作者 | DR. VAIBHAV KUMAR 编译 | VK 来源 | Analytics In Diamag
【导读】文本分类是自然语言处理的重要应用之一。在机器学习中有多种方法可以对文本进行分类。但是这些分类技术大多需要大量的预处理和大量的计算资源。
在这篇文章中,我们使用PyTorch来进行多类文本分类,因为它有如下优点:
- PyTorch提供了一种强大的方法来实现复杂的模型体系结构和算法,其预处理量相对较少,计算资源(包括执行时间)的消耗也较少。
- PyTorch的基本单元是张量,它具有在运行时改变架构和跨gpu分布训练的优点。
- PyTorch提供了一个名为TorchText的强大库,其中包含用于预处理文本的脚本和一些流行的NLP数据集的源代码。
在本文中,我们将使用TorchText演示多类文本分类,TorchText是PyTorch中一个强大的自然语言处理库。
对于这种分类,将使用由EmbeddingBag层和线性层组成的模型。EmbeddingBag通过计算嵌入的平均值来处理长度可变的文本条目。
这个模型将在DBpedia数据集上进行训练,其中文本属于14个类。训练成功后,模型将预测输入文本的类标签。
DBpedia数据集
DBpedia是自然语言处理领域中流行的基准数据集。它包含14个类别的文本,如公司、教育机构、艺术家、电影等。
它实际上是从维基百科项目创建的信息中提取的结构化内容集。TorchText提供的DBpedia数据集有63000个属于14个类的文本实例。它包括5600个训练实例和70000个测试实例。
用TorchText实现文本分类
首先,我们需要安装最新版本的TorchText。
代码语言:javascript复制!pip install torchtext==<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.4
</span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;">
之后,我们将导入所有必需的库。
代码语言:javascript复制<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import torch
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import torchtext
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">from torchtext.datasets <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import text_classification
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import os
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import torch.nn <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">as nn
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import torch.nn.functional <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">as F
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">from torch.utils.data <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import DataLoader
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import time
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">from torch.utils.data.dataset <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import random_split
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import re
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">from torchtext.data.utils <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import ngrams_iterator
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">from torchtext.data.utils <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import get_tokenizer
</span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;">
在下一步中,我们将定义ngrams和batch大小。ngrams特征用于捕获有关本地语序的重要信息。
我们使用bigram,数据集中的示例文本将是单个单词加上bigrams字符串的列表。
代码语言:javascript复制NGRAMS = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">2
BATCH_SIZE = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">16
</span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;">
现在,我们将读取TorchText提供的DBpedia数据集。
代码语言:javascript复制<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">if <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">not os.path.isdir(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'./.data'):
os.mkdir(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'./.data')
train_dataset, test_dataset = text_classification.DATASETS[<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'DBpedia'](
root=<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'./.data', ngrams=NGRAMS, vocab=<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">None)
</span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;">
下载数据集后,我们将验证下载数据集的长度和标签数量。
代码语言:javascript复制print(len(train_dataset))
print(len(test_dataset))
代码语言:javascript复制print(len(train_dataset.get_labels()))
print(len(test_dataset.get_labels()))
我们将使用CUDA架构来加速运行和执行。
代码语言:javascript复制device = torch.device(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"cuda" <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">if torch.cuda.is_available() <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">else <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"cpu")
device
</span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;">
在下一步中,我们将定义分类的模型。
代码语言:javascript复制<span class="hljs-class" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">class <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">TextSentiment<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(nn.Module):
<span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">def <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">__init__<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">True)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
<span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">def <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">init_weights<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(self):
initrange = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
<span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">def <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">forward<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(self, text, offsets):
embedded = self.embedding(text, offsets)
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">return self.fc(embedded)
print(model)
</span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-class" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;">
现在,我们将初始化超参数并定义函数以生成训练batch。
代码语言:javascript复制VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)
<span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">def <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">generate_batch<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(batch):
label = torch.tensor([entry[<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0] <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">for entry <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">in batch])
text = [entry[<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1] <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">for entry <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">in batch]
offsets = [<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0] [len(entry) <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">for entry <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">in text]
offsets = torch.tensor(offsets[:<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">-1]).cumsum(dim=<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0)
text = torch.cat(text)
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">return text, offsets, label
</span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;">
在下一步中,我们将定义用于训练和测试模型的函数。
代码语言:javascript复制<span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">def <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">train_func<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(sub_train_):
<span class="hljs-comment" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(128, 128, 128); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"># 训练模型
train_loss = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0
train_acc = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0
data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">True,
collate_fn=generate_batch)
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">for i, (text, offsets, cls) <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">in enumerate(data):
optimizer.zero_grad()
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
output = model(text, offsets)
loss = criterion(output, cls)
train_loss = loss.item()
loss.backward()
optimizer.step()
train_acc = (output.argmax(<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1) == cls).sum().item()
<span class="hljs-comment" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(128, 128, 128); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"># 调整学习率
scheduler.step()
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">return train_loss / len(sub_train_), train_acc / len(sub_train_)
<span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">def <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">test<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(data_):
loss = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0
acc = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0
data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">for text, offsets, cls <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">in data:
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">with torch.no_grad():
output = model(text, offsets)
loss = criterion(output, cls)
loss = loss.item()
acc = (output.argmax(<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1) == cls).sum().item()
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">return loss / len(data_), acc / len(data_)
</span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-comment" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(128, 128, 128); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-comment" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(128, 128, 128); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;">
我们将用5个epoch训练模型。
代码语言:javascript复制N_EPOCHS = <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">5
min_valid_loss = float(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'inf')
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1, gamma=<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.9)
train_len = int(len(train_dataset) * <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.95)
sub_train_, sub_valid_ = <br> random_split(train_dataset, [train_len, len(train_dataset) - train_len])
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">for epoch <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train_func(sub_train_)
valid_loss, valid_acc = test(sub_valid_)
secs = int(time.time() - start_time)
mins = secs / <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">60
secs = secs % <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">60
print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Epoch: %d' %(epoch <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1), <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">" | time in %d minutes, %d seconds" %(mins, secs))
print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">f'tLoss: <span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{train_loss:<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">.4f}(train)t|tAcc: <span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{train_acc * <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">100:<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">.1f}%(train)')
print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">f'tLoss: <span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{valid_loss:<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">.4f}(valid)t|tAcc: <span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{valid_acc * <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">100:<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">.1f}%(valid)')
</span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;">
下一步,我们将在测试数据集上测试我们的模型,并检查模型的准确性。
代码语言:javascript复制print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Checking the results of test dataset...')
test_loss, test_acc = test(test_dataset)
print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">f'tLoss: <span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{test_loss:<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">.4f}(test)t|tAcc: <span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{test_acc * <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">100:<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">.1f}%(test)')
</span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-subst" style="font-size: inherit; color: inherit; line-height: inherit; margin: 0px; padding: 0px; word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;">
现在,我们将在单个新闻文本字符串上测试我们的模型,并预测给定新闻文本的类标签。
代码语言:javascript复制DBpedia_label = {<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Company',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'EducationalInstitution',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">2: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Artist',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">3: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Athlete',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">4: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'OfficeHolder',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">5: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'MeanOfTransportation',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">6: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Building',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">7: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'NaturalPlace',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">8: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Village',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">9: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Animal',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">10: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Plant',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">11: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Album',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">12: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'Film',
<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">13: <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'WrittenWork'}
<span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">def <span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">predict<span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(text, model, vocab, ngrams):
tokenizer = get_tokenizer(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"basic_english")
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">with torch.no_grad():
text = torch.tensor([vocab[token]
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">for token <span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">in ngrams_iterator(tokenizer(text), ngrams)])
output = model(text, torch.tensor([<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0]))
<span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">return output.argmax(<span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1).item() <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1
vocab = train_dataset.get_vocab()
model = model.to(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"cpu")
/*
* 提示:该行代码过长,系统自动注释不进行高亮。一键复制会移除系统注释
* </span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-params" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(255, 152, 35); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-title" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(165, 218, 45); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-keyword" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-function" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(248, 35, 117); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;">
*/
现在,我们将从测试数据中随机抽取一些文本并检查预测的类标签。
第一个预测:
代码语言:javascript复制ex_text_str = <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"Brekke Church (Norwegian: Brekke kyrkje) is a parish church in Gulen Municipality in Sogn og Fjordane county, Norway. It is located in the village of Brekke. The church is part of the Brekke parish in the Nordhordland deanery in the Diocese of Bjørgvin. The white, wooden church, which has 390 seats, was consecrated on 19 November 1862 by the local Dean Thomas Erichsen. The architect Christian Henrik Grosch made the designs for the church, which is the third church on the site."
print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"This is a %s news" �pedia_label[predict(ex_text_str, model, vocab, <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">2)])
</span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;">
第二个预测:
代码语言:javascript复制ex_text_str2 = <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"Cerithiella superba is a species of very small sea snail, a marine gastropod mollusk in the family Newtoniellidae. This species is known from European waters. It was described by Thiele, 1912."
print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"This text belongs to %s class" �pedia_label[predict(ex_text_str2, model, vocab, <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">2)])
</span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;">
第三个预测:
代码语言:javascript复制ex_text_str3 = <span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">" Nithari is a village in the western part of the state of Uttar Pradesh India bordering on New Delhi. Nithari forms part of the New Okhla Industrial Development Authority's planned industrial city Noida falling in Sector 31. Nithari made international news headlines in December 2006 when the skeletons of a number of apparently murdered women and children were unearthed in the village."
print(<span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">"This text belongs to %s class" �pedia_label[predict(ex_text_str3, model, vocab, <span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;" style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">2)])
</span class="hljs-number" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(174, 135, 250); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;"></span class="hljs-string" style="font-size: inherit; line-height: inherit; margin: 0px; padding: 0px; color: rgb(238, 220, 112); word-wrap: inherit !important; word-break: inherit !important;">
因此,通过这种方式,我们使用TorchText实现了多类文本分类。
这是一种简单易行的文本分类方法,使用这个PyTorch库只需很少的预处理量。在5600个训练实例上训练模型只花了不到5分钟。
通过将ngram从2更改为3来重新运行这些代码并查看结果是否有改进。同样的实现也可以在TorchText提供的其他数据集上实现。
参考文献:
- ‘Text Classification with TorchText’, PyTorch tutorial
- Allen Nie, ‘A Tutorial on TorchText’