TensorFlow被废了,谷歌家的新王储JAX到底是啥?

2022-09-20 19:33:14 浏览数 (1)

这几天各大科技媒体都在唱衰TensorFlow,鼓吹JAX。恰好前两个月我都在用JAX,算是从JAX新人进阶为小白,过来吹吹牛。

放弃TensorFlow,谷歌全面转向JAX

TensorFlow,危!抛弃者正是谷歌自己

吃瓜群众都在疯狂吐槽TensorFlow的API多混乱,PyTorch多好用,但是好像,并没有多少人真正说到JAX。

JAX:自动微分 NumPy JIT

JAX到底是啥?简单说,JAX是一种自动微分的NumPy。所以JAX并不是一个深度学习框架,而是一个科学计算框架。深度学习是JAX功能的一个子集。

既然是NumPy,那就可以用NumPy接口做各类科学计算。

而且还带自动微分,科学计算世界中,微分是最常用的一种计算。JAX的自动微分包含了前向微分、反向微分等各种接口。反正各类花式微分,几乎都可以用JAX实现。

除了"NumPy" "自动微分",JAX还有几个其他的功能:

JIT编译

将NumPy接口写的计算转成高效的二进制代码,可以在CPU/GPU/TPU上获得极高加速比。JIT编译主要还是基于XLA(accelerated linear algebra)。XLA是一种编译器,可以将TF/JAX的代码在CPU/GPU/TPU上加速。

说到JAX速度快,主要就靠XLA!

并行化

比起简单的NumPy,JAX提供了大量接口做并行。无论是tf还是torch,一个简单的并行方法是:batch size。JAX用 vmap 做并行, 用户只用实现一条数据的处理,JAX帮我们将做拓展,可以拓展到batch size大小。vmap 的思想与 Spark 中的 map 一样。用户关注 map 里面的一条数据的处理方法,JAX 帮我们做并行化。

函数式编程

到这就不得不提JAX的函数式编程。函数式编程相对“面向对象”(Object Oriented)就难很多了。毕竟,绝大多数中国程序员都没有系统学习过函数式编程。

JAX是纯函数式的。

第一让人不适应的就是数据的不可变(Immutable)。不能原地改数据,只能创建新数据。

第二就是各类闭包。“闭包”这个名字就很抽象,更不用说真正写起来了。

然后就是partial

这些东西在torch用户那里可能一辈子都用不到。

来到JAX世界,你都会怀疑自己到底学没学过Python。

深度学习框架

JAX并不是一个深度学习框架。想要做深度学习,还要再在JAX上套一层。

要想在JAX上实现一个全连接网络,要 np.dot(w, x) b。竟然没有现成的 nn.Dense 或者 nn.Linear

于是有了DeepMind的 haiku ,Google的 flax,和其他各种各样的库。

JAX是纯函数的,代码写起来和tf、torch也不太一样。没有了 .fit() 这样傻瓜式的接口,没有 MSELoss 这样的损失函数。而且要适应数据的不可变:模型参数先初始化init,才能使用。

不过,flax 和 haiku 也有不少市场了。大名鼎鼎的AlphaFold就是用 haiku 写的。

但大家都在学JAX

JAX到底好不好我不敢说。但是大家都在学它。看看PyTorch刚发布的 torchfunc,里面的vmap就是学得JAX。还有各个框架都开始提供的前向微分 jvp,都是JAX的影子。

0 人点赞