GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

时间:2022-09-24 07:38:17

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

RNN
GRU
matlab codes

RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着时间的推移,RNN单元就失去了对很久之前信息的保存和处理的能力,而且存在着gradient vanishing问题。

所以有些特殊类型的RNN网络相继被提出,比如LSTM(long short term memory)和GRU(gated recurrent unit)(Chao,et al. 2014).这里我主要推导一下GRU参数的迭代过程

GRU单元结构如下图所示

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

1479126283494.jpg

数据流过程如下

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

其中GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现表示Hadamard积,即对应元素乘积;下标表示节点的index,上标表示时刻;GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现表示隐层到输出层的参数矩阵,GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现分别是隐层和输出层的节点个数;GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现分别表示输入和上一时刻隐层到更新门z的连接矩阵,GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现表示输入数据的维度;GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现分别表示输入和上一时刻隐层到重置门r的连接矩阵;GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现分别表示输入和上一时刻的隐层到待选状态GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现的连接矩阵。

针对于时刻t,使用链式求导法则,计算参数矩阵的梯度,其中E是代价函数,首先计算对隐层输出的梯度,因为隐层输出牵涉到多个时刻

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

所以

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

其中GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现分别是对应激活函数的线性和部分

现在对参数计算梯度

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

将上面的式子矢量化(行向量)表示:

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现
GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

那接下来使用matlab来实现一个小例子,看看GRU的效果,同样是二进制相加的问题

  1. function error= GRUtest( ) 

  2. % 初始化训练数据 

  3. uNum=16;%单元个数 

  4. maxInt=2^uNum; 

  5. % 初始化网络结构 

  6. xdim=2; 

  7. ydim=1; 

  8. hdim=16; 

  9. eta=0.1; 

  10. %初始化网络参数 

  11. Wy=rand(hdim,ydim)*2-1; 

  12. Wr=rand(xdim,hdim)*2-1; 

  13. Ur=rand(hdim,hdim)*2-1; 

  14. W =rand(xdim,hdim)*2-1; 

  15. U =rand(hdim,hdim)*2-1; 

  16. Wz=rand(xdim,hdim)*2-1; 

  17. Uz=rand(hdim,hdim)*2-1; 


  18. rvalues=zeros(uNum+1,hdim); 

  19. zvalues=zeros(uNum+1,hdim); 

  20. hbarvalues=zeros(uNum,hdim); 

  21. hvalues = zeros(uNum,hdim); 

  22. yvalues=zeros(uNum,ydim); 


  23. for p=1:10000 

  24. aInt=randi(maxInt/2); 

  25. bInt=randi(maxInt/2); 

  26. cInt=aInt+bInt; 

  27. at=dec2bin(aInt)-'0'; 

  28. bt=dec2bin(bInt)-'0'; 

  29. ct=dec2bin(cInt)-'0'; 

  30. a=zeros(1,uNum); 

  31. b=zeros(1,uNum); 

  32. c=zeros(1,uNum); 

  33. a(1:size(at,2))=at(end:-1:1); 

  34. b(1:size(bt,2))=bt(end:-1:1); 

  35. c(1:size(ct,2))=ct(end:-1:1); 

  36. xvalues=[a;b]'; 

  37. d=c'; 


  38. % 前向计算 

  39. rvalues(1,:)=sigmoid(xvalues(1,:)*Wr); 

  40. hbarvalues(1,:)=outTanh(xvalues(1,:)*W); 

  41. zvalues(1,:)=sigmoid(xvalues(1,:)*Wz); 

  42. hvalues(1,:)=zvalues(1,:).*hbarvalues(1,:); 

  43. yvalues(1,:)=sigmoid(hvalues(1,:)*Wy); 

  44. for t=2:uNum 

  45. rvalues(t,:)=sigmoid(xvalues(t,:)*Wr+hvalues(t-1,:)*Ur); 

  46. hbarvalues(t,:)=outTanh(xvalues(t,:)*W+(rvalues(t,:).*hvalues(t-1,:))*U); 

  47. zvalues(t,:)=sigmoid(xvalues(t,:)*Wz+hvalues(t-1,:)*Uz); 

  48. hvalues(t,:)=(1-zvalues(t,:)).*hvalues(t-1,:)+zvalues(t,:).*hbarvalues(t,:); 

  49. yvalues(t,:)=sigmoid(hvalues(t,:)*Wy);  

  50. end 


  51. % 误差反向传播 

  52. delta_r_next=zeros(1,hdim); 

  53. delta_z_next=zeros(1,hdim); 

  54. delta_h_next=zeros(1,hdim); 

  55. delta_next=zeros(1,hdim); 


  56. dWy=zeros(hdim,ydim); 

  57. dWr=zeros(xdim,hdim); 

  58. dUr=zeros(hdim,hdim); 

  59. dW=zeros(xdim,hdim); 

  60. dU=zeros(hdim,hdim); 

  61. dWz=zeros(xdim,hdim); 

  62. dUz=zeros(hdim,hdim); 


  63. for t=uNum:-1:2 

  64. delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:)); 

  65. delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:)); 

  66. delta_z=delta_h.*(hbarvalues(t,:)-hvalues(t-1,:)).*diffsigmoid(zvalues(t,:)); 

  67. delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)); 

  68. delta_r=hvalues(t-1,:).*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:)); 


  69. dWy=dWy+hvalues(t,:)'*delta_y; 

  70. dWz=dWz+xvalues(t,:)'*delta_z; 

  71. dUz=dUz+hvalues(t-1,:)'*delta_z; 

  72. dW =dW+xvalues(t,:)'*delta; 

  73. dU =dU+(rvalues(t,:).*hvalues(t-1,:))'*delta ; 

  74. dWr=dWr+xvalues(t,:)'*delta_r; 

  75. dUr=dUr+hvalues(t-1,:)'*delta_r; 


  76. delta_r_next=delta_r; 

  77. delta_z_next=delta_z; 

  78. delta_h_next=delta_h; 

  79. delta_next =delta; 


  80. end 


  81. t=1; 

  82. delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:)); 

  83. delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:)); 

  84. delta_z=delta_h.*(hbarvalues(t,:)-0).*diffsigmoid(zvalues(t,:)); 

  85. delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)); 

  86. delta_r=0.*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:)); 


  87. dWy=dWy+hvalues(t,:)'*delta_y; 

  88. dWz=dWz+xvalues(t,:)'*delta_z; 

  89. dW =dW+xvalues(t,:)'*delta; 

  90. dWr=dWr+xvalues(t,:)'*delta_r; 


  91. Wy = Wy-eta*dWy; 

  92. Wr = Wr-eta*dWr; 

  93. Ur = Ur-eta*dUr; 

  94. W = W -eta*dW; 

  95. U = U-eta*dU; 

  96. Wz = Wz-eta*dWz; 

  97. Uz = Uz-eta*dUz; 

  98. error = (norm(yvalues-d,2))/2.0; 

  99. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 

  100. if mod(p,500)==0 

  101. fprintf('******************第%s次迭代****************\n',int2str(p)); 

  102. yvalues=round(yvalues(end:-1:1)); 

  103. y=bin2dec(int2str(yvalues')); 

  104. fprintf('y=%d\n',y); 

  105. fprintf('c=%d\n',cInt); 

  106. fprintf('样本误差:e=%f\n',error); 

  107. end 

  108. end 

  109. end 


  110. function f=sigmoid(x) 

  111. f=1./(1+exp(-x)); 

  112. end 


  113. function fd = diffsigmoid(f) 

  114. fd=f.*(1-f); 

  115. end 


  116. function g=outTanh(x) 

  117. g=1-2./(1+exp(2*x)); 

  118. end 


  119. function gd=diffoutTanh(g) 

  120. gd=1-g.^2; 

  121. end 

部分实验结果

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

1479392393541.jpg

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现的更多相关文章

  1. Gated Recurrent Unit (GRU)

                                   Gated Recurrent Unit (GRU) Outline                             Backgr ...

  2. Gated Recurrent Unit (GRU)公式简介

    update gate $z_t$: defines how much of the previous memory to keep around. \[z_t = \sigma ( W^z x_t+ ...

  3. pytorch_SRU(Simple Recurrent Unit)

    导读 本文讨论了最新爆款论文(Training RNNs as Fast as CNNs)提出的LSTM变种SRU(Simple Recurrent Unit),以及基于pytorch实现了SRU,并 ...

  4. Simple Recurrent Unit,单循环单元

    SRU(Simple Recurrent Unit),单循环单元 src/nnet/nnet-recurrent.h 使用Tanh作为非线性单元 SRU不保留内部状态 训练时,每个训练序列以零向量开始 ...

  5. php网页,想弹出对话框, 消息框 简单代码

    php网页,想弹出对话框, 消息框 简单代码 <?php echo "<script language=\"JavaScript\">alert(\&q ...

  6. C&num; 客服端上传文件与服务器器端接收 (简单代码)

    简单代码: /*服务器端接收写入 可以实现断点续传*/ public string ConnectUpload(string newfilename,string filepath,byte[] fi ...

  7. Redis:安装、配置、操作和简单代码实例&lpar;C语言Client端&rpar;

    Redis:安装.配置.操作和简单代码实例(C语言Client端) - hj19870806的专栏 - 博客频道 - CSDN.NET Redis:安装.配置.操作和简单代码实例(C语言Client端 ...

  8. 1 go 开发环境搭建与简单代码实现

    什么是go语言 go是一门并发支持,垃圾回收的编译型 系统编程语言,旨在创造一门具有静态编译语言的高性能和动态语言的高效开发之间拥有一个良好平衡点 的一门编程语言. go有什么优点? 自动垃圾回收机制 ...

  9. 使用WinSCP进行简单代码文件同步

    前言传输协议FTPFTPSSFTPSCP为什么使用WinSCP?CMD的FTP命令FileZillaPuTTYrsyncSublime的SFTP插件WinSCPWinSCP进行简单代码文件同步总结备注 ...

随机推荐

  1. 【C&num;】MVC项目中搭建WebSocket服务器

    前言 因为项目需要,前端页面中需要不断向后台请求获取一个及一个以上的状态值.最初的方案是为每个状态值请求都建立一个定时器循环定时发起Ajax请求,结果显而 易见.在HTTP1.1协议中,同一客户端浏览 ...

  2. >Python下使用subprocess中文乱码的解决方案

    # -*- coding: CP936 -*- import subprocess cmd="cmd.exe" begin=101 end=110 while begin<e ...

  3. &lbrack;转&rsqb;WinExec、ShellExecute和CreateProcess及返回值判断方式

    [转]WinExec.ShellExecute和CreateProcess及返回值判断方式 http://www.cnblogs.com/ziwuge/archive/2012/03/12/23924 ...

  4. 【Mood-18】github 使用指南

    windows下使用教程: http://www.cnblogs.com/dongdong230/p/4211221.html repository not found error问题解决(需确定gi ...

  5. js的内置对象

    转载: https://www.cnblogs.com/liuluteresa/p/6413988.html   在js里,一切皆为或者皆可以被用作对象.可通过new一个对象或者直接以字面量形式创建变 ...

  6. C&num;工具&colon;ASP&period;NET MVC生成图片验证码

    1.复制下列代码,拷贝到控制器中. #region 生成验证码图片 // [OutputCache(Location = OutputCacheLocation.None, Duration = 0, ...

  7. Django——发送邮件

    Django--发送邮件 在web应用中,服务器对客户发送邮件来通知用户一些信息,可以使用邮件来实现. Django中提供了邮件接口,使我们可以快捷的建设一个邮件发送系统. 以下是一个简单实例: se ...

  8. kafka TimeoutException 超时问题解决

    1.报错:: java.util.concurrent.ExecutionException: org.apache.kafka.common.errors.NotLeaderForPartition ...

  9. 2018-2019-2 20165209 《网络对抗技术》Exp7: 网络欺诈防范

    2018-2019-2 20165209 <网络对抗技术>Exp7: 网络欺诈防范 1 基础问题回答和实验内容 1.1基础问题回答 (1)通常在什么场景下容易受到DNS spoof攻击. ...

  10. Kettle数据源连接配置

    说明: 通过(图3.1)我们可以看到创建数据源时需要配置相应的参数: Connection Name(必填):配置数据源使用名称,如:Rot_Source Host Name(必填):数据库主机IP地 ...