每天五分钟深度学习框架pytorch:搭建LSTM完成手写字体识别任务?
本文重点
前面我们学习了LSTM的搭建,我们也学习过很多卷积神经网络关于手写字体的识别,本文我们使用LSTM来完成手写字体的识别。
网络模型的搭建
class RNN(nn.Module):
def __init__(self,in_dim,hidden_dim,n_layer,n_class):
super(RNN,self).__init__()
self.n_layer=n_layer
self.hidden_dim=hidden_dim
self.lstm=nn.LSTM(in_dim,hidden_dim,n_layer)
self.classifier=nn.Linear(hidden_dim,n_class)
def forward(self,x):
out,_ =self.lstm(x)
out =out[-1,:,:]
out =self.classifier(out)
return out
代码解析:
搭建了一个循环神经网络,然后循环神经网络会有两个输出,然后我们使用out输出,这个out输出是所有时间步的的最后一层的输出,它的维度为[时间步数,batch,词维度]ÿ