A Beginner's Guide to Tensor Parallelism in Autoregressive LLM Inference

1. Motivation

Autoregressive large language models are dominated by dense linear algebra in two parts of every decoder block: the self-attention sublayer and the MLP sublayer. As model size, context length, and batch size increase, a single GPU may not have enough memory bandwidth, compute throughput, or device memory to execute the model efficiently.

Tensor parallelism partitions large projection matrices across a tensor-parallel group of \(p\) GPUs. Each rank computes a shard of the output, and collective communication is used at the points where sharded outputs must be reduced or where a sharded attention context must be made visible.

The purpose of this article is to study tensor parallelism using tensor shapes, matrix products, and communication volumes.

2. Notation

SymbolDefinition
\(B\)Global batch size: number of independent autoregressive sequences.
\(B_q\)Batch size of activation shard \(q\), typically \(B_q=B/p\) for batch-sharded activations.
\(T\)Context length or sequence length.
\(T_q\)Query length. In decode, \(T_q=1\).
\(H\)Residual-stream hidden dimension.
\(I\)MLP intermediate dimension.
\(L\)Number of decoder blocks.
\(\ell\)Decoder-block index.
\(p\)Tensor-parallel degree: number of ranks in one tensor-parallel group.
\(r\)Tensor-parallel rank index, \(r\in\{1,\dots,p\}\).
\(q\)Activation shard index. The shard may be along the batch dimension or sequence dimension.
\(i,j,k\)Matrix-entry indices.
\(b\)Batch index.
\(t,u\)Token-position indices.
\(a\)Attention-head index.
\(h\)Number of query heads.
\(h_{kv}\)Number of key/value heads.
\(d\)Per-head dimension.
\(H_q\)Total query projection dimension, \(H_q=hd\).
\(H_{kv}\)Total key/value projection dimension, \(H_{kv}=h_{kv}d\).
\(s\)Bytes per scalar element. For BF16 or FP16, \(s=2\).
\(X_\ell\)Residual-stream activation at the input of decoder block \(\ell\).
\(\widetilde{X}_\ell\)Residual stream after the self-attention sublayer of decoder block \(\ell\).
\(Q,K,V\)Query, key, and value tensors.
\(S\)Attention score tensor before softmax.
\(P\)Attention probability tensor after softmax.
\(O\)Self-attention output before the output projection.
\(Y\)Self-attention output after the output projection.
\(A\)MLP intermediate activation.
\(Z\)MLP output after the down projection.
\(W_Q,W_K,W_V\)Query, key, and value projection matrices.
\(W_O\)Self-attention output projection matrix.
\(W_g,W_u\)MLP gate and up projection matrices.
\(W_d\)MLP down projection matrix.
\(\phi\)Elementwise activation function.
\(\odot\)Elementwise multiplication.
\(M_{\mathrm{causal}}\)Causal mask for autoregressive self-attention.
\(M_H\)Number of elements in a full hidden activation tensor, \(M_H=BTH\).
\(M_{KV}\)Number of elements in one full key or value tensor, \(M_{KV}=BTH_{kv}\).

3. Linear Algebra and Communication Primitives

3.1 Matrix Multiplication

For

\[ A\in\mathbb{R}^{m\times n}, \qquad B\in\mathbb{R}^{n\times k}, \]

the product \(C=AB\) satisfies

\[ C\in\mathbb{R}^{m\times k}, \qquad C_{ij}=\sum_{k'=1}^{n}A_{ik'}B_{k'j}. \]

3.2 Elementwise Operations

Elementwise multiplication is

\[ (A\odot B)_{ij}=A_{ij}B_{ij}. \]

For an elementwise nonlinearity \(\phi\),

\[ (\phi(A))_{ij}=\phi(A_{ij}). \]

3.3 Matrix Partitioning

Column-wise partitioning across \(p\) tensor-parallel ranks is written as

\[ W=[W_1,W_2,\dots,W_p]. \]

Row-wise partitioning is written as

\[ W= \begin{bmatrix} W_1\\ W_2\\ \vdots\\ W_p \end{bmatrix}. \]

3.4 Causal Mask

The causal mask prevents position \(t\) from attending to future position \(u>t\):

\[ M_{\mathrm{causal}}(t,u)= \begin{cases} 0, & u\leq t,\\ -\infty, & u>t. \end{cases} \]

3.5 Collectives

CollectiveDefinitionPer-rank communication volume in bytes
AllReduce Each rank starts with \(X_r\). All ranks receive \[ X=\sum_{r=1}^{p}X_r. \] \[ 2\frac{p-1}{p}Ns \]
ReduceScatter First reduce \[ X=\sum_{r=1}^{p}X_r, \] then rank \(q\) receives one shard \(X^{(q)}\). \[ \frac{p-1}{p}Ns \]
AllGather Each rank starts with shard \(X^{(q)}\). All ranks receive \[ X=[X^{(1)},X^{(2)},\dots,X^{(p)}]. \] \[ \frac{p-1}{p}Ns \]

Here \(N\) is the number of scalar elements in the full tensor involved in the collective. For hidden activations, \(N=M_H=BTH\). For one key or one value tensor, \(N=M_{KV}=BTH_{kv}\).

3.6 Communication Model

Communication costs are derived assuming communication-optimal implementations of collective communication primitives. Specifically, AllReduce, ReduceScatter, and AllGather are assumed to use recursive-doubling or ring-based algorithms whose communicated volume per rank scales as \((p-1)/p\) times the message size.

Let \(N\) denote the number of communicated elements and let \(s\) denote the number of bytes per element.

The communicated bytes per rank are defined as:

\[ C_{AG}(N,p) = \frac{p-1}{p}Ns \]

\[ C_{RS}(N,p) = \frac{p-1}{p}Ns \]

\[ C_{AR}(N,p) = 2\frac{p-1}{p}Ns \]

where \(C_{AG}\), \(C_{RS}\), and \(C_{AR}\) denote the communicated bytes per rank for AllGather, ReduceScatter, and AllReduce respectively.

4. Autoregressive Inference Phases

4.1 Prefill

In prefill, the full prompt is processed. The residual-stream activation entering a decoder block is

\[ X_\ell\in\mathbb{R}^{B\times T\times H}. \]

The residual activation memory is

\[ \mathrm{Mem}_{\mathrm{act}}=BTHs. \]

For one query head, the attention score tensor has shape

\[ S\in\mathbb{R}^{B\times T\times T}. \]

The prefill self-attention cost scales as

\[ O(BT^2dh). \]

4.2 Decode

In decode, the model generates one new token per sequence. The current query length is

\[ T_q=1. \]

The current-step residual activation is

\[ X_{\ell,\mathrm{decode}}\in\mathbb{R}^{B\times 1\times H}. \]

The current-step activation memory is

\[ \mathrm{Mem}_{\mathrm{decode-act}}=BHs. \]

The KV cache for one decoder block is

\[ K,V\in\mathbb{R}^{B\times T\times H_{kv}}, \]

with memory

\[ \mathrm{Mem}_{\mathrm{KV,block}}=2BTH_{kv}s. \]

Across \(L\) decoder blocks, the KV-cache memory is

\[ \mathrm{Mem}_{\mathrm{KV,total}}=2LBTH_{kv}s. \]

Prefill is sensitive to long-sequence activation and attention compute. Decode has small current-token activations, but KV-cache memory grows with \(B\), \(T\), \(H_{kv}\), and \(L\).

5. Decoder Block

A pre-norm decoder block updates the residual stream as

\[ \widetilde{X}_\ell = X_\ell+ \mathrm{SelfAttention}_\ell(\mathrm{RMSNorm}(X_\ell)), \]

\[ X_{\ell+1} = \widetilde{X}_\ell+ \mathrm{MLP}_\ell(\mathrm{RMSNorm}(\widetilde{X}_\ell)). \]

5.1 Self-Attention Sublayer

For normalized input \(X\),

\[ Q=XW_Q,\qquad K=XW_K,\qquad V=XW_V. \]

The masked attention scores are

\[ S=\frac{QK^\top}{\sqrt d}+M_{\mathrm{causal}}. \]

The attention probabilities and attention output are

\[ P=\mathrm{softmax}(S), \qquad O=PV. \]

The self-attention output projection is

\[ Y=OW_O. \]

5.2 MLP Sublayer

For a gated MLP,

\[ A=\phi(XW_g)\odot(XW_u). \]

The down projection is

\[ Z=AW_d. \]

The matrix shapes are

\[ W_g,W_u\in\mathbb{R}^{H\times I}, \qquad W_d\in\mathbb{R}^{I\times H}. \]

6. Classic Tensor Parallelism

Classic tensor parallelism shards projection matrices while keeping the residual-stream activation replicated across all tensor-parallel ranks:

\[ X_r=X. \]

6.1 Self-Attention in Classic TP

The QKV projections are column-parallel:

\[ W_Q=[W_{Q,1},\dots,W_{Q,p}], \quad W_K=[W_{K,1},\dots,W_{K,p}], \quad W_V=[W_{V,1},\dots,W_{V,p}]. \]

Rank \(r\) computes

\[ Q_r=XW_{Q,r}, \qquad K_r=XW_{K,r}, \qquad V_r=XW_{V,r}. \]

Each rank owns a subset of heads, so the attention computation is local to that head shard:

\[ S_r=\frac{Q_rK_r^\top}{\sqrt d}+M_{\mathrm{causal}}, \]

\[ P_r=\mathrm{softmax}(S_r), \qquad O_r=P_rV_r. \]

The output projection is row-parallel:

\[ W_O= \begin{bmatrix} W_{O,1}\\ \vdots\\ W_{O,p} \end{bmatrix}. \]

Rank \(r\) computes a partial projected output:

\[ Y_r=O_rW_{O,r}. \]

The full output is

\[ Y=\sum_{r=1}^{p}Y_r. \]

Classic TP materializes \(Y\) on every TP rank using AllReduce:

\[ Y=\mathrm{AllReduce}_r(Y_r). \]

6.2 MLP in Classic TP

The gate and up projections are column-parallel:

\[ W_g=[W_{g,1},\dots,W_{g,p}], \qquad W_u=[W_{u,1},\dots,W_{u,p}]. \]

Rank \(r\) computes

\[ A_r=\phi(XW_{g,r})\odot(XW_{u,r}). \]

The down projection is row-parallel:

\[ W_d= \begin{bmatrix} W_{d,1}\\ \vdots\\ W_{d,p} \end{bmatrix}. \]

Rank \(r\) computes

\[ Z_r=A_rW_{d,r}. \]

The full MLP output is

\[ Z=\sum_{r=1}^{p}Z_r = \mathrm{AllReduce}_r(Z_r). \]

7. Tensor Parallelism with Sharded Activations

Sharded-activation TP keeps the same weight partitioning as classic TP but avoids materializing the full residual activation on every rank after row-parallel projections. AllReduce is replaced by ReduceScatter.

After the self-attention output projection:

\[ Y^{(q)}=\mathrm{ReduceScatter}_r(Y_r). \]

After the MLP down projection:

\[ Z^{(q)}=\mathrm{ReduceScatter}_r(Z_r). \]

The next decoder block consumes the sharded residual stream:

\[ X_{\ell+1}^{(q)}=Z_\ell^{(q)}. \]

The activation shard \(q\) can be defined along the batch dimension or the sequence dimension. These two cases have different attention communication.

7.1 Batch-Sharded Activations

With batch-sharded activations,

\[ X^{(q)}\in\mathbb{R}^{B_q\times T\times H}, \qquad B_q=\frac{B}{p}. \]

Each shard contains complete sequences. Therefore, QKV tensors are local to those sequences:

\[ Q_r^{(q)}=X^{(q)}W_{Q,r}, \qquad K_r^{(q)}=X^{(q)}W_{K,r}, \qquad V_r^{(q)}=X^{(q)}W_{V,r}. \]

Self-attention requires no cross-rank K/V exchange:

\[ O_r^{(q)} = \mathrm{softmax} \left( \frac{Q_r^{(q)}K_r^{(q)\top}}{\sqrt d} + M_{\mathrm{causal}} \right)V_r^{(q)}. \]

Row-parallel projections still require ReduceScatter:

\[ Y^{(q)}=\mathrm{ReduceScatter}_r(O_r^{(q)}W_{O,r}), \]

\[ Z^{(q)}=\mathrm{ReduceScatter}_r(A_r^{(q)}W_{d,r}). \]

7.2 Sequence-Sharded Activations

With sequence-sharded activations,

\[ X^{(q)}\in\mathbb{R}^{B\times T/p\times H}. \]

Each shard contains only a contiguous subset of tokens from the same sequence. Local QKV projections are

\[ Q_r^{(q)}=X^{(q)}W_{Q,r}, \qquad K_r^{(q)}=X^{(q)}W_{K,r}, \qquad V_r^{(q)}=X^{(q)}W_{V,r}. \]

A local query shard must attend over the complete context. Hence keys and values must be made visible across sequence shards:

\[ K_r=\mathrm{AllGather}_{\mathrm{seq}}(K_r^{(q)}), \qquad V_r=\mathrm{AllGather}_{\mathrm{seq}}(V_r^{(q)}). \]

The local output rows are then computed as

\[ O_r^{(q)} = \mathrm{softmax} \left( \frac{Q_r^{(q)}K_r^\top}{\sqrt d} + M_{\mathrm{causal}}^{(q)} \right)V_r. \]

Queries are not gathered because shard \(q\) produces only the output rows for its local query tokens.

The row-parallel projections are again completed with ReduceScatter:

\[ Y^{(q)}=\mathrm{ReduceScatter}_r(O_r^{(q)}W_{O,r}), \]

\[ Z^{(q)}=\mathrm{ReduceScatter}_r(A_r^{(q)}W_{d,r}). \]

8. Memory Analysis

Execution model Residual-stream activation per GPU KV cache per GPU per decoder block Attention-context exchange
Classic TP \(BTHs\) bytes \(\frac{2BTH_{kv}s}{p}\) No
Sharded-activation TP, batch shard \(\frac{BTHs}{p}\) \(\frac{2BTH_{kv}s}{p}\) No
Sharded-activation TP, sequence shard \(\frac{BTHs}{p}\) \(\frac{2BTH_{kv}s}{p}\) Yes, K/V exchange

9. Communication Analysis

Component Classic TP Batch-sharded activation TP Sequence-sharded activation TP
QKV projections Column-parallel, no collective Column-parallel, no collective Column-parallel, no collective
Self-attention core No collective No collective \(\mathrm{AllGather}(K,V)\), \(2\frac{p-1}{p}M_{KV}s\) bytes
Attention output projection \(\mathrm{AllReduce}\), \(2\frac{p-1}{p}M_Hs\) bytes \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes
MLP gate/up projections Column-parallel, no collective Column-parallel, no collective Column-parallel, no collective
MLP down projection \(\mathrm{AllReduce}\), \(2\frac{p-1}{p}M_Hs\) bytes \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes

10. Summary

Classic tensor parallelism shards projection matrices but replicates residual-stream activations after row-parallel projections. Sharded-activation tensor parallelism replaces those AllReduce operations with ReduceScatter and keeps the residual stream partitioned across ranks. With batch-sharded activations, complete sequences remain local and self-attention requires no K/V exchange. With sequence-sharded activations, one context is split across ranks, so the self-attention sublayer requires global K/V visibility through AllGather or an equivalent distributed-attention algorithm.

References

  1. A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin, “Attention is all you need,” in Advances in Neural Information Processing Systems, vol. 30, pp. 5998–6008, 2017. Available: NeurIPS proceedings.