译者:hhxx2015
作者: Sean Robertson
我们将构建和训练字符级RNN来对单词进行分类。 字符级RNN将单词作为一系列字符读取,在每一步输出预测和“隐藏状态”,将其先前的隐藏状态输入至下一时刻。 我们将最终时刻输出作为预测结果,即表示该词属于哪个类。
具体来说,我们将在18种语言构成的几千个姓氏的数据集上训练模型,根据一个单词的拼写预测它是哪种语言的姓氏:
代码语言:javascript复制$ python predict.py Hinton
(-0.47) Scottish
(-1.52) English
(-3.57) Irish
$ python predict.py Schmidhuber
(-0.19) German
(-2.48) Czech
(-2.68) Dutch
阅读建议:
我默认你已经安装好了PyTorch,熟悉Python语言,理解“张量”的概念:
- https://pytorch.org/ PyTorch安装指南
- Deep Learning with PyTorch: A 60 Minute Blitz PyTorch入门
- Learning PyTorch with Examples 一些PyTorch的例子
- PyTorch for Former Torch Users Lua Torch 用户参考
事先学习并了解RNN的工作原理对理解这个例子十分有帮助:
- The Unreasonable Effectiveness of Recurrent Neural Networks shows a bunch of real life examples
- Understanding LSTM Networks is about LSTMs specifically but also informative about RNNs in general
阅读全文/改进本文