Keras 3.0一统江湖!大更新整合PyTorch、JAX,全球250万开发者在用了

2023-11-30 17:42:25 浏览数 (1)

转载新智元报道

编辑:编辑部

【新智元导读】今天,备受广大开发者欢迎的深度学习框架Keras,正式更新了3.0版本,实现了对PyTorch和JAX的支持,同时性能提升,还能轻松实现大规模分布式训练。

刚刚,Keras 3.0正式发布!

经过5个月的公开Beta测试,深度学习框架Keras 3.0终于面向所有开发者推出。

全新的Keras 3对Keras代码库进行了完全重写,可以在JAX、TensorFlow和PyTorch上运行,能够解锁全新大模型训练和部署的新功能。

「Keras之父」François Chollet在最新版本发布之前,也是做了多次预告。目前,有250 万的开发者都在使用Keras框架。

重磅消息:我们刚刚发布了 Keras 3.0! 在 JAX、TensorFlow 和 PyTorch 上运行 Keras 使用 XLA 编译更快地训练 通过新的 Keras 分发 API 解锁任意数量的设备和主机的训练运行 它现在在 PyPI 上上线

开发者甚至可以将Keras用作低级跨框架语言,以开发自定义组件,例如层、模型或指标。

只需一个代码库,这些组件便可用在JAX、TensorFlow、PyTorch中的原生工作流。

再次让Keras成为多后端

最初的Keras可以在Theano、TensorFlow、CNTK,甚至MXNet上运行。

2018年,由于Theano和CNTK已停止开发,TensorFlow似乎成为了唯一可行的选择,于是,Keras将开发重点放在了TensorFlow上。

而到了今年,情况发生了变化。

根据2023年StackOverflow开发者调查,和2022年Kaggle机器学习和数据科学调查等显示,

TensorFlow拥有55%到60%的市场份额,是ML在生产领域的首选。

而PyTorch拥有40%到45%的市场份额,是ML在研究领域的首选。

与此同时,JAX虽然市场份额要小得多,但已被Google DeepMind、Midjourney、Cohere等生成式AI领域的顶级参与者所接受。

于是,开发团队对Keras代码库进行了完全重写,新诞生的Keras 3.0基于模块化后端架构进行了重构,有能力在任意框架上运行。

同时新的Keras也保证了兼容性,比如在使用TensorFlow后端时,你可以简单地使用 import keras_core as keras 来替换from tensorflow import keras

——现有的代码将毫无问题地运行,而且由于 XLA 编译,通常性能略有提高。

Keras vs. TensorFlow

小编在这里给大家举一个例子,说明如何从TensorFlow的代码转换成Keras的形式。

TensorFlow Core Implementation

Keras implementation

相比之下,我们可以清楚地看到Keras带来的简洁性。

TensorFlow可以对每个变量进行更精细的控制,而Keras提供了易用性和快速原型设计的能力。

对于一些开发者来说,Keras省去了开发中的一些麻烦,降低了编程复杂性,节省了时间成本。

Keras 3.0新特性

Keras最大的优势在于,通过出色的UX、API设计和可调试性可实现高速开发。

而且,它还是一个经过实战考验的框架,并为世界上一些最复杂、最大规模的ML系统提供支持,比如Waymo自动驾驶车、YouTube推荐引擎。

那么,使用新的多后端Keras 3还有哪些额外的优势呢?

- 始终为模型获得最佳性能。

在基准测试中,发现JAX通常在GPU、TPU和CPU上提供最佳的训练和推理性能,但结果因模型而异,因为非XLA TensorFlow在GPU上偶尔会更快。

它能够动态选择为模型提供最佳性能的后端,而无需对代码进行任何更改,这意味着开发者可以以最高效率进行训练和服务。

- 为模型解锁生态系统可选性。

任何Keras 3模型都可以作为PyTorch模块实例化,可以作为 TensorFlow SavedModel 导出,也可以作为无状态 JAX 函数实例化。

这意味着开发者可以将Keras 3模型与PyTorch生态系统包,全系列TensorFlow部署和生产工具(如TF-Serving,TF.js和TFLite)以及JAX大规模TPU训练基础架构一起使用。使用 Keras 3 API 编写一个 model.py ,即可访问 ML 世界提供的一切。

- 利用JAX的大规模模型并行性和数据并行性。

Keras 3包含一个全新的分布式 API,即keras.distribution 命名空间,目前已在JAX后端实现(即将在TensorFlow和PyTorch后端实现)。

通过它,可以在任意模型尺度和聚类尺度上轻松实现模型并行、数据并行以及两者的组合。由于它能将模型定义、训练逻辑和分片配置相互分离,因此使分发工作流易于开发和维护。

- 最大限度地扩大开源模型版本的覆盖面。

想要发布预训练模型?想让尽可能多的人能够使用它吗?如果你在纯TensorFlow或PyTorch中实现它,它将被大约一半的社区使用。

如果你在Keras 3中实现了它,那么任何人都可以立即使用它,无论他们选择的框架是什么(即使他们自己不是Keras用户)。在不增加开发成本的情况下实现2倍的影响。

- 使用来自任何来源的数据管道。

Keras 3 / fit() / evaluate() predict() 例程与 tf.data.Dataset 对象、PyTorch DataLoader 对象、NumPy 数组、Pandas 数据帧兼容——无论你使用什么后端。你可以在 PyTorch DataLoader 上训练 Keras 3 TensorFlow 模型,也可以在tf.data.Dataset上训练Keras 3 PyTorch模型。

预训练模型

现在,开发者即可开始使用Keras 3的各种预训练模型。

所有40个Keras应用程序模型( keras.applications 命名空间)在所有后端都可用。KerasCV和KerasNLP中的大量预训练模型也适用于所有后端。

其中包括:

- BERT - OPT - Whisper - T5 - Stable Diffusion - YOLOv8

跨框架开发

Keras 3能够让开发者创建在任何框架中都相同的组件(如任意自定义层或预训练模型),它允许访问适用于所有后端的 keras.ops 命名空间。

Keras 3包含NumPy API的完整实现,——不是「类似 NumPy」,而是真正意义上的 NumPy API,具有相同的函数和参数。比如 ops.matmul、ops.sum、ops.stack、ops.einsum 等函数。

Keras 3还包含NumPy中没有的,一组特定于神经网络的函数,例如 ops.softmax, ops.binary_crossentropy, ops.conv等。

另外,只要开发者使用的运算,全部来自于keras.ops ,那么自定义的层、损失函数、优化器就可以跨越JAX、PyTorch和TensorFlow,使用相同的代码。

开发者只需要维护一个组件实现,就可以在所有框架中使用它。

Keras架构

下面,我们来稍稍理解一下Keras的机制和架构。

在Keras中,Sequential 和 Model 类是模型构建的核心,为组装层和定义计算图提供了一个框架。

Sequential 是层的线性堆栈。它是Model 的子类,专为简单情况而设计,模型由具有一个输入和一个输出的线性层堆栈组成。

Sequential 类有以下一些主要特点:

简单性:只需按照要执行的顺序列出图层即可。 自动前向传递:当向Sequential模型添加层时,Keras会自动将每一层的输出连接到下一层的输入,从而创建前向传递,而无需手动干预。 内部状态管理:Sequential管理层的状态(如权重和偏置)和计算图。调用compile时,它会通过指定优化器、损失函数和指标来配置学习过程。 训练和推理:Sequential类提供了fit、evaluate和predict等方法,分别用于训练、评估和预测模型。这些方法在内部处理训练循环和推理过程。

Model类与函数式API一起使用,提供了比Sequential更大的灵活性。它专为更复杂的架构而设计,包括具有多个输入或输出、共享层和非线性拓扑的模型。

Model 类的主要特点有:

层图:Model允许创建层图,允许一个层连接到多个层,而不仅仅是上一个层和下一个层。 显式输入和输出管理:在函数式API中,可以显式定义模型的输入和输出。相比于Sequential,可以允许更复杂的架构。 连接灵活性:Model类可以处理具有分支、多个输入和输出以及共享层的模型,使其适用于简单前馈网络以外的广泛应用。 状态和训练管理:Model类管理所有层的状态和训练过程,同时提供了对层的连接方式,以及数据在模型中的流动方式的更多控制。

Model 类和 Sequential类都依赖于以下机制:

层注册:在这些模型中添加层时,层会在内部注册,其参数也会添加到模型的参数列表中。 自动微分:在训练过程中,Keras使用后端引擎(TensorFlow等)提供的自动微分来计算梯度。这一过程对用户而言是透明的。 后端执行:实际计算(如矩阵乘法、激活等)由后端引擎处理,后端引擎执行模型定义的计算图。 序列化和反序列化:这些类包括保存和加载模型的方法,其中涉及模型结构和权重的序列化。

从本质上讲,Keras中的Model和Sequential类抽象掉了定义和管理计算图所涉及的大部分复杂性,使用户能够专注于神经网络的架构,而不是底层的计算机制。

Keras 自动处理各层如何相互连接、数据如何在网络中流动以及如何进行训练和推理操作等错综复杂的细节。

对于Keras的大更新,有网友使用下面的图片表达自己的看法:

虽然小编也不知道为什么要炸TensorFlow。

还有网友表示刚好可以用上:

另一位网友发来贺电,「在PyTorch之上使用Keras是一项了不起的成就!」

当然也有网友唱反调,「我想知道为什么有人会使用Keras Torch而不是普通的 Torch,因为Torch与Tensorflow不同,它有一组很好的API」。

此时Tensorflow的内心:啊对对对,你们说得都对。

参考资料:

https://twitter.com/fchollet/status/1729512791894012011

https://keras.io/keras_3/

0 人点赞