华为的哪吒模型已经面世有一阵子了,而网上一直没有关于哪吒模型的实践文章,所以我打算通过这份指南教会你如何使用nezha进行文本分类。(官网上有一份文本分类的示例代码,但是上千行的代码实在是不利用快速上手)
数据集:
这里使用的是ChnSentiCorp_htl_all数据集,有7000 多条酒店评论数据,其中5000 多条正向评论,2000 多条负向评论。
代码语言:javascript复制1,我们住的三人间,房间很宽敞,卫生间不大,但都很干净。宾馆餐厅很实惠,我们只要能回宾馆,就在餐厅用餐。周边交通非常方便,基本上一部车就到各个景点,走到什刹海等处也很近。最近宾馆在外部装修,不过不影响内部卫生和晚间休息。总体来说,我们住的都挺满意。
1,绝对是超三星标准,地处商业区,购物还是很方便的,对门有家羊杂店,绝对正宗。除了价格稍贵,总体还是很满意的
1,前台服务较差,不为客户着想。房间有朋友来需要打扫,呼叫了两个小时也未打扫。房间下水道臭气熏天,卫生间漏水堵水。
0,标准间太差房间还不如3星的而且设施非常陈旧.建议酒店把老的标准间从新改善.
0,服务态度极其差,前台接待好象没有受过培训,连基本的礼貌都不懂,竟然同时接待几个客人;大堂副理更差,跟客人辩解个没完,要总经理的电话投诉竟然都不敢给。要是没有作什么亏心事情,跟本不用这么怕。
0,地理位置还不错,到哪里都比较方便,但是服务不象是豪生集团管理的,比较差。下午睡了一觉并洗了一个澡,本来想让酒店再来打扫一下,所以,打开了,请打扫的服务灯,可是到晚上回酒店,发现打扫得服务灯被关掉了,而房间还是没有打扫过。
主要代码:
代码语言:javascript复制from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, TensorDataset
import random
from NEZHA.model_nezha import BertConfig, BertForSequenceClassification
from NEZHA import nezha_utils
import numpy as np
from transformers import BertTokenizer, AdamW, AutoModel, AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 1
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if __name__ == '__main__':
# tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path='NEZHA/nezha-base-wwm', do_lower_case=True)
data = []
with open('data.txt', 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
label, text = int(line[0]), line[2:]
text = tokenizer.encode_plus(text, max_length=128, padding='max_length', truncation=True)
data.append((text, label))
random.shuffle(data)
train_data = data[:int(len(data)*0.8)]
test_data = data[len(train_data):]
input_ids_train = torch.LongTensor([each[0]['input_ids'] for each in train_data]).to(device)
token_type_ids_train = torch.LongTensor([each[0]['token_type_ids'] for each in train_data]).to(device)
attention_mask_train = torch.LongTensor([each[0]['attention_mask'] for each in train_data]).to(device)
label_train = torch.LongTensor([each[1] for each in train_data]).to(device)
input_ids_test = torch.LongTensor([each[0]['input_ids'] for each in test_data]).to(device)
token_type_ids_test = torch.LongTensor([each[0]['token_type_ids'] for each in test_data]).to(device)
attention_mask_test = torch.LongTensor([each[0]['attention_mask'] for each in test_data]).to(device)
label_test = torch.LongTensor([each[1] for each in test_data]).to(device)
train_dataset = TensorDataset(input_ids_train, token_type_ids_train, attention_mask_train, label_train)
test_dataset = TensorDataset(input_ids_test, token_type_ids_test, attention_mask_test, label_test)
train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=8, shuffle=False)
Bert_config = BertConfig.from_json_file('NEZHA/nezha-base-wwm/bert_config.json')
model = BertForSequenceClassification(config=Bert_config, num_labels=2)
nezha_utils.torch_init_model(model, 'NEZHA/nezha-base-wwm/pytorch_model.bin')
# model = AutoModelForSequenceClassification.from_pretrained("hfl/chinese-roberta-wwm-ext")
model.to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
for epoch in range(epochs):
print('epoch:', epoch)
model.train()
for input_ids, token_type_ids, attention_mask, labels in tqdm(train_loader):
optimizer.zero_grad()
# loss = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels).loss
loss = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
loss.backward()
optimizer.step()
model.eval()
total = 0
acc = 0
with torch.no_grad():
for input_ids, token_type_ids, attention_mask, labels in tqdm(test_loader):
# logits = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).logits
logits = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
pred = torch.argmax(logits, dim=-1)
total = pred.size(0)
acc = pred.eq(labels).sum().item()
print(acc / total)
这份代码十分简洁,因为我已经尽可能地去掉了一些无关紧要的东西。
在代码中,我将nezha和roberta做了对比分析(注释中的是roberta)。
我只跑了一个epoch,得到验证集的准确率是nezha 89多,roberta 87多,你也可以试着多跑几个epoch看看结果。不过在各大赛事中,大家普遍发现nezha会比roberta效果更好。
有一点值得注意的是,nezha的训练速度会比roberta慢很多,我实测下来大概要花三倍的时间,原因不详。
数据集、完整代码、预训练权重可以从这里获取:
https://github.com/luxuantao/huawei_nezha_practice