SVI, ELBO y todas esas cosas: un ejemplo básico
I.
Va por delante que esta entrada está basada en esto. Se trata, de hecho, de las notas que he extraído mientras profundizaba en la implementación que hace NumPyro de la inferencia variacional, el ELBO, etc.
Antes de nada, nos quitamos los requisitios de en medio:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from jax import random
from numpyro.infer import SVI, Predictive, Trace_ELBO, MCMC, NUTS
rng_key = random.PRNGKey(seed=42)
Definimos un consabidísimo modelo —tiradas de moneda con una priori $\text{Beta}(10,10)$— y unos datos —sesenta caras y cuarenta cruces—:
def model(data):
f = numpyro.sample("z", dist.Beta(10, 10))
with numpyro.plate("N", data.shape[0]):
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
data = jnp.concatenate([jnp.ones(60), jnp.zeros(40)])
II. MCMC
Podemos ajustar el modelo en Numpyro usando MCMC/NUTS:
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=1000)
mcmc.run(rng_key, data)
Podemos examinar la distribución a posteriori de $z$ de diversas maneras. Por ejemplo, haciendo mcmc.print_summary()
se obtiene
mean std median 5.0% 95.0% n_eff r_hat
z 0.58 0.05 0.58 0.50 0.66 317.54 1.00
Number of divergences: 0
o hacer
samples = mcmc.get_samples()
fig, ax = plt.subplots()
ax.hist(samples['z'], bins = 50)
ax.set(
title="Posterior distribution of z",
);
para obtener
Todos sabemos que debería tratarse de una muestra de una $\text{Beta}(70, 50)$, aunque los parámetros de la beta resultante, estimados por el método de los momentos, son más bien 61.8 y 44.9.
III. SVI
Sin embargo, en problemas más complicados:
- No está clara la naturaleza de la distribución a posteriori.
- El método anterior puede resultar excesivamente lento.
Los métodos variacionales (véase el enlace al principio de la entrada) pueden resultar útiles en dichos casos. Sus ventajas y sus inconvenientes son consecuencia de su característica distintiva: comienzan postulando la solucíón.
Como aquel borracho que había perdido las llaves de su casa y las buscaba debajo de la farola, que es donde más y mejor se ve, los métodos variacionales no buscan la distribución exacta sino la más parecida a ella dentro de una colección de distribuciones parametrizadas. En concreto, estos métodos, para aproximar la distribución a posteriori $p(z)$ de $z$, ponen encima de la mesa una familia de distribuciones $q_\theta(z)$ parametrizadas por $\theta$ y tratan de determinar el valor adecuado de $\theta$ que minimiza la distancia entre $p(z)$ y $q_\theta(z)$.
Para el problema de hoy, la familia de distribuciones en cuestión se puede considerar igual a la de las distribuciones beta. Así, además, contienen verdaderamente la distribución a posteriori real, cosa que por supuesto, no cabe esperar en casi cualquier problema realmente real:
def guide(data):
alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 15., constraint=constraints.positive)
numpyro.sample("z", dist.Beta(alpha_q, beta_q))
La definición anterior merece algunos comentarios:
guide
es el nombre como en NumPyro y algunos otros lenguajes probabilísticos se conoce a la familia parametrizada de distribuciones candidatas.- Por algún motivo extraño, la función guía exige los mismos argumentos —
data
en este caso— que la del modelo, aunque ni los usa ni, conceptualmente, tiene nada que ver con ellos. - Los parámetros subyacentes
alpha_q
ybeta_q
tienen asignados valores por defecto de 15, pero no es estrictamente (o conceptualmente) necesario. - Estos valores, además, no tienen nada que ver con prioris de ningún tipo.
- Es posible obtener muestras de las distribuciones de esta familia haciendo cosas como
my_params = {'alpha_q': 30., 'beta_q': 45.}
preds = Predictive(model=guide, params = my_params, num_samples = 2000, return_sites=["z"])
rng_key, rng_subkey = random.split(rng_key)
guide_samples = preds(rng_subkey, data)
guide_samples
es en ese caso un diccionario y guide_samples['z']
contiene 2000 muestras de una distribución $\text{Beta}(30, 45)$.
Ahora, es posible encontrar la mejor guía (la más parecida a la distribución a posteriori) en NumPyro haciendo
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 2000, data)
y obtener los parámetros de esa guía óptima con svi_result.params
:
{'alpha_q': Array(18.633896, dtype=float32),
'beta_q': Array(13.33334, dtype=float32)}
(Que quedan un poco lejos, me temo, de los teóricos, que deberían ser, como se ha indicado arriba, 70 y 50.)
Se puede muestrear la posteriori como antes, pero reemplazando my_params
por svi_result.params
o haciendo
preds = Predictive(model=model, guide = guide, params=svi_result.params, num_samples = 2000, return_sites=["z"])
rng_key, rng_subkey = random.split(rng_key)
posterior_samples = preds(rng_subkey, data)
fig, ax = plt.subplots()
ax.hist(posterior_samples['z'], bins = 50)
ax.set(
title="Approximate Posterior Distribution of z",
);
para obtener, esencialmente, lo mismo:
La distribución es un poco más ancha de lo que cabría esperar. Se puede comparar, además, con la obtenida más arriba:
IV. Comentarios finales
Es aventurado criticar una técnica antes de haberla usado durante un buen tiempo en serio. Pero el resultado de este pequeño ejemplo ilustrativo de hoy es representativo de un fenómeno que he visto por ahí en más de una ocasión:
- Por motivos similares al mío, hay gente que utiliza el método ELBO y SVI en problemas sencillos con soluciones conocidas.
- Las soluciones obtenidas quedan sustancialmente lejos de los esperados.
- Se justifica la diferencia diciendo que, bueno, qué otra cosa puede esperarse de un método aproximado, etc.
- Aacto seguido, la técnica se aplica en problemas complejos en los que no contamos con soluciones contra las que validar los resultados.
¿Puro efecto Gell-Mann?