稀疏DNC(一)

稀疏访存技术

CYQ October 28, 2019 views 3702 words
\[\newcommand{\bm}[1]{\boldsymbol{#1}}\]

DeepMind又水文章啦!

在提出DNC之后不久,发现效率太低的这帮人紧跟着对NTM,DNC中的一些共性操作做了稀疏化。单个操作速度快了\(1000\times\),使用的内存少了\(3000\times\)。新模型发表在2016年的NIPS1上。加速的最大意义是使得模型可以处理更大的,贴近现实的内存大小,从而使得记忆增强神经网络(MANN)技术接近实用。

The drastic efficiency and performance improvements reported in the present paper are likely to mean that the memory augmented neural networks are now ready to be used in practice.

Reviewer 5

不过这篇文章正文只讲了稀疏访存技术,而没有讲这些技术是如何直接运用在DNC上的。正文中的实验是将稀疏访存与稠密访存、NTM仅在速度上做了对比。具体的Sparse-DNC结构放在了补充材料里,下一篇文章会进行介绍。

稀疏访存

稀疏访存(Sparse Access Memory),或称为SAM的提出是为了解决DNC中访存开销过大的问题。

原始的DNC以及更早的NTM都存在一个问题,即:为了使得将通常意义下的访存这一离散的操作连续化,每次访存实际上访问了整个内存的内容。这在内存较小时尚可无视,但是复杂的任务往往需要更大的存储,这一开销就变得不可忽视。当内存条目数\(N=64000\)时,占用的物理内存达到了29GiB,这显然是不可接受的。

基于此,本文的正文提取了这些模型中共有的一部分稠密访存操作,将这些操作进行了稀疏化,并且比较了速度和内存占用。

为了方便起见,以下所有的稀疏化之后的变量都带有~标志,如稀疏化的\(w\)即为\(\tilde{w}\)

读的稀疏化

读操作中一个重要步骤是根据权重得到一个内存个条目的加权和。

\[r=\sum_{i=1}^Nw(i)\bm{M}(i)\]

为了尽可能降低计算量,我们只取权值最大的\(K\)个内存条目\(s_1,\dots,s_k\)。而将其它的所有分量全部置为0。

\[\tilde{r}_t=\sum_{i=1}^K\tilde{w}_t^R(s_i)\bm{M}_t(s_i)\]

具体来说,为了使\(\tilde{r}_t\approx r_t\),简单的办法是遍历\(w_t^R\),将前\(K\)大的分量找出来。但遍历这一步的复杂度是\(O(N)\)。

注意到\(w_t^R\)实际上是根据某个query(如key vector) \(q_t^R\)与内存各条目的相似度过softmax(或别的某种单调函数)后得到。换言之,我们可以从前一步相似度的计算就开始入手,转化为寻找\(q_t^R\)的\(K-\)最近邻问题。这个问题有成熟的一系列算法可以解决:近似最近邻数据结构(approximate nearest neighbor data-structure)。使用之后复杂度至多为\(O(\log(N))\)。

除此之外,还使用了矩阵的行稀疏表示方法:压缩稀疏行(Compressed Sparse Rows, CSR)使得上式的空间时间复杂度都降到常数级别。

写的稀疏化

写操作相较读操作要更复杂,除了同样使用了相似度寻址,还引入了内存的释放等机制。这个机制涉及维护每个内存位置的usage信息,基于该信息找到被使用最少或者空闲的内存位置进行写入。

写入操作可以简单地分为三步:

  • 决定写入的位置
  • 擦除内存中的原有内容
  • 写入新内容
\[\bm{M}_t\leftarrow(\bm{1}-\bm{R}_t)\odot\bm{M}_{t-1}+\bm{A}_t\]

这里因为是为了介绍稀疏化的技术,没有采用DNC或者NTM中的权值生成方式,而是直接使用上一时刻的读权重。

寻址机制

设\(t\)时刻记录各位置usage的向量为\(U_t\),那么写权重可以表达为下式

\[w_t^W=\alpha_t(\gamma_tw_{t-1}^R+(1-\gamma_t)\mathbb{I}_t^U)\]

根据读权重\(w_{t-1}^R\)的稀疏性,式子中的这一项可以稀疏表示。

上式中\(\mathbb{I}_t^U\)不记录所有的usage,而是只考虑usage最小的内存位置,变成一个one-hot向量。

\[\mathbb{I}_t^U(i)= \begin{cases} 1 & \text{if}\quad U_t(i)=\min_{j=1,\dots,N}U_t(j)\\ 0 & \text{otherwise} \end{cases}\]

也就是说,写操作的目标可能是

  • 上一时刻的读位置
  • 空闲的位置

两者的加权和。这样做存在一个明显问题:直接写在上一时刻的读位置显然不合理,除非读过的位置内容立马失效了。如何能够知道刚读过的位置里面的内容是否失效了呢?这就需要usage的信息。

usage

\(U_t\)的计算尝试了不同于DNC的,两种新的usage计算机制。

第一种需要稠密访存(DAM),使用较易理解的时间折扣(time-discounted)机制

\[U_T^{(1)}(i)=\sum_{t=0}^T\lambda^{T-t}(w_t^W(i)+w^R_t(i))\]

其中\(\lambda\)是折扣因子。这个式子的含义是,如果一个位置很久没有读操作没有写操作就应该是空闲的。相比原始DNC中读的强度越大,usage反而越低的设定,我倒是觉得更合理。

第二种是稀疏访存(SAM)的:

\[U_T^{(2)}(i)=T-\max\{t:w^W_t(i)+w^R_t(i)>\delta\}\]

考虑的是最近的一次显著读或者写究竟过去了多久

稀疏化之后,完整的写操作无论是forward还是backward都是\(O(1)\)时间和空间复杂度,证明仍然被放在了补充材料里。

稀疏数据结构

近似最近邻数据结构

读操作中的近似最近邻数据结构(approximate nearest neighbor data-structure, ANN)尝试了FLANN实现的随机k-d树算法2局部敏感哈希(LSH) 。对于较小的字长使用前者,较大的使用后者。这两种ANN每次插入、删除、查询的复杂度都是\(O(\log N)\)。为了避免不平衡的数据结构,每\(N\)次插入之后会重建ANN。

压缩稀疏行

CSR的思路非常简单,用一个val数组从左到右,从上到下顺序存储所有的非零元素就可以了。

为了能够重建矩阵,我们当然还要存储非零元素的位置。因此用一个与val等长的数组col_index记录这些非零元的列数。并且用一个指针数组 row_ptr 记录每行第一个非零元在val中的位置。 row_ptr 最后一个元素为val的长度加1。

\[\bm{A}= \begin{bmatrix} 4 & 0 & 5\\ 0 & 0 & 11\\ 0 & 12 & 0 \end{bmatrix}\]

对应的CSR为

val col_index row_ptr
4 1 1
5 3 3
11 3 4
12 2 5

设矩阵有\(N\)行,非零元个数为\(n_{nz}\)。占用空间大小为\(2n_{nz}+N+1\)。

参考文献

  1. Rae, J., Hunt, J. J., Danihelka, I., Harley, T., Senior, A. W., Wayne, G., … & Lillicrap, T. (2016). Scaling memory-augmented neural networks with sparse reads and writes. In Advances in Neural Information Processing Systems (pp. 3621-3629). 

  2. Muja, M., & Lowe, D. G. (2014). Scalable nearest neighbor algorithms for high dimensional data. IEEE transactions on pattern analysis and machine intelligence, 36(11), 2227-2240.