← Back to Napkin Math

jax 1

rooflines

We need to do 1e12 FLOPs, and the GPU can manage 9.89e14 FLOPs/s

Next, say the problem needs to process 1e9 bytes, and the GPU can process 1e10 bytes/s

Now for the program, we can calculate the FLOPs/byte (arthmetic intensity):

Finally, we can compare this to our computation and memory speeds:

Lastly, we can compare the arthemtic intensity directly:

example

dot product:

x * y: bf16[N], bf16[N] -> bf16[1]

Need to load 2 * 2 bytes per element, for a total of 4 * N bytes.

And we need 2 * N FLOPs

So arthemtic intensity is

Given peak FLOPS/s of TPU v5e MXU is 240 FLOPs / byte, our program is memory-bound.

Intuition:

Another idea:

matrix multiplication

X * Y -> Z: bf16[B, D], bf16[D, F] -> bf16[B, F]

Load

Each element of Z is dot product of row of X and column of Y

Arthemtic intensity:

So if B > 240, then the program is compute-bound.

shard matrix multiplication

X * Y -> Z: bf16[B, D], bf16[D, F] -> bf16[B, F]

We split along the D dimension: the columns of X and the rows of Y:

The output of each is a partial sum of the output matrix, so we still need to copy the partial sum to GPU 2 and add them together.

Let’s find T_{math}:

For T_{comms}:

Arthemtic intensity of network for TPU

Compute bound if T_{comms} < T_{math}:

Now we see sharding if compute bound given D instead of B.

matrix dimensions

Let’s take a step back: what is B the batch size?

Each row is an sample. If we were doing image classification, each row would be an image.

D then is the features of the sample. For a raw image, D would be all the pixels.

LLMs are a bit different. Take this for example:

"Hello world!" → 
["Hello", "world", "!"] → 
[457, 1234, 89] → 
Three vectors of dimension D

So for LLMs, we might actually be looking at something like this, for a batch size of 1, and S is the sequence length.

X * Y -> Z: bf16[B, S, D], bf16[D, F] -> bf16[B, S, F]
X' * Y' -> Z': bf16[S, D], bf16[D, F] -> bf16[S, F]

Now we can see that when sharing efficiently, we actually want a larger embedding dimension to become compute bound.

This is intriguing: larger D means we can fully utilize our hardware and it theoretically gives better model performance. That’s a cool convergence!

Why is it compute bound on D? Well the time for math depends on BDF and the network only depends on BF. So as D increases, so does the computation time, but the network time is fixed.

q1

A[B, D] * B[D, F] -> C[B, F]

with int8 precision, 1 byte per parameter.

(1)

How many bytes need to be loaded from memory? How many need to be written back to memory?

Loaded:

(2)

How many total OPs are performed?

Technically 0 right? These aren’t floating points? Or I guess it’s an 8-bit floating point? Never heard of that tbh.

Each output matrix element is 2D operations. There are BF output elements, so 2BDF.

(3)

What is the arithmetic intensity?

using “B small” approximation:

(4)

Then compute bound if 2B > 486:

Ah, misunderstood this a bit. How can we estimate T_{math}?

So if we want to estimate the total time:

q2

bfloat16[B, D] * int8[D, F] -> bfloat16[B, F]

int8 is for efficient memory and training, but we do a bfloat16 op.

Ops:

Ops:

Arthemtic intensity:

using “B small” approximation:

So compute bound if 2B > 243:

Okay, so let’s think this example through. When we use int8 weights, we are compute bound when B > 120 instead of B > 240.

So what does this mean if we have a program where B = 200. If we had bfloat16 weights, our program would be comm bound, where T_{comms} > T_{math}.

With int8 weights, we are now compute bound, since B = 200 > 120. The T_{math} did not change, since we have up convert int8 to bfloat16 to compute, but our T_{comms} when down. That means the overall execution time when down too.

q3

But doesn’t the roofline model only depend on B?

q4

int8[B, D]_D * int8[B, D, F]_D -> int8[B, F]

Load and store: BD + BDF + BF

For OPs: If we think about the first row of int8[B,D] it’s a 1xD vector and are we multiplying it by a [D,F] matrix, the first “column” of [B, D, F]. That would give an output of a 1xF vector.

Now we do that for B-1 more rows, each with a unique [D,F] matrix, and that gives us a [B,F] output.

The difference being instead of taking the input 1xD row from [B,D] and doing a dot product of Dx1 column to get a scalar, we use a matrix to transform it into a 1xF vector.

Instead of taking each input and mapping to a single number, we use a vector to map it to a different vector. Likewise, another way to think is that we can encode more information in the transformation: instead of D weights, there are DF weights.

So before I do the math, my instinct is that this will be compute bound at a much lower number, since we are doing many more operations per byte.

In particular, for each row:

Ah, actually the Ops are exactly the same! That actually makes sense thinking about it a little more, because we are still doing the same number of operations per each output element. In other words, the compute-bound threshold then should be higher because we have to load more bytes, even though the ops are the same.

Compute bound if:

2BDF / (BD + BDF + BF) = 2DF / (D + DF + F) > 486

This makes sense, because now most of the data is in the DF matrices per element.

From solution:

Since it’s constant, and 2 >/ 486, we are always comms bound. Our TPU would have to have basically the same peak OPs/s as bytes/s to be compute bound, which is impossible.

q5

Let’s look at fp16 Ops, since that’s what we were using:

Arthemtic intensity:

So we are compute bound if B > 214.

For SXM: