RNN原理

  • Recurrent Neural Network,循环神经网络

SimpleRNN

  • SimpleRNN其结构如下图所示:
    • 输入为一个向量序列\(\{x_0,x_1,x_2...x_n\}\)
    • 在时间步 \(t\),序列的元素 \(x_t\) 和上一时间步的输出 $h_{t-1} $一起,经过RNN单元处理,产生输出 \(h_t\); \[h_t=ϕ(Wx_t+Uh_{t−1})\] \[y_t=Vh_t\]
    • \(h_t\) 为隐藏层状态,携带了序列截止时间步 \(t\) 的信息;\(y_t\) 为时间步 \(t\) 的输出;\(h_t\) 继续作为下一时间步的输入
    • 整个序列被处理完,最终的输出 \(y_n\) 即为RNN的输出;根据情况,也可返回所有的输出序列 \(\{y_0,y_1,y_2...y_n\}\)
    • 序列的每个元素是经过同一个RNN处理,因此待学习的参数只有一组:\(W,U,V\)
  • 序列元素依次经过RNN的激活(sigmoid/tanh)函数的处理,存在信息丢失;并且在训练时反向传播会导致梯度消失,因此只能储存短期记忆
    • 例如训练单词it对应的向量时,只能利用timeis对应的信息,而what对应的信息丢失
1

1

LSTM

LSTM原理

  • Long Short-Term Memory,其框架如下所示,LSTM单元利用当前输入、短期记忆和长期记忆,更新长期和短期记忆,并产生输出

LSTM结构

  • LSTM的结构如下图所示,包含四个门:forget gate,learn gate,remember gate,ouput(use) gate

    • forget gate:决定长期记忆\(c_t\)中哪些信息该保留,哪些该忘记

      • 首先整合当前输入\(x_t\)和短期记忆\(h_{t-1}\),输出一个向量\(f_t\)
      • \(f_t\)的值介于\(0-1\)之间,每一位对应于长期记忆的一个数字,\(1\)表示完全保留,\(0\)表示完全丢弃 \[f_t=\sigma(W_f[h_{t-1},x_t]+b_f)\] \[Out_f = c_{t-1}\cdot f_t\]
    • learn gate:决定短期记忆和当前输入中学到的信息

      • 首先整合 \(x_t\) 和短期记忆 \(h_{t-1}\) 的信息 \(\hat c_t\)
      • 然后通过 \(x_t\)\(h_{t-1}\) 获得一个遗忘因子 \(i_t\),其值位于\(0-1\)之间
      • 再将上两步的结果结合 \[\hat c_t=tanh(W_n[h_{t-1},x_t]+b_n)\]
        \[i_t=\sigma(W_i[h_{t-1},x_t]+b_i)\] \[Out_n = i_t\cdot \hat c_t\]
    • remember gate:整合上一步的长短期记忆,更新长期记忆 \[c_t = Out_f+Out_n\]

    • output(use) gate:整合上一步的长短期记忆,更新短期记忆 \[o_t=\sigma(W_o[h_{t-1},x_t]+b_o)\] \[h_t=o_t\cdot tanh(c_t)\]

  • 短期记忆\(h_t\),即为LSTM当前时间步\(t\)的输出

  • 综上LSTM单元的训练参数有四组:forget gate参数\(\{W_f,b_f\}\)learn gate参数\(\{W_n,b_n\}\)\(\{W_i,b_i\}\)output gate参数\(\{W_o,b_o\}\)

  • LSTM中不同位置处sigmoidtanh激活函数的选择,向量相加加或相乘的确定,具有一定的随意性。之所以选择现结构,是因为在实践中有效

peephole机制

  • 门机制中的sigmoid激活函数,将输入转化成\(0-1\)数值;sigmoid乘以另一向量,即可决定保留该向量的哪些信息;
  • 上述LSTM结构中三个sigmoid函数的输入都是当前输入和短期记忆\([h_{t_1},x_t]\),即决定LSTM单元保留哪些信息的都是短期记忆;
  • peephole connections:将长期记忆也加入到sigmoid激活函数的输入中,其在LSTM中的决策参与度提高了,即长期和短期记忆共同决定保留哪些信息、丢弃哪些信息 \[f_t=\sigma(W_f[c_{t-1},h_{t-1},x_t]+b_f)\] \[i_t=\sigma(W_i[c_{t-1},h_{t-1},x_t]+b_i)\] \[o_t=\sigma(W_o[c_{t-1},h_{t-1},x_t]+b_o)\]
1

1

GRU

  • Gated Recurrent Unit,将forget gatelearn gate整合成单个的update gate,单元状态(长期记忆)\(c_{t}\)与隐藏状态(短期记忆)\(h_t\)合并
1

RNN实现

1
2
3
4
5
6
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

SimpleRNN

对输入序列的每个向量\(x_t\),进行如下计算: \(h_t=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{(t−1)}+b_{hh})\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 指定输入的特征量,隐藏状态的长度,rnn的层数
rnn = nn.RNN(
input_size=6,
hidden_size=10,
num_layers=2,
batch_first=True, # 输入和输出张量形状:batch,seq,feature
bidirectional=False, # 双向RNN
)

# 输入张量:
input = torch.randn(5, 3, 6) # batch,seq,feature
h0 = torch.randn(2, 5, 10) # num_layers,batch,hidden_size
print("输入形状:", input.shape)

output, hn = rnn(input, h0)
print("输出形状:", output.shape, " 隐藏层状态形状:", hn.shape)
输入形状: torch.Size([5, 3, 6])
输出形状: torch.Size([5, 3, 10])  隐藏层状态形状: torch.Size([2, 5, 10])
1
2
print("输入向量对应的权重 W_ih:",rnn.weight_ih_l0.shape)
print("隐藏状态对应的权重 W_hh:",rnn.weight_hh_l0.shape)
输入向量对应的权重 W_ih: torch.Size([10, 6])
隐藏状态对应的权重 W_hh: torch.Size([10, 10])

LSTM

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
rnn = nn.LSTM(
input_size=6,
hidden_size=10,
num_layers=2,
batch_first=True, # 输入和输出张量形状:batch,seq,feature
bidirectional=False,
)
input = torch.randn(5, 3, 6)
h0 = torch.randn(2, 5, 10) # num_layer,batch,hidden
c0 = torch.randn(2, 5, 10)

output, (hn, cn) = rnn(input, (h0, c0))
print("输出形状:", output.shape)
print("hidden state:", hn.shape)
print("cell state:", cn.shape)
输出形状: torch.Size([5, 3, 10])
hidden state: torch.Size([2, 5, 10])
cell state: torch.Size([2, 5, 10])
1

GRU

1
2
3
4
rnn = nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
1
h0.shape
torch.Size([2, 3, 15])
1

RNN训练流程

训练数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
plt.figure(figsize=(8, 5))

# 序列数据
seq_length = 20
time_steps = np.linspace(0, np.pi, seq_length + 1)
data = np.sin(time_steps)
data.resize((seq_length + 1, 1))

x = data[:-1] # 数据
y = data[1:] # 标签

# 图示数据
plt.plot(time_steps[1:], x, 'r.', label='input, x')
plt.plot(time_steps[1:], y, 'b.', label='target, y')

plt.legend(loc='best')
plt.show()

定义模型

1
2
3
4
5
6
7
8
9
10
11
12
13
class RNN(nn.Module):
def __init__(self, input_size, output_size, hidden_dim, n_layers):
super(RNN, self).__init__()
self.hidden_dim = hidden_dim
self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_size)

def forward(self, x, hidden):
batch_size = x.size(0)
r_out, hidden = self.rnn(x, hidden)
r_out = r_out.view(-1, self.hidden_dim)
output = self.fc(r_out)
return output, hidden
1
2
3
4
5
6
7
8
9
# 检测正确建模
test_rnn = RNN(input_size=1, output_size=1, hidden_dim=10, n_layers=2)

test_input = torch.Tensor(data).unsqueeze(0)
print('Input size:', test_input.size())

test_out, test_h = test_rnn(test_input, None)
print('Output size:', test_out.size())
print('Hidden state size:', test_h.size())
Input size: torch.Size([1, 21, 1])
Output size: torch.Size([21, 1])
Hidden state size: torch.Size([2, 1, 10])

训练模型

1
2
3
4
5
# 参数
input_size = 1
output_size = 1
hidden_dim = 32
n_layers = 1
1
2
3
# 初始化模型
rnn = RNN(input_size, output_size, hidden_dim, n_layers)
print(rnn)
RNN(
  (rnn): RNN(1, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=1, bias=True)
)
1
2
3
# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 训练模型
def train(rnn, n_steps, print_every):
hidden = None
for batch_i, step in enumerate(range(n_steps)):
x_tensor = torch.Tensor(x).unsqueeze(0)
y_tensor = torch.Tensor(y)

# 前向推理
prediction, hidden = rnn(x_tensor, hidden)
hidden = hidden.data

# 损失函数
loss = criterion(prediction, y_tensor)

# 梯度归零
optimizer.zero_grad()

# 反向传播
loss.backward()

# 更新梯度
optimizer.step()

if batch_i % print_every == 0:
print('Loss: ', loss.item())
plt.plot(time_steps[1:], x, 'r.')
plt.plot(time_steps[1:], prediction.data.numpy().flatten(), 'b.')
plt.show()
return rnn
1
2
3
4
5
n_steps = 75
print_every = 15

# 训练
trained_rnn = train(rnn, n_steps, print_every)
Loss:  0.40589970350265503

Loss:  0.035483404994010925

Loss:  0.012853428721427917

Loss:  0.00824706070125103

Loss:  0.010340889915823936

1

RNN示例:字符级文本生成

1
2
3
4
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

数据集

1
2
3
4
with open('datasets/anna.txt', 'r') as f:
text = f.read()

text[:100]
'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'
1
2
3
4
5
6
7
8
# 文本向量化

chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}

encoded = np.array([char2int[ch] for ch in text])
encoded[:20]
array([77, 41, 28, 66,  7, 21, 47, 58,  4, 35, 35, 35, 23, 28, 66, 66, 31,
       58,  9, 28])

数据预处理

1
2
3
4
5
def one_hot_encode(arr, n_labels):
one_hot = np.zeros((arr.size, n_labels), dtype=np.float32)
one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
one_hot = one_hot.reshape((*arr.shape, n_labels))
return one_hot
1
2
3
test_seq = np.array([[3, 5, 1]])
one_hot = one_hot_encode(test_seq, 8)
print(one_hot)
[[[0. 0. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 1. 0. 0.]
  [0. 1. 0. 0. 0. 0. 0. 0.]]]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 创建批量数据
def get_batches(arr, batch_size, seq_length):

batch_size_total = batch_size * seq_length
n_batches = len(arr) // batch_size_total
arr = arr[:n_batches * batch_size_total]

arr = arr.reshape((batch_size, -1))
for n in range(0, arr.shape[1], seq_length):
x = arr[:, n:n + seq_length]
y = np.zeros_like(x)
try:
y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n + seq_length]
except IndexError:
y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]
yield x, y


batches = get_batches(encoded, 8, 50)
x, y = next(batches)
print('x\n', x[:4, :10])
print('\ny\n', y[:4, :10])
x
 [[77 41 28 66  7 21 47 58  4 35]
 [39 81 37 58  7 41 28  7 58 28]
 [21 37 67 58 81 47 58 28 58  9]
 [39 58  7 41 21 58 11 41 12 21]]

y
 [[41 28 66  7 21 47 58  4 35 35]
 [81 37 58  7 41 28  7 58 28  7]
 [37 67 58 81 47 58 28 58  9 81]
 [58  7 41 21 58 11 41 12 21  9]]

创建模型

1
2
3
4
5
6
7
8
# gpu 可用
train_on_gpu = torch.cuda.is_available()
if (train_on_gpu):
print('Training on GPU!')
else:
print(
'No GPU available, training on CPU; consider making n_epochs very small.'
)
Training on GPU!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# 创建模型
class CharRNN(nn.Module):
def __init__(self,
tokens,
n_hidden=256,
n_layers=2,
drop_prob=0.5,
lr=0.001):
super(CharRNN, self).__init__()
self.drop_prob = drop_prob
self.n_layers = n_layers
self.n_hidden = n_hidden
self.lr = lr

self.chars = tokens
self.int2char = dict(enumerate(self.chars))
self.char2int = {ch: ii for ii, ch in self.int2char.items()}

self.lstm = nn.LSTM(len(self.chars),
n_hidden,
n_layers,
dropout=drop_prob,
batch_first=True)

self.dropout = nn.Dropout(drop_prob)
self.fc = nn.Linear(n_hidden, len(self.chars))

def forward(self, x, hidden):
r_output, hidden = self.lstm(x, hidden)
out = self.dropout(r_output)
out = out.contiguous().view(-1, self.n_hidden)
out = self.fc(out)
return out, hidden

def init_hidden(self, batch_size):
weight = next(self.parameters()).data

if train_on_gpu:
hidden = (weight.new(self.n_layers, batch_size,
self.n_hidden).zero_().cuda(),
weight.new(self.n_layers, batch_size,
self.n_hidden).zero_().cuda())
else:
hidden = (weight.new(self.n_layers, batch_size,
self.n_hidden).zero_(),
weight.new(self.n_layers, batch_size,
self.n_hidden).zero_())
return hidden

训练模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def train(net,
data,
epochs=10,
batch_size=10,
seq_length=10,
lr=0.001,
clip=5,
val_frac=0.1,
print_every=10):
net.train()

opt = torch.optim.Adam(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

val_idx = int(len(data) * (1 - val_frac))
data, val_data = data[:val_idx], data[val_idx:]

if train_on_gpu:
net.cuda()

counter = 0
n_chars = len(net.chars)

for e in range(epochs):
h = net.init_hidden(batch_size)

for x, y in get_batches(data, batch_size, seq_length):
counter += 1

x = one_hot_encode(x, n_chars)
inputs, targets = torch.from_numpy(x), torch.from_numpy(y)

if train_on_gpu:
inputs, targets = inputs.cuda(), targets.cuda()

h = tuple([each.data for each in h])

net.zero_grad()

output, h = net(inputs, h)

loss = criterion(output,
targets.view(batch_size * seq_length).long())
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), clip)
opt.step()

if counter % print_every == 0:
val_h = net.init_hidden(batch_size)
val_losses = []
net.eval()
for x, y in get_batches(val_data, batch_size, seq_length):
x = one_hot_encode(x, n_chars)
x, y = torch.from_numpy(x), torch.from_numpy(y)

val_h = tuple([each.data for each in val_h])

inputs, targets = x, y
if (train_on_gpu):
inputs, targets = inputs.cuda(), targets.cuda()

output, val_h = net(inputs, val_h)
val_loss = criterion(
output,
targets.view(batch_size * seq_length).long())

val_losses.append(val_loss.item())

net.train()

print("Epoch: {}/{}...".format(e + 1, epochs),
"Step: {}...".format(counter),
"Loss: {:.4f}...".format(loss.item()),
"Val Loss: {:.4f}".format(np.mean(val_losses)))
1
2
3
4
5
n_hidden = 512
n_layers = 2

net = CharRNN(chars, n_hidden, n_layers)
print(net)
CharRNN(
  (lstm): LSTM(83, 512, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=83, bias=True)
)
1
2
3
4
5
6
7
8
9
10
11
batch_size = 128
seq_length = 100
n_epochs = 20

train(net,
encoded,
epochs=n_epochs,
batch_size=batch_size,
seq_length=seq_length,
lr=0.001,
print_every=10)
Epoch: 1/20... Step: 10... Loss: 3.2684... Val Loss: 3.2099
Epoch: 1/20... Step: 20... Loss: 3.1553... Val Loss: 3.1399
Epoch: 1/20... Step: 30... Loss: 3.1438... Val Loss: 3.1250
Epoch: 1/20... Step: 40... Loss: 3.1109... Val Loss: 3.1204
Epoch: 1/20... Step: 50... Loss: 3.1416... Val Loss: 3.1175
Epoch: 1/20... Step: 60... Loss: 3.1164... Val Loss: 3.1145
Epoch: 1/20... Step: 70... Loss: 3.1047... Val Loss: 3.1109
Epoch: 1/20... Step: 80... Loss: 3.1169... Val Loss: 3.1029
Epoch: 1/20... Step: 90... Loss: 3.1012... Val Loss: 3.0809
Epoch: 1/20... Step: 100... Loss: 3.0419... Val Loss: 3.0239
Epoch: 1/20... Step: 110... Loss: 2.9835... Val Loss: 2.9680
Epoch: 1/20... Step: 120... Loss: 2.8329... Val Loss: 2.8223
Epoch: 1/20... Step: 130... Loss: 2.8383... Val Loss: 2.8556
Epoch: 2/20... Step: 140... Loss: 2.7164... Val Loss: 2.6663
Epoch: 2/20... Step: 150... Loss: 2.6229... Val Loss: 2.5753
Epoch: 2/20... Step: 160... Loss: 2.5504... Val Loss: 2.5114
Epoch: 2/20... Step: 170... Loss: 2.4817... Val Loss: 2.4664
Epoch: 2/20... Step: 180... Loss: 2.4539... Val Loss: 2.4292
Epoch: 2/20... Step: 190... Loss: 2.4003... Val Loss: 2.3941
Epoch: 2/20... Step: 200... Loss: 2.4008... Val Loss: 2.3660
Epoch: 2/20... Step: 210... Loss: 2.3633... Val Loss: 2.3358
Epoch: 2/20... Step: 220... Loss: 2.3302... Val Loss: 2.3069
Epoch: 2/20... Step: 230... Loss: 2.3156... Val Loss: 2.2811
Epoch: 2/20... Step: 240... Loss: 2.2904... Val Loss: 2.2567
Epoch: 2/20... Step: 250... Loss: 2.2315... Val Loss: 2.2277
Epoch: 2/20... Step: 260... Loss: 2.1987... Val Loss: 2.1988
Epoch: 2/20... Step: 270... Loss: 2.2062... Val Loss: 2.1753
Epoch: 3/20... Step: 280... Loss: 2.2010... Val Loss: 2.1486
Epoch: 3/20... Step: 290... Loss: 2.1719... Val Loss: 2.1246
Epoch: 3/20... Step: 300... Loss: 2.1350... Val Loss: 2.1097
Epoch: 3/20... Step: 310... Loss: 2.1090... Val Loss: 2.0875
Epoch: 3/20... Step: 320... Loss: 2.0769... Val Loss: 2.0644
Epoch: 3/20... Step: 330... Loss: 2.0504... Val Loss: 2.0463
Epoch: 3/20... Step: 340... Loss: 2.0679... Val Loss: 2.0248
Epoch: 3/20... Step: 350... Loss: 2.0545... Val Loss: 2.0122
Epoch: 3/20... Step: 360... Loss: 1.9831... Val Loss: 1.9931
Epoch: 3/20... Step: 370... Loss: 2.0144... Val Loss: 1.9765
Epoch: 3/20... Step: 380... Loss: 1.9839... Val Loss: 1.9569
Epoch: 3/20... Step: 390... Loss: 1.9620... Val Loss: 1.9429
Epoch: 3/20... Step: 400... Loss: 1.9336... Val Loss: 1.9274
Epoch: 3/20... Step: 410... Loss: 1.9439... Val Loss: 1.9143
Epoch: 4/20... Step: 420... Loss: 1.9417... Val Loss: 1.8957
Epoch: 4/20... Step: 430... Loss: 1.9184... Val Loss: 1.8853
Epoch: 4/20... Step: 440... Loss: 1.9016... Val Loss: 1.8775
Epoch: 4/20... Step: 450... Loss: 1.8396... Val Loss: 1.8572
Epoch: 4/20... Step: 460... Loss: 1.8320... Val Loss: 1.8479
Epoch: 4/20... Step: 470... Loss: 1.8746... Val Loss: 1.8396
Epoch: 4/20... Step: 480... Loss: 1.8527... Val Loss: 1.8245
Epoch: 4/20... Step: 490... Loss: 1.8512... Val Loss: 1.8171
Epoch: 4/20... Step: 500... Loss: 1.8388... Val Loss: 1.8028
Epoch: 4/20... Step: 510... Loss: 1.8212... Val Loss: 1.7942
Epoch: 4/20... Step: 520... Loss: 1.8299... Val Loss: 1.7832
Epoch: 4/20... Step: 530... Loss: 1.7957... Val Loss: 1.7752
Epoch: 4/20... Step: 540... Loss: 1.7593... Val Loss: 1.7640
Epoch: 4/20... Step: 550... Loss: 1.8066... Val Loss: 1.7529
Epoch: 5/20... Step: 560... Loss: 1.7772... Val Loss: 1.7441
Epoch: 5/20... Step: 570... Loss: 1.7592... Val Loss: 1.7396
Epoch: 5/20... Step: 580... Loss: 1.7381... Val Loss: 1.7289
Epoch: 5/20... Step: 590... Loss: 1.7341... Val Loss: 1.7203
Epoch: 5/20... Step: 600... Loss: 1.7233... Val Loss: 1.7146
Epoch: 5/20... Step: 610... Loss: 1.7124... Val Loss: 1.7046
Epoch: 5/20... Step: 620... Loss: 1.7138... Val Loss: 1.7029
Epoch: 5/20... Step: 630... Loss: 1.7224... Val Loss: 1.6903
Epoch: 5/20... Step: 640... Loss: 1.6983... Val Loss: 1.6863
Epoch: 5/20... Step: 650... Loss: 1.6905... Val Loss: 1.6752
Epoch: 5/20... Step: 660... Loss: 1.6594... Val Loss: 1.6704
Epoch: 5/20... Step: 670... Loss: 1.6819... Val Loss: 1.6640
Epoch: 5/20... Step: 680... Loss: 1.6872... Val Loss: 1.6568
Epoch: 5/20... Step: 690... Loss: 1.6595... Val Loss: 1.6552
Epoch: 6/20... Step: 700... Loss: 1.6551... Val Loss: 1.6462
Epoch: 6/20... Step: 710... Loss: 1.6496... Val Loss: 1.6408
Epoch: 6/20... Step: 720... Loss: 1.6318... Val Loss: 1.6312
Epoch: 6/20... Step: 730... Loss: 1.6589... Val Loss: 1.6292
Epoch: 6/20... Step: 740... Loss: 1.6186... Val Loss: 1.6267
Epoch: 6/20... Step: 750... Loss: 1.6037... Val Loss: 1.6149
Epoch: 6/20... Step: 760... Loss: 1.6439... Val Loss: 1.6133
Epoch: 6/20... Step: 770... Loss: 1.6214... Val Loss: 1.6056
Epoch: 6/20... Step: 780... Loss: 1.6137... Val Loss: 1.6016
Epoch: 6/20... Step: 790... Loss: 1.5932... Val Loss: 1.5931
Epoch: 6/20... Step: 800... Loss: 1.6040... Val Loss: 1.5969
Epoch: 6/20... Step: 810... Loss: 1.5964... Val Loss: 1.5870
Epoch: 6/20... Step: 820... Loss: 1.5575... Val Loss: 1.5828
Epoch: 6/20... Step: 830... Loss: 1.6043... Val Loss: 1.5756
Epoch: 7/20... Step: 840... Loss: 1.5554... Val Loss: 1.5712
Epoch: 7/20... Step: 850... Loss: 1.5727... Val Loss: 1.5696
Epoch: 7/20... Step: 860... Loss: 1.5676... Val Loss: 1.5616
Epoch: 7/20... Step: 870... Loss: 1.5734... Val Loss: 1.5599
Epoch: 7/20... Step: 880... Loss: 1.5708... Val Loss: 1.5537
Epoch: 7/20... Step: 890... Loss: 1.5672... Val Loss: 1.5536
Epoch: 7/20... Step: 900... Loss: 1.5418... Val Loss: 1.5477
Epoch: 7/20... Step: 910... Loss: 1.5274... Val Loss: 1.5427
Epoch: 7/20... Step: 920... Loss: 1.5348... Val Loss: 1.5420
Epoch: 7/20... Step: 930... Loss: 1.5371... Val Loss: 1.5350
Epoch: 7/20... Step: 940... Loss: 1.5318... Val Loss: 1.5341
Epoch: 7/20... Step: 950... Loss: 1.5469... Val Loss: 1.5285
Epoch: 7/20... Step: 960... Loss: 1.5396... Val Loss: 1.5269
Epoch: 7/20... Step: 970... Loss: 1.5491... Val Loss: 1.5230
Epoch: 8/20... Step: 980... Loss: 1.5196... Val Loss: 1.5151
Epoch: 8/20... Step: 990... Loss: 1.5156... Val Loss: 1.5145
Epoch: 8/20... Step: 1000... Loss: 1.5137... Val Loss: 1.5077
Epoch: 8/20... Step: 1010... Loss: 1.5480... Val Loss: 1.5094
Epoch: 8/20... Step: 1020... Loss: 1.5163... Val Loss: 1.5056
Epoch: 8/20... Step: 1030... Loss: 1.4934... Val Loss: 1.5022
Epoch: 8/20... Step: 1040... Loss: 1.5129... Val Loss: 1.5031
Epoch: 8/20... Step: 1050... Loss: 1.4812... Val Loss: 1.4964
Epoch: 8/20... Step: 1060... Loss: 1.4982... Val Loss: 1.4925
Epoch: 8/20... Step: 1070... Loss: 1.4994... Val Loss: 1.4875
Epoch: 8/20... Step: 1080... Loss: 1.4974... Val Loss: 1.4881
Epoch: 8/20... Step: 1090... Loss: 1.4706... Val Loss: 1.4843
Epoch: 8/20... Step: 1100... Loss: 1.4743... Val Loss: 1.4796
Epoch: 8/20... Step: 1110... Loss: 1.4622... Val Loss: 1.4781
Epoch: 9/20... Step: 1120... Loss: 1.4796... Val Loss: 1.4796
Epoch: 9/20... Step: 1130... Loss: 1.4893... Val Loss: 1.4764
Epoch: 9/20... Step: 1140... Loss: 1.4759... Val Loss: 1.4679
Epoch: 9/20... Step: 1150... Loss: 1.4925... Val Loss: 1.4718
Epoch: 9/20... Step: 1160... Loss: 1.4514... Val Loss: 1.4657
Epoch: 9/20... Step: 1170... Loss: 1.4605... Val Loss: 1.4629
Epoch: 9/20... Step: 1180... Loss: 1.4495... Val Loss: 1.4687
Epoch: 9/20... Step: 1190... Loss: 1.4862... Val Loss: 1.4622
Epoch: 9/20... Step: 1200... Loss: 1.4313... Val Loss: 1.4548
Epoch: 9/20... Step: 1210... Loss: 1.4401... Val Loss: 1.4512
Epoch: 9/20... Step: 1220... Loss: 1.4504... Val Loss: 1.4553
Epoch: 9/20... Step: 1230... Loss: 1.4249... Val Loss: 1.4508
Epoch: 9/20... Step: 1240... Loss: 1.4332... Val Loss: 1.4452
Epoch: 9/20... Step: 1250... Loss: 1.4445... Val Loss: 1.4436
Epoch: 10/20... Step: 1260... Loss: 1.4480... Val Loss: 1.4434
Epoch: 10/20... Step: 1270... Loss: 1.4383... Val Loss: 1.4392
Epoch: 10/20... Step: 1280... Loss: 1.4405... Val Loss: 1.4332
Epoch: 10/20... Step: 1290... Loss: 1.4421... Val Loss: 1.4338
Epoch: 10/20... Step: 1300... Loss: 1.4320... Val Loss: 1.4349
Epoch: 10/20... Step: 1310... Loss: 1.4420... Val Loss: 1.4338
Epoch: 10/20... Step: 1320... Loss: 1.4085... Val Loss: 1.4336
Epoch: 10/20... Step: 1330... Loss: 1.4142... Val Loss: 1.4289
Epoch: 10/20... Step: 1340... Loss: 1.3981... Val Loss: 1.4257
Epoch: 10/20... Step: 1350... Loss: 1.3931... Val Loss: 1.4213
Epoch: 10/20... Step: 1360... Loss: 1.3841... Val Loss: 1.4259
Epoch: 10/20... Step: 1370... Loss: 1.3755... Val Loss: 1.4218
Epoch: 10/20... Step: 1380... Loss: 1.4250... Val Loss: 1.4148
Epoch: 10/20... Step: 1390... Loss: 1.4314... Val Loss: 1.4170
Epoch: 11/20... Step: 1400... Loss: 1.4284... Val Loss: 1.4185
Epoch: 11/20... Step: 1410... Loss: 1.4427... Val Loss: 1.4147
Epoch: 11/20... Step: 1420... Loss: 1.4309... Val Loss: 1.4073
Epoch: 11/20... Step: 1430... Loss: 1.3992... Val Loss: 1.4105
Epoch: 11/20... Step: 1440... Loss: 1.4362... Val Loss: 1.4060
Epoch: 11/20... Step: 1450... Loss: 1.3561... Val Loss: 1.4054
Epoch: 11/20... Step: 1460... Loss: 1.3756... Val Loss: 1.4073
Epoch: 11/20... Step: 1470... Loss: 1.3766... Val Loss: 1.4047
Epoch: 11/20... Step: 1480... Loss: 1.3972... Val Loss: 1.3992
Epoch: 11/20... Step: 1490... Loss: 1.3747... Val Loss: 1.3967
Epoch: 11/20... Step: 1500... Loss: 1.3645... Val Loss: 1.3986
Epoch: 11/20... Step: 1510... Loss: 1.3538... Val Loss: 1.3992
Epoch: 11/20... Step: 1520... Loss: 1.3887... Val Loss: 1.3923
Epoch: 12/20... Step: 1530... Loss: 1.4452... Val Loss: 1.3915
Epoch: 12/20... Step: 1540... Loss: 1.3953... Val Loss: 1.3887
Epoch: 12/20... Step: 1550... Loss: 1.3962... Val Loss: 1.3859
Epoch: 12/20... Step: 1560... Loss: 1.4023... Val Loss: 1.3845
Epoch: 12/20... Step: 1570... Loss: 1.3618... Val Loss: 1.3864
Epoch: 12/20... Step: 1580... Loss: 1.3288... Val Loss: 1.3844
Epoch: 12/20... Step: 1590... Loss: 1.3300... Val Loss: 1.3844
Epoch: 12/20... Step: 1600... Loss: 1.3490... Val Loss: 1.3835
Epoch: 12/20... Step: 1610... Loss: 1.3478... Val Loss: 1.3864
Epoch: 12/20... Step: 1620... Loss: 1.3535... Val Loss: 1.3791
Epoch: 12/20... Step: 1630... Loss: 1.3670... Val Loss: 1.3753
Epoch: 12/20... Step: 1640... Loss: 1.3440... Val Loss: 1.3791
Epoch: 12/20... Step: 1650... Loss: 1.3304... Val Loss: 1.3763
Epoch: 12/20... Step: 1660... Loss: 1.3709... Val Loss: 1.3695
Epoch: 13/20... Step: 1670... Loss: 1.3404... Val Loss: 1.3753
Epoch: 13/20... Step: 1680... Loss: 1.3551... Val Loss: 1.3698
Epoch: 13/20... Step: 1690... Loss: 1.3301... Val Loss: 1.3665
Epoch: 13/20... Step: 1700... Loss: 1.3306... Val Loss: 1.3616
Epoch: 13/20... Step: 1710... Loss: 1.3086... Val Loss: 1.3666
Epoch: 13/20... Step: 1720... Loss: 1.3270... Val Loss: 1.3703
Epoch: 13/20... Step: 1730... Loss: 1.3601... Val Loss: 1.3625
Epoch: 13/20... Step: 1740... Loss: 1.3294... Val Loss: 1.3621
Epoch: 13/20... Step: 1750... Loss: 1.2962... Val Loss: 1.3605
Epoch: 13/20... Step: 1760... Loss: 1.3291... Val Loss: 1.3585
Epoch: 13/20... Step: 1770... Loss: 1.3347... Val Loss: 1.3585
Epoch: 13/20... Step: 1780... Loss: 1.3094... Val Loss: 1.3519
Epoch: 13/20... Step: 1790... Loss: 1.3037... Val Loss: 1.3556
Epoch: 13/20... Step: 1800... Loss: 1.3259... Val Loss: 1.3502
Epoch: 14/20... Step: 1810... Loss: 1.3300... Val Loss: 1.3498
Epoch: 14/20... Step: 1820... Loss: 1.3115... Val Loss: 1.3474
Epoch: 14/20... Step: 1830... Loss: 1.3319... Val Loss: 1.3431
Epoch: 14/20... Step: 1840... Loss: 1.2688... Val Loss: 1.3421
Epoch: 14/20... Step: 1850... Loss: 1.2637... Val Loss: 1.3406
Epoch: 14/20... Step: 1860... Loss: 1.3170... Val Loss: 1.3437
Epoch: 14/20... Step: 1870... Loss: 1.3217... Val Loss: 1.3407
Epoch: 14/20... Step: 1880... Loss: 1.3195... Val Loss: 1.3401
Epoch: 14/20... Step: 1890... Loss: 1.3369... Val Loss: 1.3401
Epoch: 14/20... Step: 1900... Loss: 1.2999... Val Loss: 1.3382
Epoch: 14/20... Step: 1910... Loss: 1.3153... Val Loss: 1.3349
Epoch: 14/20... Step: 1920... Loss: 1.3041... Val Loss: 1.3391
Epoch: 14/20... Step: 1930... Loss: 1.2676... Val Loss: 1.3367
Epoch: 14/20... Step: 1940... Loss: 1.3331... Val Loss: 1.3396
Epoch: 15/20... Step: 1950... Loss: 1.3017... Val Loss: 1.3355
Epoch: 15/20... Step: 1960... Loss: 1.3021... Val Loss: 1.3367
Epoch: 15/20... Step: 1970... Loss: 1.2851... Val Loss: 1.3269
Epoch: 15/20... Step: 1980... Loss: 1.2824... Val Loss: 1.3319
Epoch: 15/20... Step: 1990... Loss: 1.2737... Val Loss: 1.3260
Epoch: 15/20... Step: 2000... Loss: 1.2606... Val Loss: 1.3242
Epoch: 15/20... Step: 2010... Loss: 1.2767... Val Loss: 1.3311
Epoch: 15/20... Step: 2020... Loss: 1.3080... Val Loss: 1.3249
Epoch: 15/20... Step: 2030... Loss: 1.2721... Val Loss: 1.3248
Epoch: 15/20... Step: 2040... Loss: 1.2891... Val Loss: 1.3206
Epoch: 15/20... Step: 2050... Loss: 1.2806... Val Loss: 1.3202
Epoch: 15/20... Step: 2060... Loss: 1.2827... Val Loss: 1.3219
Epoch: 15/20... Step: 2070... Loss: 1.2918... Val Loss: 1.3219
Epoch: 15/20... Step: 2080... Loss: 1.2858... Val Loss: 1.3213
Epoch: 16/20... Step: 2090... Loss: 1.2936... Val Loss: 1.3203
Epoch: 16/20... Step: 2100... Loss: 1.2737... Val Loss: 1.3163
Epoch: 16/20... Step: 2110... Loss: 1.2669... Val Loss: 1.3130
Epoch: 16/20... Step: 2120... Loss: 1.2843... Val Loss: 1.3172
Epoch: 16/20... Step: 2130... Loss: 1.2545... Val Loss: 1.3134
Epoch: 16/20... Step: 2140... Loss: 1.2673... Val Loss: 1.3119
Epoch: 16/20... Step: 2150... Loss: 1.2944... Val Loss: 1.3089
Epoch: 16/20... Step: 2160... Loss: 1.2658... Val Loss: 1.3128
Epoch: 16/20... Step: 2170... Loss: 1.2693... Val Loss: 1.3130
Epoch: 16/20... Step: 2180... Loss: 1.2581... Val Loss: 1.3123
Epoch: 16/20... Step: 2190... Loss: 1.2856... Val Loss: 1.3108
Epoch: 16/20... Step: 2200... Loss: 1.2443... Val Loss: 1.3063
Epoch: 16/20... Step: 2210... Loss: 1.2143... Val Loss: 1.3135
Epoch: 16/20... Step: 2220... Loss: 1.2763... Val Loss: 1.3081
Epoch: 17/20... Step: 2230... Loss: 1.2509... Val Loss: 1.3117
Epoch: 17/20... Step: 2240... Loss: 1.2526... Val Loss: 1.3090
Epoch: 17/20... Step: 2250... Loss: 1.2455... Val Loss: 1.3028
Epoch: 17/20... Step: 2260... Loss: 1.2519... Val Loss: 1.3079
Epoch: 17/20... Step: 2270... Loss: 1.2622... Val Loss: 1.3007
Epoch: 17/20... Step: 2280... Loss: 1.2646... Val Loss: 1.2985
Epoch: 17/20... Step: 2290... Loss: 1.2591... Val Loss: 1.3006
Epoch: 17/20... Step: 2300... Loss: 1.2187... Val Loss: 1.3061
Epoch: 17/20... Step: 2310... Loss: 1.2488... Val Loss: 1.3003
Epoch: 17/20... Step: 2320... Loss: 1.2377... Val Loss: 1.3023
Epoch: 17/20... Step: 2330... Loss: 1.2464... Val Loss: 1.3072
Epoch: 17/20... Step: 2340... Loss: 1.2496... Val Loss: 1.2971
Epoch: 17/20... Step: 2350... Loss: 1.2634... Val Loss: 1.2999
Epoch: 17/20... Step: 2360... Loss: 1.2586... Val Loss: 1.2965
Epoch: 18/20... Step: 2370... Loss: 1.2422... Val Loss: 1.2945
Epoch: 18/20... Step: 2380... Loss: 1.2481... Val Loss: 1.2942
Epoch: 18/20... Step: 2390... Loss: 1.2385... Val Loss: 1.3051
Epoch: 18/20... Step: 2400... Loss: 1.2693... Val Loss: 1.2974
Epoch: 18/20... Step: 2410... Loss: 1.2627... Val Loss: 1.2962
Epoch: 18/20... Step: 2420... Loss: 1.2337... Val Loss: 1.2883
Epoch: 18/20... Step: 2430... Loss: 1.2447... Val Loss: 1.2922
Epoch: 18/20... Step: 2440... Loss: 1.2341... Val Loss: 1.2936
Epoch: 18/20... Step: 2450... Loss: 1.2323... Val Loss: 1.2899
Epoch: 18/20... Step: 2460... Loss: 1.2457... Val Loss: 1.2914
Epoch: 18/20... Step: 2470... Loss: 1.2332... Val Loss: 1.2993
Epoch: 18/20... Step: 2480... Loss: 1.2231... Val Loss: 1.2909
Epoch: 18/20... Step: 2490... Loss: 1.2181... Val Loss: 1.2913
Epoch: 18/20... Step: 2500... Loss: 1.2346... Val Loss: 1.2918
Epoch: 19/20... Step: 2510... Loss: 1.2335... Val Loss: 1.2922
Epoch: 19/20... Step: 2520... Loss: 1.2471... Val Loss: 1.2908
Epoch: 19/20... Step: 2530... Loss: 1.2490... Val Loss: 1.2854
Epoch: 19/20... Step: 2540... Loss: 1.2635... Val Loss: 1.2854
Epoch: 19/20... Step: 2550... Loss: 1.2124... Val Loss: 1.2876
Epoch: 19/20... Step: 2560... Loss: 1.2234... Val Loss: 1.2851
Epoch: 19/20... Step: 2570... Loss: 1.2193... Val Loss: 1.2875
Epoch: 19/20... Step: 2580... Loss: 1.2569... Val Loss: 1.2992
Epoch: 19/20... Step: 2590... Loss: 1.2133... Val Loss: 1.2876
Epoch: 19/20... Step: 2600... Loss: 1.2198... Val Loss: 1.2921
Epoch: 19/20... Step: 2610... Loss: 1.2273... Val Loss: 1.2813
Epoch: 19/20... Step: 2620... Loss: 1.2029... Val Loss: 1.2822
Epoch: 19/20... Step: 2630... Loss: 1.2044... Val Loss: 1.2786
Epoch: 19/20... Step: 2640... Loss: 1.2234... Val Loss: 1.2819
Epoch: 20/20... Step: 2650... Loss: 1.2344... Val Loss: 1.2802
Epoch: 20/20... Step: 2660... Loss: 1.2336... Val Loss: 1.2819
Epoch: 20/20... Step: 2670... Loss: 1.2329... Val Loss: 1.2765
Epoch: 20/20... Step: 2680... Loss: 1.2183... Val Loss: 1.2775
Epoch: 20/20... Step: 2690... Loss: 1.2269... Val Loss: 1.2758
Epoch: 20/20... Step: 2700... Loss: 1.2301... Val Loss: 1.2785
Epoch: 20/20... Step: 2710... Loss: 1.1988... Val Loss: 1.2817
Epoch: 20/20... Step: 2720... Loss: 1.2052... Val Loss: 1.2792
Epoch: 20/20... Step: 2730... Loss: 1.1917... Val Loss: 1.2770
Epoch: 20/20... Step: 2740... Loss: 1.1990... Val Loss: 1.2773
Epoch: 20/20... Step: 2750... Loss: 1.1967... Val Loss: 1.2780
Epoch: 20/20... Step: 2760... Loss: 1.1961... Val Loss: 1.2748
Epoch: 20/20... Step: 2770... Loss: 1.2353... Val Loss: 1.2739
Epoch: 20/20... Step: 2780... Loss: 1.2497... Val Loss: 1.2741
1

获取最优模型

  • 过拟合与欠拟合
    • 实时监控训练和验证损失,如果训练损失远远低于验证损失,则模型过拟合;添加正则化、dropout,或使用更小的模型;
    • 如果训练和验证损失相近,则过拟合,可以增加模型的尺寸
  • 超参数
    • 模型定义时:隐藏层神经元数量n_hiddenLSTM的层数n_layers
      • n_layers建议设置值2或3,模型的总参数量与训练数据量处于同样的量级;如100MB的数据,当模型150K参数,模型会严重欠拟合,而10MB数量模型10M参数,模型会欠拟合,增大dropout参数
      • 总是训练较大的模型,然后尝试不同的dropout
    • 模型训练时:batch_size,seq_length,lr,及数据拆分为训练集及验证集的拆分比列
      • 尝试不同的超参数组合,选择性能最佳模型
1

保存模型

1
2
3
4
5
6
7
8
9
model_name = 'rnn_20_epoch.net'

checkpoint = {'n_hidden': net.n_hidden,
'n_layers': net.n_layers,
'state_dict': net.state_dict(),
'tokens': net.chars}

with open(model_name, 'wb') as f:
torch.save(checkpoint, f)
1

文本生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def predict(net, char, h=None, top_k=None):
# tensor inputs
x = np.array([[net.char2int[char]]])
x = one_hot_encode(x, len(net.chars))
inputs = torch.from_numpy(x)

if(train_on_gpu):
inputs = inputs.cuda()

# detach hidden state from history
h = tuple([each.data for each in h])
# get the output of the model
out, h = net(inputs, h)

# get the character probabilities
p = F.softmax(out, dim=1).data
if(train_on_gpu):
p = p.cpu() # move to cpu

# topK采样
if top_k is None:
top_ch = np.arange(len(net.chars))
else:
p, top_ch = p.topk(top_k)
top_ch = top_ch.numpy().squeeze()

# select the likely next character with some element of randomness
p = p.numpy().squeeze()
char = np.random.choice(top_ch, p=p/p.sum())

# return the encoded value of the predicted char and the hidden state
return net.int2char[char], h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 文本生成
def sample(net, size, prime='The', top_k=None):

if (train_on_gpu):
net.cuda()
else:
net.cpu()

net.eval() # eval mode

# First off, run through the prime characters
chars = [ch for ch in prime]
h = net.init_hidden(1)
for ch in prime:
char, h = predict(net, ch, h, top_k=top_k)

chars.append(char)

# Now pass in the previous character and get a new one
for ii in range(size):
char, h = predict(net, chars[-1], h, top_k=top_k)
chars.append(char)

return ''.join(chars)

print(sample(net, 1000, prime='Anna', top_k=5))
Anna
with a smile to holding a person was, and a line white sheel, who
did not know. His father was a long while, and her friend, and the
secundes, time and some of or that some made a dress so as happy, and
to see it. To her that he was not finished, he went into the same
time, the princess had always taken up to the corridor, was in which
the prevent position he was an expression to the table, and she could
do to triem to herself what they seemed to the cletched of the
coup of the sick man are a sinting state of charming head, was not to
such her for his brother's woman that when they were since he was sitting to
the soft might have been seening that her family and with the proviss,
which she had been set off the same thing, who had been disagreeable
and had sore of the propersy of always. And he set her. He spoke tran
out of the stranger and her husband who had no supported
that he had not heard the face of her starts, began to say, the pissons were
far in shame, and her heart, their sho
1

1

作者

ฅ´ω`ฅ

发布于

2021-06-17

更新于

2021-06-18

许可协议


评论