Números aleatorios, estado interno y su relación con el paralelismo

I.

En primer lugar, no voy a hablar de números aleatorios sino seudoaleatorios. Resumiéndolo todo mucho, un generador de números seudoaleatorios (PRNG en lo que sigue) es una función que a partir de una secuencia fácilmente adivinable (p.e., 0, 1, 2,…) genera otra de números con apariencia aleatoria.

Los números de la secuencia adivinable constituirían los distintos estados del PRNG. En R, Python y otros lenguajes populares, el generador de números aleatorios hace dos cosas: generar un número aleatorio y actualizar el estado.

[Estos generadores de números aleatorios no son pues lo que algunos llaman funciones puras en tanto que tienen efectos secundarios.]

II.

En R, el estado original se determina especificando la semilla:

set.seed(1)

El verdadero estado del PRNG por defecto de R puede consultarse así

.Random.seed

y es un vector largo (longitud 626) de enteros.

Cuando se hace

runif(1)

suceden dos cosas, una explícita y otra implícita. La explícita es que se obtiene el número 0.2655087. La implícita es que se modifica el estado.

[Nota técnica: el vector .Random.seed solo cambia cada 600 llamadas y pico a runif. Supongo que este funcionamiento se debe a algún tipo de optimización interna y que el estado verdadero incluye el estado más algún tipo de puntero o contador.]

III.

Esa manera de proceder es cómoda para el usuario, pero se convierte en un problema cuando se quiere paralelizar reproduciblemente código que usa números aleatorios: los distintos subprocesos tendrían que acceder al registro central del estado de manera coordinada, etc.

Para evitar ese problema, en JAX, la generación de números aleatorios es algo más engorrosa: es el usuario el que tiene que encargarse de actualizar el estado del PRNG explícitamente. De la generación de números aleatorios se encargan dos funciones puras (sin efectos secundarios) distintas, $r$ y $u$. La segunda actualiza el estado:

$$u(e_i) = e_{i+1},$$

mientras que la primera necesita un estado explícito. Generar números aleatorios en JAX, por tanto, vendría a hacerse así:

e0 <- seed(0)
r0 <- r(e0)
e1 <- u(e0)
r1 <- r(e1)
...