Basically how Torch does it in the background. Optimizing the operations.
Alex first gave the example of transistor grids with 1000 nanometers each square represents a 20 nanometer transistor. He also went on base-2 representation of numbers. With the Sign | Exponent | Significand and the precisions.
The IEEE standard is half precision float (16 bits) (1 | 5 | 10). We can represent values in the range of .
Brain half precision float: Preserve 32-bit dynamic range reduce fraction/significand precision. So you can dynamically assign different number of bits to exponent and significand. We can represent more numbers, but with smaller precision. (1 | 8 | 7).
With Torch, it’s better to use Brain half precision float.
Hardware Improvements
Here he started with the parallelism from NVIDIA Streaming Multiprocessors (SMs) where each cell is a computational unit in itself. I.e. the GPU.
L1 cache (i.e., on-chip SRAM) — responsible for fast (low memory) computations
L2 cache — shared memory accessible by all SMs - higher latency due to physically away from compute units.
DRAM (High Bandwidth Memory) not on the GPU die. Separate silicone connected with wires to the GPU die. This is basically the memory slots we have.
Attention in these terms:
- Read from DRAM
- Write back to DRAM
- Read to DRAM
- Compute
- Write it back to DRAM
- Read from DRAM
- Write O back to DRAM
Which is not very efficient, no? Because it’s sequential. That’s why we use L1 to parallelize it. But they are very small, only some kB. So we must optimize the process.
TILING
Let’s say we have 2 matrices and we want to do the dot product. In total we have complexity. Tiling is about selecting smaller segments and computing those. See the slides, it’s quite smart.
The complexity now decreases to since we load blocks from memory.
So for before, and will be computed with tiling, but for we will need the sum of the individual row.
is not numerically stable in binary floating point. The fix is stable softmax, which is removing the max of from the row sequence. You scale the values to the negative, the max will be 0, and by applying the exponential, we scale into a much smaller range than before. He lost me at and . It’s about optimizing the second loop of . Finally it comes down to a recursive formulation. Apparently you can also include the computation of in these loops.
There is also flash attention which I see that it gets the job done in only one loop .
KV caching is storing the keys and values in the cache memory. Required memory:
. This is where experts make the difference. They understand these things when designing neural networks. For the entire formula, the one thing we can control easily is the number of heads, i.e. Multi-Query Attention (MQA) which means use only one head for the keys and the values (validate this).
With MHA we might need 4MB for KV cache, and with MQA we need only 31kB (128 reduction).
Can we find a better balance?
yes. Use batches, i.e. Group Query Attention (GQA). Very popular.
Try to understand the Multi-Head Latent Attention (MLA) where you break the process in two (compress and up-project everything).
See the slide with the comparison between what each technique stores (kB, MB). With MLA it even works better than MHA (so this is the most efficient way of doing it).
MLA is the new norm, not MHA.
It is both faster and better in results. Since 1-2 years.
State matrix . Again, recursion. Insert the slide he has on the computations. Again, smart. This is Linear Attention which replaces Softmax Attention.
But linearity includes challenges. There’s no real parallelism possible since each state comes as keys and values come along. There’s also a lot of cost in updating the states (high I/O costs for state updates).
Chunk-wise parallelism: instead of computing every single state, we can just compute individual steps. The first chunk could go from S0 to S3 and compute the instant in parallel. We can just do a vector multiplication in this case . Hopefully it makes sense. And then we go from 3 to 6 and so on.
Next he talks about moving the key vector close to the value vector through and then compute the loss function with SDG update at t. See delta update rule for linear attention, that’s how he formulates it.
RMSNorm
look again over this and understand the concept which constrains to a unit sphere.
See OLMoE to understand the experts explanation he gave. From 2024 (GPT 3.5 nano) everything is a mixture of experts.