代码地址:https://github.com/davidfan1224/CAIL2021_Multi-span_MRC 解读:
代码语言:javascript复制# /*
# * @Author: Yue.Fan
# * @Date: 2022-03-23 11:35:37
# * @Last Modified by: Yue.Fan
# * @Last Modified time: 2022-03-23 11:35:37
# */
# from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel
from pytorch_pretrained_bert.configuration_bert import BertConfig
from pytorch_pretrained_bert.modeling_bert import BertLayer, BertPreTrainedModel, BertModel
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torchcrf import CRF
VERY_NEGATIVE_NUMBER = -1e29
class CailModel(BertPreTrainedModel):
def __init__(self, config, answer_verification=True, hidden_dropout_prob=0.3, need_birnn=False, rnn_dim=128):
super(CailModel, self).__init__(config)
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.qa_dropout = nn.Dropout(config.hidden_dropout_prob)
# max_n_answers=3
self.num_answers = 4 # args.max_n_answers 1
self.qa_outputs = nn.Linear(config.hidden_size*4, 2)
self.qa_classifier = nn.Linear(config.hidden_size, self.num_answers)
# self.apply(self.init_bert_weights)
self.answer_verification = answer_verification
head_num = config.num_attention_heads // 4
self.coref_config = BertConfig(num_hidden_layers=1, num_attention_heads=head_num,
hidden_size=config.hidden_size, intermediate_size=256 * head_num)
self.coref_layer = BertLayer(self.coref_config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
out_dim = config.hidden_size
self.need_birnn = need_birnn
# 如果为False,则不要BiLSTM层
if need_birnn:
self.birnn = nn.LSTM(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
self.gru = nn.GRU(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
out_dim = rnn_dim * 2
self.hidden2tag = nn.Linear(out_dim, 2) # I O 二分类
# self.crf = CRF(config.num_labels, batch_first=True)
self.crf = CRF(2, batch_first=True)
self.init_weights()
if self.answer_verification:
self.retionale_outputs = nn.Linear(config.hidden_size*4, 1)
self.unk_ouputs = nn.Linear(config.hidden_size, 1)
self.doc_att = nn.Linear(config.hidden_size*4, 1)
self.yes_no_ouputs = nn.Linear(config.hidden_size*4, 2)
# self.yes_no_ouputs_noAttention = nn.Linear(config.hidden_size, 2)
self.ouputs_cls_3 = nn.Linear(config.hidden_size*4, 3)
self.beta = 100
else:
# self.unk_yes_no_outputs_dropout = nn.Dropout(config.hidden_dropout_prob)
self.unk_yes_no_outputs = nn.Linear(config.hidden_size, 3)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None,
unk_mask=None, yes_mask=None, no_mask=None, answer_masks=None, answer_nums=None, label_ids=None):
# sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
# output_all_encoded_layers=True)
# 以下例子以batch_size=2,seq_len=512, hidden_dim=768为例
# sequence_output长度为2
# sequence_output[0].shape=[2,512,768]
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
# print("sequence_output:", sequence_output[0].shape)
# print("pooled_output.shape:", pooled_output.shape)
sequence_output = sequence_output[1]
# [2, 512, 768]
sequence_output_IO = sequence_output[-1] # 取最后一层的输出
# sequence_output:[2, 512, 768*4]
sequence_output = torch.cat((sequence_output[-4], sequence_output[-3], sequence_output[-2],
sequence_output[-1]), -1) # 拼接BERT最后四层
if self.answer_verification:
batch_size = sequence_output.size(0)
seq_length = sequence_output.size(1)
hidden_size = sequence_output.size(2)
# [2*512, 3072]
sequence_output_matrix = sequence_output.view(batch_size*seq_length, hidden_size)
# [2*512 , 1]
rationale_logits = self.retionale_outputs(sequence_output_matrix)
# print(rationale_logits.shape)
# [2, 512]
rationale_logits = rationale_logits.view(batch_size, seq_length)
# [2, 512]
# 这里计算的是问题和文本之间的一个注意力
rationale_logits = F.softmax(rationale_logits, dim=-1)
# [batch, seq, hidden] [batch, seq_len, 1] = [batch, seq, hidden]
# [2, 512, 3072]
final_hidden = sequence_output*rationale_logits.unsqueeze(2)
# print(final_hidden.shape)
# [2*512, 3072]
sequence_output = final_hidden.view(batch_size*seq_length, hidden_size)
logits = self.qa_outputs(sequence_output).view(batch_size, seq_length, 2)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
# [000,11111] 1代表了文章
# [batch, seq_len] [batch, seq_len]
rationale_logits = rationale_logits * attention_mask.float()
# [batch, seq_len, 1] [batch, seq_len]
start_logits = start_logits*rationale_logits
end_logits = end_logits*rationale_logits
if self.need_birnn:
self.birnn.flatten_parameters()
self.gru.flatten_parameters()
sequence_output_IO, _ = self.birnn(sequence_output_IO)
# sequence_output_IO, _ = self.gru(sequence_output_IO)
sequence_output_IO = self.dropout(sequence_output_IO)
# [2, 512, 2] 每一个token进行二分类
emissions = self.hidden2tag(sequence_output_IO)
# answers num
# [2, 3] 进行答案数量的分类
switch_logits = self.qa_classifier(pooled_output) # 用cls位置向量进行答案数量分类
# extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) # fp16 compatibility
# extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# sequence_output_sw = self.coref_layer(sequence_output_switch, extended_attention_mask)[0]
#
# switch_logits = self.qa_classifier(sequence_output_sw[:,0,:])
# unk
# [2, 1]
unk_logits = self.unk_ouputs(pooled_output)
# doc_attn
# [2*512, 1]
attention = self.doc_att(sequence_output)
# [2, 512]
attention = attention.view(batch_size, seq_length)
# 这里计算的是文本之间的注意力
# [2, 512]
attention = attention*token_type_ids.float() (1-token_type_ids.float())*VERY_NEGATIVE_NUMBER
attention = F.softmax(attention, 1)
# [2, 512, 1]
attention = attention.unsqueeze(2)
# [2, 512, 1]*[2, 512, 3072] = [2, 512, 3072]
attention_pooled_output = attention*final_hidden
# [2, 3072]
attention_pooled_output = attention_pooled_output.sum(1)
# 去掉attention
# attention_pooled_output = pooled_output
# yes_no_logits = self.yes_no_ouputs_noAttention(attention_pooled_output)
# [2, 2]
yes_no_logits = self.yes_no_ouputs(attention_pooled_output)
# [2, 1]
yes_logits, no_logits = yes_no_logits.split(1, dim=-1)
# unk_yes_no_logits = self.ouputs_cls_3(attention_pooled_output)
# unk_logits, yes_logits, no_logits = unk_yes_no_logits.split(1, dim=-1)
else:
# sequence_output = self.qa_dropout(sequence_output)
logits = self.qa_outputs(sequence_output)
# self attention
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
# answers num
switch_logits = self.qa_classifier(pooled_output) # 用cls位置向量进行答案数量分类
# # unk yes_no_logits
# pooled_output = self.unk_yes_no_outputs_dropout(pooled_output)
unk_yes_no_logits = self.unk_yes_no_outputs(pooled_output)
unk_logits, yes_logits, no_logits= unk_yes_no_logits.split(1, dim=-1)
# # [batch, 1]
# unk_logits = unk_logits.squeeze(-1)
# yes_logits = yes_logits.squeeze(-1)
# no_logits = no_logits.squeeze(-1)
# token的logits,未知的logits, yes的logits,no的logits拼接
# [2, 515]
# 512标识没有答案,513标识YES,514标识NO
new_start_logits = torch.cat([start_logits, unk_logits, yes_logits, no_logits], 1)
print(new_start_logits.shape)
new_end_logits = torch.cat([end_logits, unk_logits, yes_logits, no_logits], 1)
if self.answer_verification and start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
if len(answer_nums.size()) > 1:
answer_nums = answer_nums.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = new_start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
# torch.unbind:不改变原来的tensor的shape,只是返回展开后的切片
# print(answer_masks.shape)
# [1,3] --> (tensor([1]), tensor([1]), tensor([0]))
# print(torch.unbind(answer_masks, dim=1))
# print(torch.unbind(start_positions, dim=1))
"""
start_positions = torch.tensor([[1,2,3], [4,5,6]])
answer_mask = torch.tensor([[1,1,0],[1,0,0]])
print(torch.unbind(start_positions, dim=1))
print(torch.unbind(answer_mask, dim=1))
(tensor([1, 4]), tensor([2, 5]), tensor([3, 6]))
(tensor([1, 1]), tensor([1, 0]), tensor([0, 0]))
"""
start_losses = [(loss_fct(new_start_logits, _start_positions) * _span_mask)
for (_start_positions, _span_mask)
in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_masks, dim=1))] # torch.unbind 移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片
end_losses = [(loss_fct(new_end_logits, _end_positions) * _span_mask)
for (_end_positions, _span_mask)
in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_masks, dim=1))]
loss_IO = -1 * self.crf(emissions, label_ids, mask=attention_mask.byte())
switch_loss = loss_fct(switch_logits, answer_nums)
# start_loss = loss_fct(new_start_logits, start_positions)
# end_loss = loss_fct(new_end_logits, end_positions)
rationale_positions = token_type_ids.float()
alpha = 0.25
gamma = 2.
# 这里还可以这么干,有意思。
rationale_loss = -alpha * ((1 - rationale_logits) ** gamma) * rationale_positions * torch.log(
rationale_logits 1e-8) - (1 - alpha) * (rationale_logits ** gamma) * (
1 - rationale_positions) * torch.log(1 - rationale_logits 1e-8)
rationale_loss = (rationale_loss*token_type_ids.float()).sum() / token_type_ids.float().sum()
# s_e_loss = sum(start_losses end_losses) rationale_loss*self.beta
# total_loss = torch.mean(s_e_loss switch_loss)
s_e_loss = sum(start_losses end_losses)
total_loss = torch.mean(s_e_loss switch_loss loss_IO) rationale_loss * self.beta
# total_loss = (start_losses end_losses) / 2
return total_loss
elif start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = new_start_logits.size(1)
start_positions.clamp_(1, ignored_index)
end_positions.clamp_(1, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(new_start_logits, start_positions)
end_loss = loss_fct(new_end_logits, end_positions)
total_loss = (start_loss end_loss) / 2
return total_loss
else:
IO_logits = self.crf.decode(emissions, attention_mask.byte())
for io in IO_logits:
while len(io) < 512:
io.append(0)
IO_logits=torch.Tensor(IO_logits)
IO_logits = IO_logits.cuda()
return start_logits, end_logits, unk_logits, yes_logits, no_logits, switch_logits, IO_logits
class MultiLinearLayer(nn.Module):
def __init__(self, layers, hidden_size, output_size, activation=None):
super(MultiLinearLayer, self).__init__()
self.net = nn.Sequential()
for i in range(layers-1):
self.net.add_module(str(i) 'linear', nn.Linear(hidden_size, hidden_size))
self.net.add_module(str(i) 'relu', nn.ReLU(inplace=True))
self.net.add_module('linear', nn.Linear(hidden_size, output_size))
def forward(self, x):
return self.net(x)
if __name__ == '__main__':
import torch
input_ids = torch.tensor([[101, 839, 5442, 6158, 6843, 2518, 1525, 763, 1278, 7368, 8043, 102, 5307, 2144, 4415, 3389, 3209, 131,
123, 121, 122, 125, 2399, 129, 3299, 127, 3189, 677, 1286, 5276, 128, 4157, 117, 1333, 1440, 7942, 166,
121, 3341, 1168, 6158, 1440, 5529, 166, 124, 5307, 5852, 4638, 3717, 3799, 2421, 1079, 6579, 743, 697,
1259, 3717, 3799, 117, 4507, 1333, 1440, 1350, 1071, 707, 3198, 7416, 3341, 4638, 676, 6762, 6756, 1923,
5632, 6121, 6566, 6569, 3021, 6817, 3717, 3799, 511, 1762, 3021, 6817, 6814, 4923, 704, 117, 6158, 1440,
2421, 1079, 1831, 3123, 4638, 3717, 3799, 948, 1847, 678, 3341, 2199, 1333, 1440, 4790, 839, 511, 2496,
1921, 677, 1286, 117, 1333, 1440, 6158, 6843, 2518, 727, 3926, 2356, 5018, 676, 782, 3696, 1278, 7368,
117, 5307, 7305, 6402, 3466, 3389, 6402, 3171, 711, 100, 5587, 123, 510, 124, 3491, 860, 7755, 2835,
510, 2340, 1079, 6674, 7755, 2835, 100, 117, 1066, 3118, 1139, 1278, 4545, 6589, 127, 128, 122, 1039,
511, 1728, 4567, 2658, 698, 7028, 117, 2496, 1921, 6760, 1057, 3946, 2336, 1278, 4906, 1920, 2110, 7353,
2247, 5018, 753, 1278, 7368, 6822, 6121, 857, 7368, 3780, 4545, 117, 754, 123, 121, 122, 125, 2399, 129,
3299, 122, 122, 3189, 1762, 1059, 7937, 678, 6121, 100, 5587, 3491, 7755, 2835, 1147, 1908, 1121, 1327,
1079, 1743, 2137, 3318, 100, 1469, 100, 2340, 1079, 6674, 7755, 2835, 1079, 1743, 2137, 3318, 100, 117,
754, 123, 121, 122, 125, 2399, 129, 3299, 123, 122, 3189, 1139, 7368, 117, 1066, 6369, 3118, 1139, 857,
7368, 6589, 4500, 126, 125, 126, 121, 126, 119, 126, 128, 1039, 511, 5307, 3315, 7368, 1999, 2805, 3946,
2336, 1921, 3633, 1385, 3791, 7063, 2137, 2792, 7063, 2137, 117, 1333, 1440, 4638, 3655, 4565, 4923,
2428, 711, 736, 5277, 117, 5852, 1075, 3309, 7361, 6397, 2137, 711, 124, 702, 3299, 113, 794, 1358, 839,
722, 3189, 6629, 6369, 5050, 114, 117, 753, 3309, 2797, 3318, 113, 2858, 7370, 1079, 1743, 2137, 114,
4638, 5852, 1075, 3309, 7361, 6397, 2137, 711, 1288, 702, 3299, 117, 1400, 5330, 3780, 4545, 6589, 5276,
7444, 122, 121, 121, 121, 121, 1039, 2772, 2902, 2141, 7354, 1394, 4415, 1355, 4495, 6589, 4500, 711,
1114, 511, 1333, 1440, 857, 7368, 3780, 4545, 1350, 1139, 7368, 1400, 117, 6158, 1440, 5529, 166, 124,
1350, 1071, 1036, 2094, 3295, 1343, 2968, 3307, 2400, 6843, 677, 5852, 1075, 1501, 511, 1352, 3175,
2218, 6608, 985, 752, 2139, 3187, 3791, 6809, 2768, 671, 5636, 2692, 6224, 117, 3125, 3868, 6401, 511,
809, 677, 752, 2141, 117, 3300, 1333, 1440, 6716, 819, 6395, 510, 697, 6158, 1440, 2787, 5093, 6395,
3209, 510, 7305, 6402, 4567, 1325, 1350, 1355, 4873, 1063, 819, 510, 857, 7368, 6589, 4500, 1355, 4873,
1350, 3926, 1296, 510, 1139, 7368, 6381, 2497, 510, 1278, 4545, 6395, 3209, 741, 510, 1385, 3791, 7063,
2137, 2692, 6224, 741, 510, 7063, 2137, 6589, 1355, 4873, 1350, 2431, 2144, 5011, 2497, 1762, 3428, 858,
6395, 117, 3315, 7368, 750, 809, 6371, 2137, 511, 1333, 1440, 2990, 897, 4638, 6228, 7574, 6598, 3160,
117, 1377, 809, 6395, 102]])
input_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]])
segment_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]])
paragraph_len = 499
label_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
start_positions = torch.tensor([[118, 174, 0]])
end_positions = torch.tensor([[126, 185, 0]])
is_impossible =False
unk_mask = torch.tensor([[0]])
yes_mask = torch.tensor([[0]])
no_mask = torch.tensor([[0]])
answer_masks = torch.tensor([[1, 1, 0]])
answer_nums = torch.tensor([2])
class Args:
bert_config_file = 'model_hub/chinese-bert-wwm-ext/config.json'
need_birnn = False
rnn_dim = 128
args = Args()
config = BertConfig.from_json_file(args.bert_config_file)
model = CailModel(config, need_birnn=args.need_birnn, rnn_dim=args.rnn_dim)
# print(model)
loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions,
unk_mask, yes_mask, no_mask, answer_masks, answer_nums, label_ids)
print(loss.item())
我们需要注意的几点:
- query和context的注意力、context自己之间的注意力。
- 【答案的开始的loss、答案的结束的loss、没有答案的loss、答案为yes的loss、答案为no的loss】、答案数目的loss、每一个token是否属于答案的loss、token_type的loss。