Published on

Notes on Learning to Solve Routing [KHW-19]

Paper link, super clearly written...

Graph attention recap

For simplicity, we omit the layer index. Unless stated otherwise, the discussion pertains to one layer of a GNN.

Single-head attention. Let hih_i denote the embedding of a vertex ii. Define dkd_k and dvd_v to be the dimensions of keys and values, respectively. Let dhd_h denote the dimension of hih_i.

In graph attention, the new embedding in the next layer is computed as follows:

hi=σ(jN(i)aijvj)h_i' = \sigma \left(\sum_{j \in N(i)} a_{ij} v_{j} \right)

where

  1. N(i)N(i) is the set of neighbors of ii; σ\sigma is a nonlinearity.

  2. aij[0,1]a_{ij} \in [0, 1] is the attention weight.

  3. vjRdvv_{j} \in \mathbb{R}^{d_v} is the value of vertex jj.

Formally, aija_{ij} and vjv_j are computed as follows. Let WQW^{Q} (dk×dhd_k \times d_h), WKW^{K} (dk×dhd_k \times d_h) and WVW^{V} (dv×dhd_v \times d_h) be the learnable parameters. For a vertex ii, its query, key, and value are computed by projecting hih_i:

qi=WQhi,  ki=WKhi,  vi=WVhiq_i = W^Q h_i, \; k_i = W^K h_i, \; v_i = W^V h_i

For a neighbor jj of ii, the weight aija_{ij} is the output of a softmax over the set {uij:jN(i)}\{u_{i j'} : j' \in N(i) \}, where the compatibility uijRu_{ij} \in \mathbb{R} is computed as follows:

uij=qiTkjdku_{ij} = \frac{q^T_i k_j}{\sqrt{d_k}}

Multi-head attention. Let MM be the number of heads, and we compute hih_i' independently MM times, resulting a set of MM embeddings. Importantly, each head has its own set of learnable projection matrices. Following this, one can set dv=dk=dh/Md_v = d_k = d_h / M, then do concatenation over this set, followed by another project. One can also do averaging if we use multi-head attention in the final prediction layer.