Question
Computational Complexity of Self-Attention in the Transformer Model
I recently went through the Transformer paper from Google Research describing how self-attention layers could completely replace traditional RNN-based sequence encoding layers for machine translation. In Table 1 of the paper, the authors compare the computational complexities of different sequence encoding layers, and state (later on) that self-attention layers are faster than RNN layers when the sequence length n
is smaller than the dimension of the vector representations d
.
However, the self-attention layer seems to have an inferior complexity than claimed if my understanding of the computations is correct. Let X
be the input to a self-attention layer. Then, X
will have shape (n, d)
since there are n
word-vectors (corresponding to rows) each of dimension d
. Computing the output of self-attention requires the following steps (consider single-headed self-attention for simplicity):
- Linearly transforming the rows of
X
to compute the queryQ
, keyK
, and valueV
matrices, each of which has shape(n, d)
. This is accomplished by post-multiplyingX
with 3 learned matrices of shape(d, d)
, amounting to a computational complexity ofO(n d^2)
. - Computing the layer output, specified in Equation 1 of the paper as
SoftMax(Q Kt / sqrt(d)) V
, where the softmax is computed over each row. ComputingQ Kt
has complexityO(n^2 d)
, and post-multiplying the resultant withV
has complexityO(n^2 d)
as well.
Therefore, the total complexity of the layer is O(n^2 d + n d^2)
, which is worse than that of a traditional RNN layer. I obtained the same result for multi-headed attention too, on considering the appropriate intermediate representation dimensions (dk
, dv
) and finally multiplying by the number of heads h
.
Why have the authors ignored the cost of computing the Query, Key, and Value matrices while reporting total computational complexity?
I understand that the proposed layer is fully parallelizable across the n
positions, but I believe that Table 1 does not take this into account anyway.