循环神经网络(高级篇)
例:人名分类
数据准备
根据人名的英文拼写判断人所在的国家,数据形式如下:
这个问题是序列输入,而只有一个输出,若采用RNN,模型形式如下:
.assets/image-20200927161458425.png)
采用GRU的话,模型形式如下:
.assets/image-20200927161924466.png)
数据处理:
名字序列转ASCII码值,进一步表示成one-hot
.assets/image-20200927200005921.png)
对ASCII码表示的输入做padding,统一长度好形成一个张量
.assets/image-20200927200047083.png)
国家表示成数字label
.assets/image-20200927200209709.png)
模型与代码
双向神经网络图:
.assets/image-20200927214633630.png)
注:$hidden=[h_N^f,h_N^b]$
1 | class NameDataset(Dataset): |
1 | class RNNClassifier(torch.nn.Module): |
1 | def trainModel(): |
1 | def testModel(): |
课程来源:《PyTorch深度学习实践》完结合集
.assets/image-20200927161615676.png)