I.
Va por delante que esta entrada está basada en
esto.
Se trata, de hecho, de las notas que he extraído mientras profundizaba en la implementación que hace NumPyro de la inferencia variacional, el ELBO, etc.
Antes de nada, nos quitamos los requisitios de en medio:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from jax import random
from numpyro.infer import SVI, Predictive, Trace_ELBO, MCMC, NUTS
rng_key = random.PRNGKey(seed=42)
Definimos un consabidísimo modelo —tiradas de moneda con una priori $\text{Beta}(10,10)$— y unos datos —sesenta caras y cuarenta cruces—: