后端 Day 17:GRU (Gated Recurrent Unit) 概念介绍与实作

wwwdaoseacom · August 12, 2021 · 1 hits

前言

原来还想多介绍几个应用,但是,一直担心忘了另一个 RNN 的变形 -- GRU,所以,还是先把它处理掉,才好 focus 在应用上。另一方面,LSTM 运行速度非常慢/images/emoticon/emoticon18.gif,如果改用 GRU,希望测试可以快一点。

GRU (Gated Recurrent Unit) vs. LSTM

RNN 还有一个兄弟,与 LSTM 类似的模型,称为『GRU』(无译名,Gated Recurrent Unit),如下图,本来想把它忽略掉,但看到相关文章,说它能加快运行速度及减少内存的耗用,因此,还是花点时间实验看看。
https://d1dwq032kyr03c.cloudfront.net/upload/images/20171226/20001976rhen46Lrsx.png
图. GRU vs. LSTM 性能比较,图片来源:Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling

GRU 架构

回想一下,LSTM 的记忆 (Ct, Current Memeory State) 由输入 (input)、遗忘 (forget) 两个阀 (Gate) 控制,GRU 则简化为一个『更新阀』(Update Gate) 控制,公式如下,zt 为更新阀:
https://d1dwq032kyr03c.cloudfront.net/upload/images/20171226/20001976phdFedCH1q.png
https://d1dwq032kyr03c.cloudfront.net/upload/images/20171226/20001976lMy5V0uk4i.png
在记忆处理上,LSTM 与 GRU 的差异如下,两者使用的数学符号不同,看得有点累:
https://d1dwq032kyr03c.cloudfront.net/upload/images/20171227/20001976UgmqRVibwU.png
图. LSTM 记忆处理
https://d1dwq032kyr03c.cloudfront.net/upload/images/20171227/20001976apZkdgfLoj.png
图. GRU 记忆处理

整体架构如下图,详细说明请参考CS224d 笔记 4 续——RNN 隐藏层计算之 GRU 和 LSTMfrom vanilla RNN to GRU & LSTMs,后者包含视频及动画投视频,笔者自认功力有限,没办法说明的更清楚,就此打住。
https://d1dwq032kyr03c.cloudfront.net/upload/images/20171212/20001976QGf1IXC07N.jpg
图. GRU 架构,图片来源:Evolution: from vanilla RNN to GRU & LSTMs

实作

Keras 提供 GRU 函数可直接使用,我么就可以把前两篇的程序 LSTM.py 及 Sentiment1.py 中的 LSTM layer 直接换成 GRU Layer,测试看看性能及准确率是否有改善,可自这里下载,文件名为 GRU.py 及 Sentiment1_GRU.py。由于只改一行,就不列出代码了,以免让读者以为笔者滥竽充数。

运行

两支程序运行方式分别如下:
python GRU.py
python Sentiment1_GRU.py

测试结果如下:

  1. 性能及准确率不是很显著,运行时间相差无几 (GRU:7 or 8 秒,LSTM:9 秒)。/images/emoticon/emoticon02.gif
  2. 附带一提,RNN/LSTM/GRU 在优化时,常使用 rmsprop 函数取代 SGD 或 adam,因为,它收敛的速度会比较快,原因是 rmsprop 的学习速率 (learning rate) 会随着之前的梯度总和作反向的调整。
  3. 结论是读者如果认为少许的改善,也是值得的,就可以改用 GRU,但是,用谷大神搜索时,还是以 LSTM 为关键字,能找到较多的资源。

结论

我们已经将 RNN 三大算法介绍过了,也实作情绪分析应用,接下来,笔者还是会多找一些应用,来跟大家一起讨论,下次见了。


,

没有 GRU.py 及 Sentiment1_GRU.py

,

您好,读过您的文章后,最近在专题有实作 GRU,但一直尝试各种调整参数,accuracy 及 val_accuracy 都一直为 0.0000e+00,想请问该如何解决,谢谢~
https://d1dwq032kyr03c.cloudfront.net/upload/images/20210712/20139428GN1j1Z9YBO.jpghttps://d1dwq032kyr03c.cloudfront.net/upload/images/20210712/201394287VxzoaFwYA.jpghttps://d1dwq032kyr03c.cloudfront.net/upload/images/20210712/20139428Tk0DJeUyzn.jpg

No Reply at the moment.
You need to Sign in before reply, if you don't have an account, please Sign up first.