PyTorch 1.0 中文官方教程:使用字符级别特征的 RNN 网络进行姓氏分类

2022-05-07 14:01:06 浏览数 (1)

译者:hhxx2015

作者: Sean Robertson

我们将构建和训练字符级RNN来对单词进行分类。 字符级RNN将单词作为一系列字符读取,在每一步输出预测和“隐藏状态”,将其先前的隐藏状态输入至下一时刻。 我们将最终时刻输出作为预测结果,即表示该词属于哪个类。

具体来说,我们将在18种语言构成的几千个姓氏的数据集上训练模型,根据一个单词的拼写预测它是哪种语言的姓氏:

代码语言:javascript复制
$ python predict.py Hinton
(-0.47) Scottish
(-1.52) English
(-3.57) Irish

$ python predict.py Schmidhuber
(-0.19) German
(-2.48) Czech
(-2.68) Dutch

阅读建议:

我默认你已经安装好了PyTorch,熟悉Python语言,理解“张量”的概念:

  • https://pytorch.org/ PyTorch安装指南
  • Deep Learning with PyTorch: A 60 Minute Blitz PyTorch入门
  • Learning PyTorch with Examples 一些PyTorch的例子
  • PyTorch for Former Torch Users Lua Torch 用户参考

事先学习并了解RNN的工作原理对理解这个例子十分有帮助:

  • The Unreasonable Effectiveness of Recurrent Neural Networks shows a bunch of real life examples
  • Understanding LSTM Networks is about LSTMs specifically but also informative about RNNs in general

阅读全文/改进本文

0 人点赞