How DiT injects conditioning signals through adaptive normalization, why zero-initialization matters, and how the same idea generalizes to spatial conditioning with segmentation masks.
Standard LayerNorm normalizes activations, then applies fixed learned affine parameters. adaLN makes those parameters dynamic: the scale and shift are produced on the fly from a conditioning vector c — typically a timestep embedding plus a class embedding, passed through a small MLP.
The (1 + γ) form is deliberate — it centers the multiplicative scale at \(1\) (identity), so training starts neutral instead of at zero scale. In DiT this happens twice per block (before attention, before the FFN), giving four parameters per block: \(\gamma_1,\beta_1\) and \(\gamma_2,\beta_2\).
# c: condition [B, cond_dim] x: tokens [B, T, D]
shift1, scale1, shift2, scale2 = MLP(SiLU(c)).chunk(4, -1)
x_mod = LayerNorm(x) * (1 + scale1) + shift1
x = x + Attention(x_mod) # plain residual
x_mod = LayerNorm(x) * (1 + scale2) + shift2
x = x + FFN(x_mod)
The same \(\gamma,\beta\) are broadcast to every token. Conditioning is global — every image patch gets identical modulation. Fine for class/timestep signals; it discards spatial structure.
adaLN-Zero (Peebles & Xie, DiT 2023) adds one element: a per-block gating scalar \(\alpha\) that multiplies each sub-block's output before the residual add. The linear layer producing \(\alpha\) is initialized to zero.
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, 6 * hidden_size),
)
# ★ the critical line
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
At step 0 every modulation output is zero, so \(\alpha = 0\) and each block is an identity — only the residual stream flows. The network starts as a stack of no-ops and learns to switch blocks on gradually. That is the "Zero" in adaLN-Zero.
The \(\alpha\) pathway is structurally close to a Squeeze-and-Excitation block — a side branch that rescales the main features channel-wise. Omitting bias terms:
This buys extra expressivity even before zero-init enters the picture.
The ICLR-2025 analysis (OpenReview E4roJSM9RM) decoupled three candidate explanations and ranked them:
| Factor | What it is | Impact |
|---|---|---|
| SE-like structure | Adding α at all (even default init) | Moderate |
| Gradual update order | Zero-init delays gradients to some weights early on | Small |
| Zero-init position | Starting near the trained weight distribution | Dominant |
As training runs, the \(\alpha\) weights drift from a spike at zero toward a Gaussian-like distribution centered at zero. Zero-init drops you near the center of that target already. By an entropy argument, adaLN-Zero's distribution simply spreads outward (entropy increases — the natural direction), while other inits must first contract toward zero, fighting the gradient.
Since the weights end up Gaussian anyway, initialize there directly and skip the trip. One line changes; the paper reports ≈2 FID points over adaLN-Zero.
# replace the zeros_ weight line with:
nn.init.normal_(self.adaLN_modulation[-1].weight, std=0.001)
SPADE (Park et al., CVPR 2019 — a.k.a. GauGAN) tackles a sibling problem: conditioning a generator on a segmentation mask without normalization erasing its spatial layout.
Run a mask through an encoder, then normalize in the decoder, and the semantic signal washes out — regions of the same label collapse to one normalized value, so spatial structure disappears. The fix must be spatially aware.
Instead of a fixed affine layer after normalization, SPADE predicts \(\gamma\) and \(\beta\) per spatial location from the mask, via a small CNN:
class SPADE(nn.Module):
def __init__(self, ch, cond_ch):
super().__init__()
self.norm = nn.InstanceNorm2d(ch, affine=False)
self.shared = nn.Sequential(nn.Conv2d(cond_ch, 128, 3, padding=1), nn.ReLU())
self.gamma = nn.Conv2d(128, ch, 3, padding=1)
self.beta = nn.Conv2d(128, ch, 3, padding=1)
def forward(self, x, mask):
m = F.interpolate(mask, size=x.shape[2:], mode='nearest')
f = self.shared(m)
return self.gamma(f) * self.norm(x) + self.beta(f)
Park, Liu, Wang, Zhu — Semantic Image Synthesis with Spatially-Adaptive Normalization, CVPR 2019 (Oral).
adaLN-Zero and SPADE are the same move at different spatial granularity: normalize to strip dataset statistics, then re-inject conditioning through a learned affine transform.
| Property | adaLN / adaLN-Zero | SPADE |
|---|---|---|
| condition | vector (timestep, class) | image (segmentation mask) |
| predictor | MLP on vector | CNN on mask |
| params | γ, β (shared over tokens) | γ(y,x), β(y,x) per site |
| granularity | global | local / spatial |
| backbone | transformer (DiT) | conv decoder (GAN) |
| gate α | yes (Zero variant) | no |
Both are conditional affine modulation of a normalized activation. The only differences are what the condition is (vector vs image) and how coarsely the parameters are applied (one-per-channel vs one-per-site).
To condition a DiT on a binary mask while staying inside the adaLN family: replace the modulation MLP with a CNN that emits spatially-varying versions of all six parameters (including the \(\alpha\) gates), then modulate each patch token with the parameters at its own grid cell. This is SPADE's spatial idea carried into the transformer, with the gate retained.
This combination isn't a published, named method I can cite. SPADE is the citation for the spatial-modulation idea; the DiT paper is the citation for the gate and zero-init. The fusion below is a sensible design, not an established result.
self.out_conv = nn.Conv2d(256, 6 * hidden_size, 1)
nn.init.zeros_(self.out_conv.weight) # preserves the adaLN-Zero property
nn.init.zeros_(self.out_conv.bias)
spatial = self.mask_cnn(mask) # [B,256,H/P,W/P]
time = self.time_proj(t_emb)[..., None, None] # [B,256,1,1]
fused = spatial + time # broadcast -> global denoise signal
If mask and image share patch size \(P\), then CNN cell \((i,j)\) modulates image token \((i,j)\) exactly. No extra positional encoding needed on the modulation maps.
What's spatial here is the modulation, not the attention. Self-attention still mixes tokens globally. To make the mask gate attention itself, you'd need attention bias or cross-attention — a different mechanism.
| Approach | Spatial | Compute | Best for |
|---|---|---|---|
| adaLN-Zero | no | min | class / timestep |
| SPADE | yes | low | conv decoder + mask |
| Spatial adaLN-Zero | yes | med | DiT + binary/semantic mask |
| Cross-attention to mask | yes | high | layout drives content |