Autoregressive large language models are dominated by dense linear algebra in two parts of every decoder block: the self-attention sublayer and the MLP sublayer. As model size, context length, and batch size increase, a single GPU may not have enough memory bandwidth, compute throughput, or device memory to execute the model efficiently.
Tensor parallelism partitions large projection matrices across a tensor-parallel group of \(p\) GPUs. Each rank computes a shard of the output, and collective communication is used at the points where sharded outputs must be reduced or where a sharded attention context must be made visible.
The purpose of this article is to study tensor parallelism using tensor shapes, matrix products, and communication volumes.
| Symbol | Definition |
|---|---|
| \(B\) | Global batch size: number of independent autoregressive sequences. |
| \(B_q\) | Batch size of activation shard \(q\), typically \(B_q=B/p\) for batch-sharded activations. |
| \(T\) | Context length or sequence length. |
| \(T_q\) | Query length. In decode, \(T_q=1\). |
| \(H\) | Residual-stream hidden dimension. |
| \(I\) | MLP intermediate dimension. |
| \(L\) | Number of decoder blocks. |
| \(\ell\) | Decoder-block index. |
| \(p\) | Tensor-parallel degree: number of ranks in one tensor-parallel group. |
| \(r\) | Tensor-parallel rank index, \(r\in\{1,\dots,p\}\). |
| \(q\) | Activation shard index. The shard may be along the batch dimension or sequence dimension. |
| \(i,j,k\) | Matrix-entry indices. |
| \(b\) | Batch index. |
| \(t,u\) | Token-position indices. |
| \(a\) | Attention-head index. |
| \(h\) | Number of query heads. |
| \(h_{kv}\) | Number of key/value heads. |
| \(d\) | Per-head dimension. |
| \(H_q\) | Total query projection dimension, \(H_q=hd\). |
| \(H_{kv}\) | Total key/value projection dimension, \(H_{kv}=h_{kv}d\). |
| \(s\) | Bytes per scalar element. For BF16 or FP16, \(s=2\). |
| \(X_\ell\) | Residual-stream activation at the input of decoder block \(\ell\). |
| \(\widetilde{X}_\ell\) | Residual stream after the self-attention sublayer of decoder block \(\ell\). |
| \(Q,K,V\) | Query, key, and value tensors. |
| \(S\) | Attention score tensor before softmax. |
| \(P\) | Attention probability tensor after softmax. |
| \(O\) | Self-attention output before the output projection. |
| \(Y\) | Self-attention output after the output projection. |
| \(A\) | MLP intermediate activation. |
| \(Z\) | MLP output after the down projection. |
| \(W_Q,W_K,W_V\) | Query, key, and value projection matrices. |
| \(W_O\) | Self-attention output projection matrix. |
| \(W_g,W_u\) | MLP gate and up projection matrices. |
| \(W_d\) | MLP down projection matrix. |
| \(\phi\) | Elementwise activation function. |
| \(\odot\) | Elementwise multiplication. |
| \(M_{\mathrm{causal}}\) | Causal mask for autoregressive self-attention. |
| \(M_H\) | Number of elements in a full hidden activation tensor, \(M_H=BTH\). |
| \(M_{KV}\) | Number of elements in one full key or value tensor, \(M_{KV}=BTH_{kv}\). |
For
\[ A\in\mathbb{R}^{m\times n}, \qquad B\in\mathbb{R}^{n\times k}, \]
the product \(C=AB\) satisfies
\[ C\in\mathbb{R}^{m\times k}, \qquad C_{ij}=\sum_{k'=1}^{n}A_{ik'}B_{k'j}. \]
Elementwise multiplication is
\[ (A\odot B)_{ij}=A_{ij}B_{ij}. \]
For an elementwise nonlinearity \(\phi\),
\[ (\phi(A))_{ij}=\phi(A_{ij}). \]
Column-wise partitioning across \(p\) tensor-parallel ranks is written as
\[ W=[W_1,W_2,\dots,W_p]. \]
Row-wise partitioning is written as
\[ W= \begin{bmatrix} W_1\\ W_2\\ \vdots\\ W_p \end{bmatrix}. \]
The causal mask prevents position \(t\) from attending to future position \(u>t\):
\[ M_{\mathrm{causal}}(t,u)= \begin{cases} 0, & u\leq t,\\ -\infty, & u>t. \end{cases} \]
| Collective | Definition | Per-rank communication volume in bytes |
|---|---|---|
| AllReduce | Each rank starts with \(X_r\). All ranks receive \[ X=\sum_{r=1}^{p}X_r. \] | \[ 2\frac{p-1}{p}Ns \] |
| ReduceScatter | First reduce \[ X=\sum_{r=1}^{p}X_r, \] then rank \(q\) receives one shard \(X^{(q)}\). | \[ \frac{p-1}{p}Ns \] |
| AllGather | Each rank starts with shard \(X^{(q)}\). All ranks receive \[ X=[X^{(1)},X^{(2)},\dots,X^{(p)}]. \] | \[ \frac{p-1}{p}Ns \] |
Here \(N\) is the number of scalar elements in the full tensor involved in the collective. For hidden activations, \(N=M_H=BTH\). For one key or one value tensor, \(N=M_{KV}=BTH_{kv}\).
Communication costs are derived assuming communication-optimal implementations of collective communication primitives. Specifically, AllReduce, ReduceScatter, and AllGather are assumed to use recursive-doubling or ring-based algorithms whose communicated volume per rank scales as \((p-1)/p\) times the message size.
Let \(N\) denote the number of communicated elements and let \(s\) denote the number of bytes per element.
The communicated bytes per rank are defined as:
\[ C_{AG}(N,p) = \frac{p-1}{p}Ns \]
\[ C_{RS}(N,p) = \frac{p-1}{p}Ns \]
\[ C_{AR}(N,p) = 2\frac{p-1}{p}Ns \]
where \(C_{AG}\), \(C_{RS}\), and \(C_{AR}\) denote the communicated bytes per rank for AllGather, ReduceScatter, and AllReduce respectively.
In prefill, the full prompt is processed. The residual-stream activation entering a decoder block is
\[ X_\ell\in\mathbb{R}^{B\times T\times H}. \]
The residual activation memory is
\[ \mathrm{Mem}_{\mathrm{act}}=BTHs. \]
For one query head, the attention score tensor has shape
\[ S\in\mathbb{R}^{B\times T\times T}. \]
The prefill self-attention cost scales as
\[ O(BT^2dh). \]
In decode, the model generates one new token per sequence. The current query length is
\[ T_q=1. \]
The current-step residual activation is
\[ X_{\ell,\mathrm{decode}}\in\mathbb{R}^{B\times 1\times H}. \]
The current-step activation memory is
\[ \mathrm{Mem}_{\mathrm{decode-act}}=BHs. \]
The KV cache for one decoder block is
\[ K,V\in\mathbb{R}^{B\times T\times H_{kv}}, \]
with memory
\[ \mathrm{Mem}_{\mathrm{KV,block}}=2BTH_{kv}s. \]
Across \(L\) decoder blocks, the KV-cache memory is
\[ \mathrm{Mem}_{\mathrm{KV,total}}=2LBTH_{kv}s. \]
A pre-norm decoder block updates the residual stream as
\[ \widetilde{X}_\ell = X_\ell+ \mathrm{SelfAttention}_\ell(\mathrm{RMSNorm}(X_\ell)), \]
\[ X_{\ell+1} = \widetilde{X}_\ell+ \mathrm{MLP}_\ell(\mathrm{RMSNorm}(\widetilde{X}_\ell)). \]
For normalized input \(X\),
\[ Q=XW_Q,\qquad K=XW_K,\qquad V=XW_V. \]
The masked attention scores are
\[ S=\frac{QK^\top}{\sqrt d}+M_{\mathrm{causal}}. \]
The attention probabilities and attention output are
\[ P=\mathrm{softmax}(S), \qquad O=PV. \]
The self-attention output projection is
\[ Y=OW_O. \]
For a gated MLP,
\[ A=\phi(XW_g)\odot(XW_u). \]
The down projection is
\[ Z=AW_d. \]
The matrix shapes are
\[ W_g,W_u\in\mathbb{R}^{H\times I}, \qquad W_d\in\mathbb{R}^{I\times H}. \]
Classic tensor parallelism shards projection matrices while keeping the residual-stream activation replicated across all tensor-parallel ranks:
\[ X_r=X. \]
The QKV projections are column-parallel:
\[ W_Q=[W_{Q,1},\dots,W_{Q,p}], \quad W_K=[W_{K,1},\dots,W_{K,p}], \quad W_V=[W_{V,1},\dots,W_{V,p}]. \]
Rank \(r\) computes
\[ Q_r=XW_{Q,r}, \qquad K_r=XW_{K,r}, \qquad V_r=XW_{V,r}. \]
Each rank owns a subset of heads, so the attention computation is local to that head shard:
\[ S_r=\frac{Q_rK_r^\top}{\sqrt d}+M_{\mathrm{causal}}, \]
\[ P_r=\mathrm{softmax}(S_r), \qquad O_r=P_rV_r. \]
The output projection is row-parallel:
\[ W_O= \begin{bmatrix} W_{O,1}\\ \vdots\\ W_{O,p} \end{bmatrix}. \]
Rank \(r\) computes a partial projected output:
\[ Y_r=O_rW_{O,r}. \]
The full output is
\[ Y=\sum_{r=1}^{p}Y_r. \]
Classic TP materializes \(Y\) on every TP rank using AllReduce:
\[ Y=\mathrm{AllReduce}_r(Y_r). \]
The gate and up projections are column-parallel:
\[ W_g=[W_{g,1},\dots,W_{g,p}], \qquad W_u=[W_{u,1},\dots,W_{u,p}]. \]
Rank \(r\) computes
\[ A_r=\phi(XW_{g,r})\odot(XW_{u,r}). \]
The down projection is row-parallel:
\[ W_d= \begin{bmatrix} W_{d,1}\\ \vdots\\ W_{d,p} \end{bmatrix}. \]
Rank \(r\) computes
\[ Z_r=A_rW_{d,r}. \]
The full MLP output is
\[ Z=\sum_{r=1}^{p}Z_r = \mathrm{AllReduce}_r(Z_r). \]
Sharded-activation TP keeps the same weight partitioning as classic TP but avoids materializing the full residual activation on every rank after row-parallel projections. AllReduce is replaced by ReduceScatter.
After the self-attention output projection:
\[ Y^{(q)}=\mathrm{ReduceScatter}_r(Y_r). \]
After the MLP down projection:
\[ Z^{(q)}=\mathrm{ReduceScatter}_r(Z_r). \]
The next decoder block consumes the sharded residual stream:
\[ X_{\ell+1}^{(q)}=Z_\ell^{(q)}. \]
The activation shard \(q\) can be defined along the batch dimension or the sequence dimension. These two cases have different attention communication.
With batch-sharded activations,
\[ X^{(q)}\in\mathbb{R}^{B_q\times T\times H}, \qquad B_q=\frac{B}{p}. \]
Each shard contains complete sequences. Therefore, QKV tensors are local to those sequences:
\[ Q_r^{(q)}=X^{(q)}W_{Q,r}, \qquad K_r^{(q)}=X^{(q)}W_{K,r}, \qquad V_r^{(q)}=X^{(q)}W_{V,r}. \]
Self-attention requires no cross-rank K/V exchange:
\[ O_r^{(q)} = \mathrm{softmax} \left( \frac{Q_r^{(q)}K_r^{(q)\top}}{\sqrt d} + M_{\mathrm{causal}} \right)V_r^{(q)}. \]
Row-parallel projections still require ReduceScatter:
\[ Y^{(q)}=\mathrm{ReduceScatter}_r(O_r^{(q)}W_{O,r}), \]
\[ Z^{(q)}=\mathrm{ReduceScatter}_r(A_r^{(q)}W_{d,r}). \]
With sequence-sharded activations,
\[ X^{(q)}\in\mathbb{R}^{B\times T/p\times H}. \]
Each shard contains only a contiguous subset of tokens from the same sequence. Local QKV projections are
\[ Q_r^{(q)}=X^{(q)}W_{Q,r}, \qquad K_r^{(q)}=X^{(q)}W_{K,r}, \qquad V_r^{(q)}=X^{(q)}W_{V,r}. \]
A local query shard must attend over the complete context. Hence keys and values must be made visible across sequence shards:
\[ K_r=\mathrm{AllGather}_{\mathrm{seq}}(K_r^{(q)}), \qquad V_r=\mathrm{AllGather}_{\mathrm{seq}}(V_r^{(q)}). \]
The local output rows are then computed as
\[ O_r^{(q)} = \mathrm{softmax} \left( \frac{Q_r^{(q)}K_r^\top}{\sqrt d} + M_{\mathrm{causal}}^{(q)} \right)V_r. \]
Queries are not gathered because shard \(q\) produces only the output rows for its local query tokens.
The row-parallel projections are again completed with ReduceScatter:
\[ Y^{(q)}=\mathrm{ReduceScatter}_r(O_r^{(q)}W_{O,r}), \]
\[ Z^{(q)}=\mathrm{ReduceScatter}_r(A_r^{(q)}W_{d,r}). \]
| Execution model | Residual-stream activation per GPU | KV cache per GPU per decoder block | Attention-context exchange |
|---|---|---|---|
| Classic TP | \(BTHs\) bytes | \(\frac{2BTH_{kv}s}{p}\) | No |
| Sharded-activation TP, batch shard | \(\frac{BTHs}{p}\) | \(\frac{2BTH_{kv}s}{p}\) | No |
| Sharded-activation TP, sequence shard | \(\frac{BTHs}{p}\) | \(\frac{2BTH_{kv}s}{p}\) | Yes, K/V exchange |
| Component | Classic TP | Batch-sharded activation TP | Sequence-sharded activation TP |
|---|---|---|---|
| QKV projections | Column-parallel, no collective | Column-parallel, no collective | Column-parallel, no collective |
| Self-attention core | No collective | No collective | \(\mathrm{AllGather}(K,V)\), \(2\frac{p-1}{p}M_{KV}s\) bytes |
| Attention output projection | \(\mathrm{AllReduce}\), \(2\frac{p-1}{p}M_Hs\) bytes | \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes | \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes |
| MLP gate/up projections | Column-parallel, no collective | Column-parallel, no collective | Column-parallel, no collective |
| MLP down projection | \(\mathrm{AllReduce}\), \(2\frac{p-1}{p}M_Hs\) bytes | \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes | \(\mathrm{ReduceScatter}\), \(\frac{p-1}{p}M_Hs\) bytes |
Classic tensor parallelism shards projection matrices but replicates residual-stream activations after row-parallel projections. Sharded-activation tensor parallelism replaces those AllReduce operations with ReduceScatter and keeps the residual stream partitioned across ranks. With batch-sharded activations, complete sequences remain local and self-attention requires no K/V exchange. With sequence-sharded activations, one context is split across ranks, so the self-attention sublayer requires global K/V visibility through AllGather or an equivalent distributed-attention algorithm.