参考连接: CRF Layer on the Top of BiLSTM - 输入单词序列的词表征经过BiLSTM处理,生成每个单词所属实体类别的权重 - 再将权重分布组成的序列,输入到CRF层,获得最终的实体类别分布 - BiLSTM层已经可以获得了单词的实体类别了 - 但CRF层给上一层的输出添加了一些规则限制,即的CRF特征方程
1) START B-Person B-Person B-Person B-Person B-Person END 2) START B-Person I-Person B-Person B-Person B-Person END ...... 10) START B-Person I-Person O B-Organization O END ...... N) O O O O O O O
# Initialize the viterbi variables in log space init_vvars = torch.full((1, self.tagset_size), -10000.) init_vvars[0][self.tag_to_ix[START_TAG]] = 0
# 保存上一步各个标签对应的最佳分数 forward_var = init_vvars for feat in feats: bptrs_t = [] # holds the backpointers for this step viterbivars_t = [] # holds the viterbi variables for this step
best_path = [best_tag_id] # 最佳路径 for bptrs_t inreversed(backpointers): best_tag_id = bptrs_t[best_tag_id] best_path.append(best_tag_id) # Pop off the start tag (we dont want to return that to the caller) start = best_path.pop() assert start == self.tag_to_ix[START_TAG] # Sanity check best_path.reverse() return path_score, best_path
BiLSTM+CRF完整代码
1 2 3 4 5 6 7 8 9
import torch import torch.autograd as autograd import torch.nn as nn import torch.optim as optim
defprepare_sequence(seq, to_ix): """ word seq --> idx seq """ idxs = [to_ix[w] for w in seq] return torch.tensor(idxs, dtype=torch.long) # dtype must be float <- long/int not implemented for torch.exp
# Maps the output of the LSTM into tag space. self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
# Matrix of transition parameters. Entry i,j is the score of # transitioning *to* i *from* j. self.transitions = nn.Parameter( torch.randn(self.tagset_size, self.tagset_size))
# These two statements enforce the constraint that we never transfer # to the start tag and we never transfer from the stop tag self.transitions.data[tag_to_ix[START_TAG], :] = -10000 self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
def_forward_alg(self, feats): # Do the forward algorithm to compute the partition function init_alphas = torch.full((1, self.tagset_size), -10000.) # START_TAG has all of the score. init_alphas[0][self.tag_to_ix[START_TAG]] = 0.
# Wrap in a variable so that we will get automatic backprop forward_var = init_alphas
# Iterate through the sentence for feat in feats: alphas_t = [] # The forward tensors at this timestep for next_tag inrange(self.tagset_size): # broadcast the emission score: it is the same regardless of # the previous tag emit_score = feat[next_tag].view(1, -1).expand( 1, self.tagset_size) # the ith entry of trans_score is the score of transitioning to # next_tag from i trans_score = self.transitions[next_tag].view(1, -1) # The ith entry of next_tag_var is the value for the # edge (i -> next_tag) before we do log-sum-exp next_tag_var = forward_var + trans_score + emit_score # The forward variable for this tag is log-sum-exp of all the # scores. alphas_t.append(log_sum_exp(next_tag_var).view(1)) forward_var = torch.cat(alphas_t).view(1, -1) terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] alpha = log_sum_exp(terminal_var) return alpha
# Initialize the viterbi variables in log space init_vvars = torch.full((1, self.tagset_size), -10000.) init_vvars[0][self.tag_to_ix[START_TAG]] = 0
# forward_var at step i holds the viterbi variables for step i-1 forward_var = init_vvars for feat in feats: bptrs_t = [] # holds the backpointers for this step viterbivars_t = [] # holds the viterbi variables for this step
for next_tag inrange(self.tagset_size): # next_tag_var[i] holds the viterbi variable for tag i at the # previous step, plus the score of transitioning # from tag i to next_tag. # We don't include the emission scores here because the max # does not depend on them (we add them in below) next_tag_var = forward_var + self.transitions[next_tag] best_tag_id = argmax(next_tag_var) bptrs_t.append(best_tag_id) viterbivars_t.append(next_tag_var[0][best_tag_id].view(1)) # Now add in the emission scores, and assign forward_var to the set # of viterbi variables we just computed forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1) backpointers.append(bptrs_t)
# Follow the back pointers to decode the best path. best_path = [best_tag_id] for bptrs_t inreversed(backpointers): best_tag_id = bptrs_t[best_tag_id] best_path.append(best_tag_id) # Pop off the start tag (we dont want to return that to the caller) start = best_path.pop() assert start == self.tag_to_ix[START_TAG] # Sanity check best_path.reverse() return path_score, best_path
defforward(self, sentence): # dont confuse this with _forward_alg above. # Get the emission scores from the BiLSTM lstm_feats = self._get_lstm_features(sentence)
# Find the best path, given the features. score, tag_seq = self._viterbi_decode(lstm_feats) return score, tag_seq
# Make up some training data training_data = [ ("the wall street journal reported today that apple corporation made money" .split(), "B I I I O O O B I O O".split()), ("georgia tech is a university in georgia".split(), "B I O O O O B".split()) ]
word_to_ix = {} for sentence, tags in training_data: for word in sentence: if word notin word_to_ix: word_to_ix[word] = len(word_to_ix)
# Check predictions before training with torch.no_grad(): precheck_sent = prepare_sequence(training_data[0][0], word_to_ix) precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long) print(model(precheck_sent))
for epoch inrange(300): for sentence, tags in training_data: # Step 1. Remember that Pytorch accumulates gradients. # We need to clear them out before each instance model.zero_grad()
# Step 2. Get our inputs ready for the network, that is, # turn them into Tensors of word indices. sentence_in = prepare_sequence(sentence, word_to_ix) targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
# Step 3. Run our forward pass. loss = model.neg_log_likelihood(sentence_in, targets)
# Step 4. Compute the loss, gradients, and update the parameters by # calling optimizer.step() loss.backward() optimizer.step()
# Check predictions after training with torch.no_grad(): precheck_sent = prepare_sequence(training_data[0][0], word_to_ix) print(model(precheck_sent))