self-notes · diffusion transformers

Conditioning by modulation:
adaLN-Zero and SPADE

How DiT injects conditioning signals through adaptive normalization, why zero-initialization matters, and how the same idea generalizes to spatial conditioning with segmentation masks.

01

adaLN — adaptive layer norm

Standard LayerNorm normalizes activations, then applies fixed learned affine parameters. adaLN makes those parameters dynamic: the scale and shift are produced on the fly from a conditioning vector c — typically a timestep embedding plus a class embedding, passed through a small MLP.

\[ \gamma, \beta = \mathrm{MLP}(c) \qquad x_{\text{mod}} = \mathrm{LN}(x)\odot(1+\gamma) + \beta \]

The (1 + γ) form is deliberate — it centers the multiplicative scale at \(1\) (identity), so training starts neutral instead of at zero scale. In DiT this happens twice per block (before attention, before the FFN), giving four parameters per block: \(\gamma_1,\beta_1\) and \(\gamma_2,\beta_2\).

condition c SiLU + Linear γ (scale), β (shift) input x (residual) Modulate x ⊙ (1+γ) + β Attention / FFN Residual add out + x residual
adaLN block — condition produces γ, β (global, shared across all tokens). No gate on the residual.
# c: condition  [B, cond_dim]   x: tokens [B, T, D]
shift1, scale1, shift2, scale2 = MLP(SiLU(c)).chunk(4, -1)

x_mod = LayerNorm(x) * (1 + scale1) + shift1
x = x + Attention(x_mod)          # plain residual

x_mod = LayerNorm(x) * (1 + scale2) + shift2
x = x + FFN(x_mod)
limitation

The same \(\gamma,\beta\) are broadcast to every token. Conditioning is global — every image patch gets identical modulation. Fine for class/timestep signals; it discards spatial structure.

02

adaLN-Zero — adding the gate

adaLN-Zero (Peebles & Xie, DiT 2023) adds one element: a per-block gating scalar \(\alpha\) that multiplies each sub-block's output before the residual add. The linear layer producing \(\alpha\) is initialized to zero.

\[ \gamma_1,\beta_1,\alpha_1,\ \gamma_2,\beta_2,\alpha_2 = \mathrm{Linear}\big(\mathrm{SiLU}(c)\big) \] \[ x \leftarrow x + \alpha_1 \odot \mathrm{Attn}\!\big(\mathrm{LN}(x)\odot(1+\gamma_1)+\beta_1\big) \] \[ x \leftarrow x + \alpha_2 \odot \mathrm{FFN}\!\big(\mathrm{LN}(x)\odot(1+\gamma_2)+\beta_2\big) \]
condition c SiLU + Linear γ, β + α α linear = zero-init input x (residual) Modulate · x⊙(1+γ)+β then Attention / FFN Scale by α out ⊙ α Residual add (out ⊙ α) + x residual α = 0 at init → block starts as identity
adaLN-Zero — the orange α gate scales each sub-block output. Zero-init makes α = 0, so every block begins as a pure residual pass-through.

The zero-initialization

self.adaLN_modulation = nn.Sequential(
    nn.SiLU(),
    nn.Linear(cond_dim, 6 * hidden_size),
)

# ★ the critical line
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
key insight

At step 0 every modulation output is zero, so \(\alpha = 0\) and each block is an identity — only the residual stream flows. The network starts as a stack of no-ops and learns to switch blocks on gradually. That is the "Zero" in adaLN-Zero.

An SE-like side note

The \(\alpha\) pathway is structurally close to a Squeeze-and-Excitation block — a side branch that rescales the main features channel-wise. Omitting bias terms:

\[ \underbrace{\big(c\odot\sigma(c)\big)\,W_\alpha}_{\text{adaLN-Zero's }\alpha} \quad\Longleftrightarrow\quad \underbrace{\mathrm{ReLU}(W_1 c)\,W_2}_{\text{SE module}} \]

This buys extra expressivity even before zero-init enters the picture.

03

Why zero-init wins

The ICLR-2025 analysis (OpenReview E4roJSM9RM) decoupled three candidate explanations and ranked them:

FactorWhat it isImpact
SE-like structureAdding α at all (even default init)Moderate
Gradual update orderZero-init delays gradients to some weights early onSmall
Zero-init positionStarting near the trained weight distributionDominant

As training runs, the \(\alpha\) weights drift from a spike at zero toward a Gaussian-like distribution centered at zero. Zero-init drops you near the center of that target already. By an entropy argument, adaLN-Zero's distribution simply spreads outward (entropy increases — the natural direction), while other inits must first contract toward zero, fighting the gradient.

bonus · adaLN-Gaussian

Since the weights end up Gaussian anyway, initialize there directly and skip the trip. One line changes; the paper reports ≈2 FID points over adaLN-Zero.

# replace the zeros_ weight line with:
nn.init.normal_(self.adaLN_modulation[-1].weight, std=0.001)
04

SPADE — spatial modulation

SPADE (Park et al., CVPR 2019 — a.k.a. GauGAN) tackles a sibling problem: conditioning a generator on a segmentation mask without normalization erasing its spatial layout.

The problem

Run a mask through an encoder, then normalize in the decoder, and the semantic signal washes out — regions of the same label collapse to one normalized value, so spatial structure disappears. The fix must be spatially aware.

The mechanism

Instead of a fixed affine layer after normalization, SPADE predicts \(\gamma\) and \(\beta\) per spatial location from the mask, via a small CNN:

\[ \gamma_{c,y,x}(\mathbf{m}),\ \beta_{c,y,x}(\mathbf{m}) = \mathrm{CNN}(\mathbf{m}) \] \[ \mathrm{SPADE}(h,\mathbf{m})_{c,y,x} = \gamma_{c,y,x}(\mathbf{m})\,\frac{h_{c,y,x}-\mu_c}{\sigma_c} + \beta_{c,y,x}(\mathbf{m}) \]
mask m [B,C,H,W] Conv + ReLU shared feature γ(y,x) Conv 1×1 β(y,x) Conv 1×1 activation h [B,C,H,W] Normalize no affine γ(y,x) ⊙ ĥ + β(y,x) spatially-varying affine
SPADE — a CNN turns the mask into per-pixel γ and β maps that modulate the normalized activation. Same idea as adaLN, but the parameters vary across space.
class SPADE(nn.Module):
    def __init__(self, ch, cond_ch):
        super().__init__()
        self.norm   = nn.InstanceNorm2d(ch, affine=False)
        self.shared = nn.Sequential(nn.Conv2d(cond_ch, 128, 3, padding=1), nn.ReLU())
        self.gamma  = nn.Conv2d(128, ch, 3, padding=1)
        self.beta   = nn.Conv2d(128, ch, 3, padding=1)

    def forward(self, x, mask):
        m = F.interpolate(mask, size=x.shape[2:], mode='nearest')
        f = self.shared(m)
        return self.gamma(f) * self.norm(x) + self.beta(f)
paper

Park, Liu, Wang, Zhu — Semantic Image Synthesis with Spatially-Adaptive Normalization, CVPR 2019 (Oral).

06

Synthesis — spatial adaLN-Zero

To condition a DiT on a binary mask while staying inside the adaLN family: replace the modulation MLP with a CNN that emits spatially-varying versions of all six parameters (including the \(\alpha\) gates), then modulate each patch token with the parameters at its own grid cell. This is SPADE's spatial idea carried into the transformer, with the gate retained.

honesty note

This combination isn't a published, named method I can cite. SPADE is the citation for the spatial-modulation idea; the DiT paper is the citation for the gate and zero-init. The fusion below is a sensible design, not an established result.

\[ \big[\gamma_1,\beta_1,\alpha_1,\gamma_2,\beta_2,\alpha_2\big] = \mathrm{CNN}(\mathbf{m}) \in \mathbb{R}^{B\times 6D\times \frac{H}{P}\times\frac{W}{P}} \;\xrightarrow{\text{flatten}}\; \mathbb{R}^{B\times T\times 6D} \]

1 · keep zero-init on the final conv

self.out_conv = nn.Conv2d(256, 6 * hidden_size, 1)
nn.init.zeros_(self.out_conv.weight)   # preserves the adaLN-Zero property
nn.init.zeros_(self.out_conv.bias)

2 · fuse timestep additively in feature space

spatial = self.mask_cnn(mask)                     # [B,256,H/P,W/P]
time    = self.time_proj(t_emb)[..., None, None] # [B,256,1,1]
fused   = spatial + time                          # broadcast -> global denoise signal

3 · spatial alignment is implicit

If mask and image share patch size \(P\), then CNN cell \((i,j)\) modulates image token \((i,j)\) exactly. No extra positional encoding needed on the modulation maps.

scope reminder

What's spatial here is the modulation, not the attention. Self-attention still mixes tokens globally. To make the mask gate attention itself, you'd need attention bias or cross-attention — a different mechanism.

Choosing an approach

ApproachSpatialComputeBest for
adaLN-Zeronominclass / timestep
SPADEyeslowconv decoder + mask
Spatial adaLN-ZeroyesmedDiT + binary/semantic mask
Cross-attention to maskyeshighlayout drives content