文章目录

前面讲了,在普通的RNN中会有长依赖问题,而造成长依赖问题的原因时梯度消失或爆炸,所以这节分析一下RNN造成梯度消失和爆炸的原因。

首先我们看看RNN的结构,如图
rnn结构
这儿我们将t时刻赋值为2,所以对应的t-1、t、t+1分别为1、2、3,根据图我们可以写出以下表达式:
$S_{1} = W_{s}S_{0}+W_{x}X_{1}+b_{1}$ $O_{1} = W_{o}S_{1}+b_{2}$
$S_{2} = W_{s}S_{1}+W_{x}X_{2}+b_{1}$ $O_{2} = W_{o}S_{2}+b_{2}$
$S_{3} = W_{s}S_{2}+W_{x}X_{3}+b_{1}$ $O_{3} = W_{o}S_{3}+b_{2}$

假设t=3时刻,损失函数为$L_{3} = \frac{1}{2}(Y_{3}-O_{3})^{2}$。
对于一次训练的损失函数$L= \sum_{t=0}^{k}L_{t}$,即每一时刻损失函数的累加。
使用梯度下降就是对$W_{s}、W_{x}、W_{o}、b_{1}、b_{2}$求偏导,使L更小。所以在t=3时,根据上面的损失函数$L_{3} = \frac{1}{2}(Y_{3}-O_{3})^{2}=\frac{1}{2}(Y_{3}-(W_{o}S_{3}+b_{2}))^{2}$。根据链式求导法则对$W_{s}、W_{x}、W_{o}、b_{1}、b_{2}$的偏导为:
$\frac{\partial L_{3}}{\partial b_{1}}=0$

$\frac{\partial L_{3}}{\partial b_{2}}=0$

$\frac{\partial L_{3}}{\partial W_{o}} = \frac{\partial L_{3}}{\partial O_{3}}\frac{\partial O_{3}}{\partial W_{o}}$

$\frac{\partial L_{3}}{\partial W_{x}} = \frac{\partial L_{3}}{\partial O_{3}}\frac{\partial O_{3}}{\partial S_{3}}\frac{\partial S_{3}}{\partial W_{x}}+\frac{\partial L_{3}}{\partial O_{3}}\frac{\partial O_{3}}{\partial S_{3}}\frac{\partial S_{3}}{\partial S_{2}}\frac{\partial S_{2}}{\partial W_{x}}+\frac{\partial L_{3}}{\partial O_{3}}\frac{\partial O_{3}}{\partial S_{3}}\frac{\partial S_{3}}{\partial S_{2}}\frac{\partial S_{2}}{\partial S_{1}}\frac{\partial S_{1}}{\partial W_{x}}$

$\frac{\partial L_{3}}{\partial W_{s}} = \frac{\partial L_{3}}{\partial O_{3}}\frac{\partial O_{3}}{\partial S_{3}}\frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}}\frac{\partial O_{3}}{\partial S_{3}}\frac{\partial S_{3}}{\partial S_{2}}\frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}}\frac{\partial O_{3}}{\partial S_{3}}\frac{\partial S_{3}}{\partial S_{2}}\frac{\partial S_{2}}{\partial S_{1}}\frac{\partial S_{1}}{\partial W_{s}}$

所以根据上面几个推导可以看出,对$W_{0}$求偏导时,没有常识依赖,但是对$W_{x}、W_{s}$求偏导时,会随着序列的变成长,依赖变长。
对于$W_{s}、W_{o}$的偏导结构相同,所以可以简写为:
$\frac{\partial L_{t}}{\partial W_{s}} = \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial O_{t}}\frac{\partial O_{t}}{\partial S_{t}}(\prod_{j=k+1}^{t}\frac{\partial S_{j}}{\partial S_{j-1}})\frac{\partial S_{k}}{\partial W_{s}}$

$\frac{\partial L_{t}}{\partial W_{o}} = \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial O_{t}}\frac{\partial O_{t}}{\partial S_{t}}(\prod_{j=k+1}^{t}\frac{\partial S_{j}}{\partial S_{j-1}})\frac{\partial S_{k}}{\partial W_{o}}$

当用tanh最为激活函数时,
$S_{j} = tanh(W_{x}X_{j}+W_{s}S_{j-1}+b_{1})$
所以:
$\prod_{j=k+1}^{t}\frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t}tanh’W_{s}$
tanh函数和它的导函数图像如图:

tanh函数本身取值范围为(-1,1),导函数取值为[0,1],仅在x=0时,导函数值为1,所以很少情况出现$W_{x}X_{j}+W_{s}S_{j-1}+b_{1} =0$ 如果$W_{s}$是一个大于0小于1的值,那么在t很大时$\prod_{j=k+1}^{t}\frac{\partial S_{j}}{\partial S_{j-1}}W_{s}$就会趋近于0,而出现梯度消失,若是当$W_{s}$很大时,$tanh’W_{s}$大于0,则会使$\prod_{j=k+1}^{t}\frac{\partial S_{j}}{\partial S_{j-1}}W_{s}$变得很大。

文章目录