## At a Glance
- **Definition:** A dropout variant that learns the per-layer dropout probability $p$ by gradient descent rather than grid-search, using a continuous relaxation of the discrete Bernoulli mask to make $p$ differentiable.
- **Formula:** $\mathcal{L}(\theta) = -\tfrac{1}{M}\sum_{i \in S}\log p(y_i \mid f^{\omega}(x_i)) + \tfrac{1}{N},\mathrm{KL}(q_\theta(\omega),|,p(\omega))$
- **Range / output:** Returns a trained scalar $p \in (0, 1)$ per layer; at inference, repeated forward passes through the Concrete-masked network produce a sample ensemble whose variance is the epistemic uncertainty estimate.
- **Also known as:** Learnable dropout; gradient-tuned Bayesian dropout. The relaxation used is the Concrete distribution, also called Gumbel-Softmax in the generative-modelling literature.[^2][^3]
## Why It Exists
The obvious way to set a dropout probability is a grid-search: try $p \in {0.1, 0.2, \ldots, 0.5}$, pick the value that maximises validation log-likelihood. This breaks in two ways. First, large vision models (10 M+ parameters) take days to train per run, so a grid-search over per-layer configurations — exponential in the number of layers — is computationally prohibitive.[^1] Second, in reinforcement learning the dataset grows with every episode, so the optimal $p$ should shrink continuously as data accumulates; grid-search requires resetting and retraining from scratch with each new batch of experience, which is simply impossible in an online setting.[^1] Concrete Dropout fixes both problems by treating $p$ as a variational parameter optimised end-to-end alongside the weights. The only obstacle is that the Bernoulli mask is discrete and has no gradient with respect to $p$ — so the paper replaces it with the Concrete distribution, a continuous relaxation that concentrates near 0 and 1 but is differentiable everywhere, enabling the pathwise derivative estimator to flow gradients through $p$.[^1]
## Demo!
<div class="ml-widget" data-algo="concrete-dropout"></div>
**Tips:**
- The curve is $\mathcal{L}(p)$ — the KL term from the objective above, viewed as a function of $p$ alone while $N$ and $K$ are held fixed by the sliders. The minimum is exactly where $\tfrac{d\mathcal{L}}{dp} = 0$: the point at which the likelihood pressure pushing $p$ toward 0 and the entropy term $-KH(p)$ pushing $p$ toward 0.5 cancel. Watch the minimum slide leftward as you increase $N$: the $\tfrac{1}{N}$ factor in front of the KL shrinks its influence, letting the data term win at a lower $p$.
- Set $N$ small and drag $K$ to its maximum, then drop the ball at $p = 0.05$. It rolls right — past what feels like the obvious low-dropout answer — and settles near 0.5. That's the entropy term $-KH(p)$, which scales with the number of parameters $K$, dominating: a large model with little data genuinely should mask aggressively, and the landscape says so. The three faint dashed curves show each additive force separately so you can see which term is moving the minimum.
## Formalization
The loss curve you just manipulated comes from dropout's variational interpretation.[^1] Treating the network's stochastic weight matrices $\omega = {W_l}_{l=1}^L$ as latent variables, dropout inference approximates the posterior $p(\omega \mid \mathcal{D})$ with a variational distribution $q_\theta(\omega)$ parameterised by mean weight matrices $M_l$ and dropout probabilities $p_l$. Minimising the KL divergence between the approximate and true posterior gives the ELBO objective:
$\mathcal{L}_{\mathrm{MC}}(\theta) = -\frac{1}{M}\sum_{i \in S}\log p(y_i \mid f^{\omega}(x_i)) + \frac{1}{N},\mathrm{KL}(q_\theta(\omega),|,p(\omega))$
The KL term decomposes per layer and, using the discretised Gaussian prior from Gal (2016),[^4] approximates to:
$\mathrm{KL}(q_M(W),|,p(W)) \propto \frac{l^2(1-p)^2}{2}|M|^2 - K,H(p)$
where the Bernoulli entropy is:
$H(p) := -p\log p - (1-p)\log(1-p)$
To obtain $\tfrac{\partial \mathcal{L}}{\partial p}$, the paper replaces the discrete Bernoulli mask with the Concrete distribution.[^2][^3] For the binary case this reduces to a sigmoid-transformed uniform noise:
$\tilde{z} = \sigma!\left(\frac{1}{t}!\left[\log p - \log(1-p) + \log u - \log(1-u)\right]\right), \quad u \sim \mathrm{Unif}(0,1)$
This $\tilde{z}$ sits in $(0, 1)$, concentrates near 0 and 1 for small temperature $t$, and is differentiable with respect to $p$ everywhere — enabling standard backpropagation to update $p$ alongside the weights.
**Variables:**
- **$\theta = {M_l, p_l}_{l=1}^L$** `[scalar p per layer]`: variational parameters — the mean weight matrices and the dropout probabilities being learned.
- **$f^{\omega}(x_i)$** `[output_dim]`: network output on input $x_i$ under one sampled weight realisation $\omega$ (one Concrete-masked forward pass).
- **$N$** `[scalar]`: total dataset size. Scales the KL term so that $p \to 0$ as $N \to \infty$ — more data reduces epistemic uncertainty.
- **$M$** `[scalar]`: mini-batch size used for the Monte Carlo estimate of the likelihood term.
- **$l$** `[scalar]`: prior length-scale. Sets the weight-regularisation strength: `weight_regularizer` $= l^2/(\tau N)$.
- **$K$** `[scalar]`: number of input units dropped (input dimensionality of the layer). Scales the entropy term, so deeper and wider layers are pushed toward higher $p$.
- **$H(p)$** `[scalar]`: entropy of a $\mathrm{Bernoulli}(p)$ random variable; maximised at $p = 0.5$, creating the rightward pull in the loss landscape.
- **$t$** `[scalar]`: Concrete temperature. Lower $t$ makes $\tilde{z}$ more binary (closer to standard dropout); $t$ can itself be treated as a variational parameter.
- **$\tilde{z}$** `[input_dim]`: the Concrete-relaxed mask, valued in $(0, 1)$. The pre-activation is multiplied by $\tilde{z}$ then rescaled by $1/(1-p)$ to preserve expected magnitude.
## Building the Intuition (Deep Dive)
### The tug-of-war inside $\mathcal{L}(p)$
The widget makes the loss landscape feel like a bowl with a movable minimum. To see why it has that shape, break the KL term into its two additive forces.
The **weight regularisation term** $\tfrac{l^2(1-p)^2}{2}|M|^2$ is minimised when $p = 1$: mask everything, drive the effective weight magnitude to zero. It pulls the minimum rightward. The **entropy term** $-KH(p)$ is minimised (most negative) at $p = 0.5$, where entropy is highest. It also pulls rightward from small $p$ and leftward from large $p$ — it centres. The **likelihood term** $\tfrac{1}{N} \cdot (\text{first term of } \mathcal{L}_{\mathrm{MC}})$ penalises any masking that degrades predictions; it pulls leftward toward $p = 0$. Together these produce a convex bowl whose minimum moves left as $N$ grows (data overwhelms the KL) and right as $K$ grows (the entropy term, scaled by $K$, strengthens).[^1]
### Why the Bernoulli gradient is blocked — and how Concrete fixes it
Standard gradient descent on $p$ fails because sampling from $\mathrm{Bernoulli}(p)$ is a non-differentiable step function. The score-function (REINFORCE) estimator does produce an unbiased gradient, but its variance is so high in practice that optimisation stalls.[^1] The Concrete distribution sidesteps this by relaxing the hard 0/1 boundary: instead of $z \sim \mathrm{Bernoulli}(p)$, sample $\tilde{z}$ from the sigmoid-of-logistic-noise reparameterisation above. Because $\tilde{z}$ is a deterministic, differentiable function of the noise $u$ and the parameter $p$, the pathwise derivative $\partial \tilde{z}/\partial p$ exists everywhere and can be backpropagated directly.[^2][^3] At temperature $t \to 0$, $\tilde{z}$ converges in distribution to the Bernoulli — so a small fixed temperature (the paper uses $t = 0.1$) gives the best of both worlds: near-binary behaviour at inference and differentiable gradients at training.[^1]
### The input-layer pattern
Across all experiments in the paper, the input layer's dropout probability converges to nearly zero while deeper layers retain moderate $p$.[^1] The widget encodes why: the entropy term $-KH(p)$ scales with $K$, the layer's input width. The input layer is typically narrow relative to hidden layers, so its entropy term is weakest — the likelihood pressure dominates and drives $p \to 0$. Deeper, wider layers have stronger entropy terms and settle at higher $p$. This gives a post-hoc justification for the practitioner rule of using small dropout early and larger dropout deep.
### Epistemic vs. aleatoric uncertainty
Concrete Dropout is a tool for **epistemic uncertainty** — uncertainty that shrinks as data accumulates. The dropout probability $p$ directly controls the variance of the function ensemble: larger $p$ produces more varied forward passes and wider predictive intervals. At $N \to \infty$ the optimised $p \to 0$ and the epistemic uncertainty collapses to zero, leaving only **aleatoric uncertainty** (irreducible noise), which is modelled separately via the likelihood precision $\tau$.[^1] These two sources are cleanly separable in the variational framework, and both are estimable by Concrete Dropout — epistemic from the mask-induced variance, aleatoric from the learned $\tau$.
## Failure Modes & Gotchas
- **$p$ stuck near 0.5 with small data:** If the dataset is small relative to model size, the entropy term dominates and the optimiser parks $p$ near 0.5. This is mathematically correct — the model genuinely has high epistemic uncertainty — but can feel like the network is refusing to train. The fix is not to lower $p$ by hand; it is to collect more data or reduce model capacity.
- **Gradient instability near $p = 0$ or $p = 1$:** The Concrete reparameterisation involves $\log p$ and $\log(1-p)$, both of which blow up at the boundaries. Practically, $p$ should be parameterised as a sigmoid of an unconstrained logit (as in the Keras implementation in the paper and the PyTorch example above), so it can never exactly reach 0 or 1. Clamp the logit range or add a small $\varepsilon$ if numerical issues appear.
- **Miscalibrated `weight_regularizer` / `dropout_regularizer`:** The two regulariser hyperparameters must satisfy `weight_regularizer` $= l^2/(\tau N)$ and `dropout_regularizer` $= 2/(\tau N)$, where $\tau$ is the model precision (inverse observation noise) and $N$ is the dataset size. Getting $N$ wrong — for example using a mini-batch size instead of the full dataset size — gives the wrong $p$ at convergence and breaks uncertainty calibration. This is the most common implementation error.[^1]
- **Applying to non-Dense layers without adjusting the regulariser:** The paper's regulariser derivation assumes the loss is summed over data points and the KL scales by $1/N$. For pixel-wise losses (e.g. semantic segmentation), the effective $N$ is $N_{\text{images}} \times H \times W$. The paper's computer-vision experiments use `dropout_regularizer` $= 0.01 \times N \times H \times W$ for exactly this reason. Forgetting the spatial scaling drives $p$ to near zero for all layers.[^1]
- **Temperature too high:** At large $t$ the Concrete samples are uniform in $(0, 1)$ rather than near-binary, so the network experiences smooth attenuation rather than dropout-like masking. Uncertainty estimates become poorly calibrated. Use $t \leq 0.1$ in practice; the paper uses $t = 0.1$ throughout.[^1]
- **Comparing to fixed-$p$ dropout without MC sampling:** Concrete Dropout's uncertainty estimates require MC sampling at inference (multiple stochastic forward passes). Evaluating with a single deterministic forward pass (standard dropout inference mode) produces neither calibrated uncertainty nor the accuracy gains the paper reports.
## Implementation Details & Code
- **Framework Equivalents:** PyTorch: wrap any `nn.Module` layer with the `ConcreteDropout` wrapper shown below; add `model.get_kl_loss()` to your training loss. Keras: the paper's Appendix C provides a ~20-line `Wrapper` subclass.[^1] No built-in equivalent exists in `torch.nn` or `torch.nn.functional`.
- **Complexity:** Training adds one scalar parameter $p_\text{logit}$ per wrapped layer (negligible). The KL computation is $O(d)$ in the layer's weight count. Inference cost for $T$ MC samples is $T \times$ single-forward-pass cost — typically $T = 10$–$50$.
- **Practical Heuristics:** Initialise $p$ around 0.1–0.2 (the paper uses uniform $\log\text{it}$ in $[-2, 0]$, corresponding to roughly 0.12–0.5). Set `weight_regularizer` $= l^2/(\tau N)$ and `dropout_regularizer` $= 2/(\tau N)$; the paper finds $l = 10^{-2}$ works well across UCI benchmarks. For pixel-wise losses multiply `dropout_regularizer` by $H \times W$. Temperature $t = 0.1$ is the recommended default. Concrete Dropout is tolerant to the initial value of $p$: the paper shows convergence to similar optima from any initialisation in $[0.05, 0.5]$.[^1]
```python
import torch
import torch.nn as nn
import numpy as np
class ConcreteDropout(nn.Module):
"""
Concrete Dropout wrapper (Gal, Hron & Kendall 2017).
Wraps any nn.Module layer; learns the dropout probability p
by gradient descent via the Concrete relaxation.
Usage:
layer = ConcreteDropout(nn.Linear(in_dim, out_dim), input_dim=in_dim)
# In training loop:
out = layer(x)
kl = layer.get_kl_loss()
loss = nll_loss(out, y) + kl
"""
def __init__(
self,
layer: nn.Module,
input_dim: int,
weight_regularizer: float = 1e-6, # = l² / (τN)
dropout_regularizer: float = 1e-5, # = 2 / (τN)
init_p: float = 0.1,
temperature: float = 0.1,
):
super().__init__()
self.layer = layer
self.input_dim = input_dim
self.weight_regularizer = weight_regularizer
self.dropout_regularizer = dropout_regularizer
self.temperature = temperature
# p is parameterised as sigmoid(p_logit) so it never hits 0 or 1.
self.p_logit = nn.Parameter(
torch.tensor(float(np.log(init_p / (1.0 - init_p))))
)
@property
def p(self) -> torch.Tensor:
return torch.sigmoid(self.p_logit)
def get_kl_loss(self) -> torch.Tensor:
"""KL regularisation term — add this to your NLL loss at each step."""
p = self.p.to(torch.float32)
eps = 1e-7
# Weight regularisation: (l²(1−p)²/2)‖W‖² → weight_reg * ‖W‖² / (1−p)
w_sum = sum(
param.pow(2).sum()
for name, param in self.layer.named_parameters()
if "weight" in name
)
weight_reg = self.weight_regularizer * w_sum / (1.0 - p + eps)
# Entropy regularisation: −(2/τN) · K · H(p)
# Minimising −H(p) pushes p toward 0.5.
H = -(p * torch.log(p + eps) + (1 - p) * torch.log(1 - p + eps))
dropout_reg = -self.dropout_regularizer * self.input_dim * H
return weight_reg + dropout_reg
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = self.p
eps = 1e-8
# Concrete relaxation: sigmoid((log p − log(1−p) + log u − log(1−u)) / t)
u = torch.rand_like(x).clamp(eps, 1.0 - eps)
z = torch.sigmoid(
(self.p_logit + torch.log(u) - torch.log(1.0 - u)) / self.temperature
)
# Scale to preserve expected activation magnitude.
return self.layer(x * (1.0 - z) / (1.0 - p + eps))
```
## References
[^1]: Gal, Y., Hron, J. and Kendall, A. (2017) 'Concrete Dropout', _Advances in Neural Information Processing Systems_, 30. Available at: https://arxiv.org/abs/1705.07832
[^2]: Maddison, C.J., Mnih, A. and Teh, Y.W. (2017) 'The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables', _International Conference on Learning Representations_. Available at: https://arxiv.org/abs/1611.00712
[^3]: Jang, E., Gu, S. and Poole, B. (2017) 'Categorical Reparameterization with Gumbel-Softmax', _International Conference on Learning Representations_. Available at: https://arxiv.org/abs/1611.01144
[^4]: Gal, Y. (2016) 'Uncertainty in Deep Learning', PhD thesis, University of Cambridge. Available at: http://mlg.eng.cam.ac.uk/yarin/thesis/thesis.pdf
<!-- ========================================================================== WIDGET TECHNICAL SPECIFICATION (Phase 3) Hidden from the published note. ========================================================================== ### SPECIFICATION: concrete-dropout 1. STATE SCHEMA: - N: number // dataset size; default 200; range [10,10000] log-slider - K: number // model capacity proxy; default 500; range [50,5000] - ballP: number // ball position as raw p ∈ (P_MIN,P_MAX); default 0.5 - isDragging: bool // true while pointer held on ball - isDescending: bool // true while ball rolls to minimum after release - stepAccumMs: number // accumulated ms for descent pacing Constants: P_MIN=0.01, P_MAX=0.99, CURVE_PTS=200, ETA=0.003, W_SCALE=0.8, BALL_R=14, HIT_R=22, STEP_MS=42 2. MATH ENGINE ADDITIONS (publish.js Layer 2): - MLMath.cdObjective(p, cLik, cW, cE): L(p) = cLik·p + cW·(1−p)² − cE·H(p), H(p) = −p·ln(p) − (1−p)·ln(1−p) - MLMath.cdGradient(p, cLik, cW, cE): dL/dp = cLik − 2·cW·(1−p) − cE·ln((1−p)/p) - MLMath.argmin: REUSE existing body. Coefficients: cLik=1/N, cW=W_SCALE, cE=K/N 3. INTERACTION MAPPING: - pointerdown: hit-test ball (radius HIT_R*scale); hit→isDragging=true, isDescending=false; miss→place ballP at pointer p, isDescending=false - pointermove: if isDragging→update ballP from eventToModel x - pointerup/cancel: isDragging=false, isDescending=true - Reset button: ballP=0.5, isDragging=false, isDescending=false, stepAccumMs=0 All coords via eventToModel(e)→mx∈[-1,1]→p=(mx+1)/2, clamped to (P_MIN,P_MAX) 4. RENDER LOOP: - buildControls(): N log-slider (10^v), K linear slider, Reset button - init(): defaults; 4 pointer listeners via this.addListener - onFrame(dt): if isDescending and !isDragging: accumulate dt, step by ETA*cdGradient, clamp; stop when |grad|<1e-5 - draw(ctx,W,H): 1. Sample Lvals[200], compute Lmin/Lmax 2. Axes + tick labels 3. Three faint dashed component terms in same y-frame as L(p) 4. Main L(p) curve (accent colour) 5. p* dashed vertical + label (argmin of Lvals) 6. Ball at curve y, fill=accent if descending else textCol 7. setStatus [p, L(p), p*, state string] =========================================================================== -->