MXNet 宣布支持 Keras 2,可更加方便快捷地实现 CNN 及 RNN 分布式训练

2018-07-26 10:59:30 浏览数 (1)

AI 研习社按,近期,AWS 表示 MXNet 支持 Keras 2,开发者可以使用 Keras-MXNet 更加方便快捷地实现 CNN 及 RNN 分布式训练。AI 研习社将 AWS 官方博文编译如下。

Keras-MXNet 深度学习后端(https://github.com/awslabs/keras-apache-mxnet)现在可用,这要归功于 Keras 和 Apache MXNet(孵化)开源项目的贡献者。Keras 是用 Python 编写的高级神经网络 API,以快速简单的 CNN 和 RNN 原型而闻名。

Keras 开发人员现在可以使用高性能 MXNet 深度学习引擎进行 CNN 和递归神经网络 RNN 的分布式训练。通过更新几行代码,Keras 开发人员可以使用 MXNet 的多 GPU 分布式训练功能来提高训练速度。保存 MXNet 模型是该发行版本一个极具价值的功能。开发者可以在 Keras 中进行设计,使用 Keras-MXNet 进行训练,并且在生产中用 MXNet 进行大规模推算。

用 Keras 2 和 MXNet 做分布式训练

本文介绍如何安装 Keras-MXNet 并演示如何训练 CNN 和 RNN。如果您之前尝试过使用其他深度学习引擎做分布式训练,那么您应该知道这过程可能很乏味而且很困难。现在,让我们看看用 Keras-MXNet 训练会怎样。

安装只需要几步

  • 部署 AWS Deep Learning AMI
  • 安装 Keras-MXNet
  • 配置 Keras-MXNet

1.部署 AWS Deep Learning AMI

按照此教程部署 AWS Deep Learning AMI(DLAMI)。要利用多 GPU 训练示例,请启动一个 p3.8xlarge 或类似的多 GPU 实例类型。

想要自己安装依赖来运行 CUDA,Keras,MXNet 和其他框架(比如 TensorFlow)? 请按照 Keras-MXNet 安装指南来安装(https://github.com/awslabs/keras-apache-mxnet/blob/master/docs/mxnet_backend/installation.md)。

2.安装 Keras-MXNet

将 Keras-MXnet 及其依赖项安装在您 DLAMI 上的 MXNet Conda 环境中。 由于它已经有Keras 1.0,所以你需要首先卸载它。登录您的 DLAMI 并运行以下命令:

代码语言:javascript复制
# Activate the MXNet Python 3 environment on the DLAMI
$ source activate mxnet_p36

# Install a dependency needed for Keras datasets
$ pip install h5py

# Uninstall older versions Keras-MXNet
$ pip uninstall keras-mxnet

# Install Keras-MXNet v2.1.6 
$ pip install keras-mxnet

Keras-MXnet 及其依赖现已安装在 DLAMI 的 MXNet Conda 环境中。

3.验证 Keras-MXNet 安装

使用以下方式运行 MXNet 后端来验证你的 Keras:

代码语言:javascript复制
$ python
>>>import keras as k

   Using MXNet backend

CNN 支持

现在让我们在 CIFAR-10 数据集(https://www.cs.toronto.edu/~kriz/cifar.html)上训练一个 ResNet 模型以确定 10 个分类:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。

我们可以使用 Keras-MXNet 存储库的示例的部分 Keras 2 脚本。用 MXNet 作为 Keras 的后端只需要对脚本进行非常少的更新。

首先从 Keras-MXNet 库文件中下载示例脚本。

代码语言:javascript复制
$ wget https://raw.githubusercontent.com/awslabs/keras-apache-mxnet/master/examples/cifar10_resnet_multi_gpu.py

该脚本调用 multi_gpu_model API 并传递要使用的 GPU 数量。

其次,在终端窗口中运行 nvidia-smi 以确定 DLAMI 上可用的 GPU 数量。 在下一步中,如果您有四个 GPU,您将按原样运行脚本,否则运行以下命令打开脚本进行编辑。

代码语言:javascript复制
$ vi cifar10_resnet_multi_gpu.py

该脚本以下行可以定义 GPU 的数量,如果有必要的话可以更新它:

代码语言:javascript复制
model = multi_gpu_model(model, gpus=4)

训练:

代码语言:javascript复制
$ python cifar10_resnet_multi_gpu.py

(可选)在训练运行期间,使用 nvidia-smi 命令检查 GPU 利用率和内存使用情况。

RNN 支持

Keras-MXNet 目前提供 RNN 实验性的支持。 在使用带有 MXNet 后端的 RNN 时存在一些限制。更多相关信息,请查阅 Keras-MXNet 文档。 这里的例子包括你需要的解决方法,以便使用 LSTM 层训练 IMDB 数据集。尽管有解决方法,但在多 GPU AMI 上训练此 RNN 将比你习惯的要容易和快速。

使用 imdb_lstm 示例脚本。 在嵌入层中传递输入长度,并按如下所示设置 unroll = True。

首先,在 DLAMI 的终端会话中,从 Keras-MXNet repo 文件夹下载示例脚本。

代码语言:javascript复制
$ wget https://raw.githubusercontent.com/awslabs/keras-apache-mxnet/master/examples/imdb_lstm.py

其次,打开脚本并跳转到下面一行来查看它:

代码语言:javascript复制
model.add(Embedding(max_features, 128, input_length=maxlen))

model.add(LSTM(128, unroll=True))

第三,示例脚本已被修改为与 MXNet 后端兼容,因此您可以运行它:

代码语言:javascript复制
$ python imdb_lstm.py

(可选)在训练运行期间,使用 nvidia-smi 命令检查 GPU 利用率和内存使用情况。 为此打开另一个终端会话。

Benchmarks

为帮助您评估不同 Keras 后端的性能,我们为 Keras-MXNet 添加了基准测试模块。通过在该表中描述的 CPU,单 GPU 和多 GPU 机器上使用各种模型和数据集,您可以看到 Keras-MXNet 具有更快的 CNN 训练速度,以及跨多个 GPU 的高效缩放, 这将显示在训练速度的条形图中。有关如何运行基准脚本并生成详细基准测试结果的信息,请参阅 Keras 基准测试自述文件。

基准配置:

  • Keras Version 2.1.6
  • MXNet Version 1.2.0
  • Image Data Format: Channel first

由于数据集图像本身较小,因此对 CIFAR10 数据集进行训练会导致子线性缩放。该数据集由 50,000 个尺寸为 32×32 像素的图像组成,传送这些小图像的通信开销高于从四个跳转到八个 GPU 所提供的计算能力。

MXNet 宣布支持 Keras 2,可更加方便快捷地实现 CNN 及 RNN 分布式训练

与 Keras-MXNet 的图像处理速度比较

MXNet 宣布支持 Keras 2,可更加方便快捷地实现 CNN 及 RNN 分布式训练

下一步?

尝试一些额外的 Keras-MXNet 教程或阅读发行说明中的详细信息。

更多资料

  • 保存 MXNet-Keras 模型 https://github.com/awslabs/keras-apache-mxnet/blob/master/docs/mxnet_backend/installation.md
  • 性能指南 https://github.com/awslabs/keras-apache-mxnet/blob/master/docs/mxnet_backend/performance_guide.md
  • 多 GPU 训练 https://github.com/awslabs/keras-apache-mxnet/blob/master/docs/mxnet_backend/multi_gpu_training.md
  • RNN 限制和解决方法 https://github.com/awslabs/keras-apache-mxnet/blob/master/docs/mxnet_backend/using_rnn_with_mxnet_backend.md
  • 发行说明 https://github.com/awslabs/keras-apache-mxnet/releases/tag/v2.1.6

Via:

https://aws.amazon.com/cn/blogs/machine-learning/apache-mxnet-incubating-adds-support-for-keras-2/

0 人点赞