Github 项目推荐 | 类 Keras 的 PyTorch 深度学习框架 —— PyToune

2018-03-28 10:04:28 浏览数 (1)

PyToune 是一个类 Keras 的 Pytorch 深度学习框架,可用来处理训练神经网络所需的大部分模板代码。

用 PyToune 你可以:

  • 更容易地训练模型
  • 用回调来保存你最好的模型,执行 early stopping 方法等

Pytoune 官方页面:http://pytoune.org/

Pytoune Github 页面:https://github.com/GRAAL-Research/pytoune

Pytoune 兼容 PyTorch >= 0.3.0 版本和 Python >= 3.5 版本。

入门:快速上手 PyToune

PyToune 的核心数据结构是一种 Model,一种训练你的神经网络的方法。创建 PyToune 的方法和平常创建 PyTorch 模块(神经网络)的方式一样,但是你花时间去训练它,将其反馈到 PyToune 模型中,它会处理所有的步骤、统计数据、回调,就像 Keras 那样。

下面是个示例:

代码语言:javascript复制
# Import the PyToune Model and define a toy dataset
from pytoune.framework import Model

num_train_samples = 800
train_x = torch.rand(num_train_samples, num_features)
train_y = torch.rand(num_train_samples, 1)

num_valid_samples = 200
valid_x = torch.rand(num_valid_samples, num_features)
valid_y = torch.rand(num_valid_samples, 1)

创建你自己的 PyTorch 神经网络,一个损失函数和优化器:

代码语言:javascript复制
pytorch_module = torch.nn.Linear(num_features, 1)
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.SGD(pytorch_module.parameters(), lr=1e-3)

你可以用 PyToune 非常容易地训练神经网络:

代码语言:javascript复制
model = Model(pytorch_module, optimizer, loss_function)
model.fit(
    train_x, train_y,
    validation_x=valid_x,
    validation_y=valid_y,
    epochs=num_epochs,
    batch_size=batch_size
  )

这与 Keras 中的 model.compile 函数非常相似:

代码语言:javascript复制
# Keras way to compile and train
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)

你可以使用 PyToune 模型的评估方法评估你的网络的性能:

代码语言:javascript复制
loss_and_metrics = model.evaluate(x_test, y_test)

或者只预测新数据:

代码语言:javascript复制
predictions = model.predict(x_test)

正如你所见,PyToune 受到 Keras 很多启发,详细信息,请参阅 PyToune.org 上的 PyToune 文档。

安装

在使用 PyToune 之前,你应该先装上 PyTorch 0.3.0。

安装稳定的 PyToune 版本:

代码语言:javascript复制
pip install pytoune

安装最新的 PyToune:

代码语言:javascript复制
pip install -U git https://github.com/GRAAL-Research/pytoune.git

为什么叫 PyToune

PyToune(或 Québécois 的 pitoune)曾指代的是河流里的原木,用河流运输原木是非常有效的一种运输方式。PyToune 的作者希望 PyToune 能够帮助开发者更加方便地训练神经网络模型,就像「pitoune」那样。

0 人点赞