LSTM Explained
在时序相关或者文本相关的任务中,人们常常会使用
RNN-Based
的序列模型,本文将主要介绍LSTM
涉及到的一些知识点。主要参考几篇优秀的博文,链接将放在最后的Reference中,感兴趣的建议直接阅读原文🚘。
为了避免Vanilla RNN
在输入长文本后带来的梯度消失问题,研究者提出了LSTM unit
,即long-short term memory unit
,使用它们作为整个序列模型中的最小单元。
在LSTM unit
中共存在两种state
,一类是hidden state
,另一类是cell state
。同时,引入了门机制
,分别为Forget Gate
、Update Gate/Input Gate
以及Output Gate
。在每一个time step
时,模型通过上述的门机制
决定应该存放哪些信息,同时过滤掉哪些信息。
首先,为了阐明LSTM unit
的计算机制,先对其内部架构进行一定的了解,如下图所示:
通过Fig 1
,我们可以预览到LSTM
传播时数据的大致流通过程。更近一步的,通过以下动图可以一目了然地观察到数据是怎么在一个unit
中流动的。
可以清晰地看到,在时间步t
时,unit
接收来自时间步t-1
的$c_{t-1}$与$h_{t-1}$,经过若干次的线性变化以及point-wise操作后,输出c_t
与h_t
。下面将具体介绍每个门机制的运行过程以及个人对于该设计的理解。
首先,我们可以将内部计算中涉及到门计算的部分分为以下几种。
其中:
- 最左边的部分为
Forget Gate
; - 中间的部分为
Input Gate或Update Gate
; - 最右边的部分为
Output Gate
.
Forget Gate
对于上一个时间步传递的
cell state
信息,遗忘门
决定哪些信息需要被继续保持,哪些信息应被遗忘。
假设当前所在的时间步为t
,记当前forget gate
的输入为$X_t$,$c_{t-1}$以及$h_{t-1}$,输出为$C_{tmp}$,根据上面动图中的计算,其相应的计算伪代码如下:
- 将$X_t$与$h_{t-1}$进行
concatenation
,得到$[h_{t-1}, X_{t}]$; - 通过$W_{forget}$与$b_{forget}$进行线性转换,再通过
sigmoid
将计算结果转换到[0, 1]
区间; - 最后,将上述输出的
概率向量
与$c_{t-1}$中保存的向量做point-wise
乘法。
下面是Forget Gate
的计算公式:
- $Conbine = Concatenation(h_{t-1}, X_t)$
- $Z_f = Sigmoid(W_{forget} Conbine + b_{forget})$
- $c_{tmp} = c_{t-1} * Z_f$
通过Forget Gate
的一系列操作,上一个时间步的$c_{t-1}$第一次得到了更新,得到了$c_{tmp}$。我们用它表示在时间步t
输出最终cell state
前的中间值。
我认为遗忘门运作的机制是:将之前保存的信息(存放在$h_{t-1}$)与当前时间步的输入信息($X_t$)进行比较,进而去判断是否对$c_{t-1}$中的某些信息进行遗忘。当sigmoid
输出的向量中某个位置输出的概率值更偏向于0,则说明$c_{t-1}$对应位置上的信息应该被遗忘;而当该位置的概率值更偏向1时,则$c_{t-1}$对应位置上的信息更应该保留。
Update Gate/Input Gate
更新门
需要考虑将哪些新信息保存或者贴加到cell state
中,并且输出到下一个时间步。
在Update Gate
中,unit
决定什么信息
要被更新到$c_{tmp}$中。这组门运算的输入包括:$h_{t-1}, X_t, c_{tmp}$,输出为$c_{t}$:
- 第一步还是同
Forget Gate
中形式的一样,先将$X_t$与$h_{t-1}$进行concatenation
,得到$[h_{t-1}, X_{t}]$; - 接下来,分别进行两次线性变换,但是各自
activation function
不同,一个为sigmod
,而另一个则为tanh
; - 最后,将两个结果进行
point-wise
的乘法,得到Update Gate
的输出。
同样的,下面是Update Gate
的计算公式:
- $Conbine = Concatenation(h_{t-1}, X_t)$
- $Z_u = Sigmoid(W_{update1} Conbine + b_{update1})$
- $O_u = Tanh(W_{update2} Conbine + b_{update2})$
- $c_{t} = c_{tmp} + Z_u * O_u$
更新门
中共存在两组线性转化
的参数:$(W_{update1}, b_{update1})$以及$(W_{update2}, b_{update2})$,各自的结果转化之后相乘,再加到$c_{tmp}$上,最后得到当前时间步将要输出的cell state
。
Output Gate
输出门
的作用在于结合各种信息,在当前的时间步做出决策,同时为下一个时间步的计算提供信息。
除去上述的两组,unit
中还剩下一组门运算,即输出门
。它的输入包括$h_{t-1}, X_t, c_{t}$,最后的输出为$h_t$:
- 第一步依旧是拼接$X_t$与$h_{t-1}$,得到$[h_{t-1}, X_{t}]$;
- 通过$W_{output}$与$b_{output}$进行线性转换,再通过
sigmoid
将计算结果转换到[0, 1]
区间; - 将$c_t$通过
tanh
运算; - 最后将2、3步的输出做
point-wise
的乘法,得到Output Gate
的输出,也就是$h_t$。
同样,下面是Output Gate
的计算公式:
- $Conbine = Concatenation(h_{t-1}, X_t)$
- $Z_o = Sigmoid(W_{output} Conbine + b_{output})$
- $O_o = Tanh(c_t)$
- $h_t = Z_o * O_o$
总结
到此,一个LSTM unit
中的计算流程已经走完,下面我想谈一下个人对于其中几点的理解。首先是出现了两种激活函数sigmoid
以及tanh
。其中,sigmoid
共出现了3次,分别在三个门中都出现;tanh
则是出现在Update Gate
以及Output Gate
中。
sigmoid
能将数值缩放至[0, 1]区间,它在LSTM unit
中的使用都是与其他向量做point-wise
的乘法,所以我认为使用它主要是为了使unit
在处理时能够将信息进行保存或者遗忘。
tanh
能将数值缩放至[-1, 1]区间,它在LSTM unit
中被使用时基本上是为了将信息
转化为数值
,在Update Gate
中,该数值会被记录到cell state
中,在Output Gate
中,该数值会被转化为下一个时间步的hidden state
。