PyTorch 1.0 中文官方教程:torch.nn 到底是什么?

2022-05-07 14:16:00 浏览数 (1)

译者:lhc741

作者:Jeremy Howard,fast.ai。感谢Rachel Thomas和Francisco Ingham的帮助和支持。

我们推荐使用notebook来运行这个教程,而不是脚本,点击这里下载notebook(.ipynb)文件。

Pytorch提供了torch.nn、torch.optim、Dataset和DataLoader这些设计优雅的模块和类以帮助使用者创建和训练神经网络。 为了最大化利用这些模块和类的功能,并使用它们做出适用于你所研究问题的模型,你需要真正理解他们是如何工作的。 为了做到这一点,我们首先基于MNIST数据集训练一个没有任何特征的简单神经网络。 最开始我们只会用到PyTorch中最基本的tensor功能,然后我们将会逐渐的从torch.nntorch.optimDatasetDataLoader中选择一个特征加入到模型中,来展示新加入的特征会对模型产生什么样的效果,以及它是如何使模型变得更简洁或更灵活。

在这个教程中,我们假设你已经安装好了PyTorch,并且已经熟悉了基本的tensor运算。(如果你熟悉Numpy的数组运算,你将会发现这里用到的PyTorch tensor运算和numpy几乎是一样的)

MNIST数据安装

我们将要使用经典的MNIST数据集,这个数据集由手写数字(0到9)的黑白图片组成。

我们将使用pathlib来处理文件路径的相关操作(python3中的一个标准库),使用request来下载数据集。 我们只会在用到相关库的时候进行引用,这样你就可以明确在每个操作中用到了哪些库。

代码语言:javascript复制
from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

阅读全文/改进本文

0 人点赞