03/25/2024-03/31/2024
Self-Attention Does Not Need O(N^2) Memory
Key Idea
- Split Q, K, V along batch_dim, seq_dim, head_dim, having Q, K, V as 1D vector as [feature_dim].
- Q@K produces a scalar S. S@V produces a vector V.
- Delay the softmax operation by updating scalar S and vector V after all K, V vectors.
- Correspondingly, handling self-attention requires [batch_dim, seq_dim, feature_dim] for each head. Instead of [batch_dim, seq_dim, seq_dim]. When seq_dim >> feature_dim, it is a massive memory saving.
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
State Space Model
- Discretization.
- Transforms continuous parameters into discrete parameters.
- Computation.
- Global convolution for parallelizable training.
- Linear recurrent for autoregressive inference.
- Linear Time Invariant (LTI)
- Model dynamics are constant throughout time.
- Limitations to other SSMs.
- Maps 1-D sequence into an implicit laten state.
- Success Models
- Linear Attention -> H3 -> Hyena -> RetNet -> RWKV
Selective State Space Model
- Algorithm
- LTI (RNN/CNN) are limited by constant states.
- -> To improve the performance, we should be selective on the states.
- -> Let the parameters affecting sequence interactions to be input dependent
- Implementation
- Materialize hidden state on SRAM
Leave a comment