๐ชJAX - ์ธ๋ถ์ ์ธ ํน์ง๋ค (JAX - The Sharp Bits)๐ช#
์ ์ : Anselm Levskaya & Mattew Johnson
์ญ์ : ์ฅ์ง์ฐ(wdfrty1234@gmail.com)
๊ฒ์ : ์ด์๋น, ๋ฐ์ ํ
์ดํ๋ฆฌ์ ๋ณ๋๋ฆฌ๋ฅผ ๊ฑท๋ค๋ณด๋ฉด, ์ฌ๋๋ค์ด
JAX์ ๋ํด *โuna anima di pura programmazione funzionale(์์ํ ํจ์ํ ํ๋ก๊ทธ๋๋ฐ์ ์ํผ)โ*์ด๋ผ๋ ์ฉ์ด๋ก ๋ฌ์ฌํ๋ ์ฌ๋๋ค์ด ๋ง๋ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
JAX๋ ์์นํด์ ํ๋ก๊ทธ๋จ์ ๋ณํ์ ์ํ ์ธ์ด๋ก, CPU ํน์ ๊ฐ์๊ธฐ(GPU/TPU)์์ ์์นํ ํ๋ก๊ทธ๋จ์ ์ปดํ์ผํ์ฌ ๋์์ํฌ ์ ์์ต๋๋ค. ํน์ ํ ์ ์ฝ์กฐ๊ฑด์ ์ถฉ์กฑํ๋ ๊ฒฝ์ฐ, JAX๋ ์์น ๋ฐ ๊ณผํ ํ๋ก๊ทธ๋จ์์ ์ ๋์ํฉ๋๋ค!
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
๐ช์์ ํจ์ (Pure functions)#
JAX๋ (๋ชจ๋ ์
๋ ฅ ๋ฐ์ดํฐ๊ฐ ํจ์์ ๋งค๊ฐ๋ณ์๋ฅผ ํตํด ์ ๋ฌ๋๊ณ , ๋ชจ๋ ์ถ๋ ฅ ๊ฒฐ๊ณผ๊ฐ ํจ์์ ๊ฒฐ๊ณผ๋ฅผ ํตํด์ ๋์ค๋) ์์ ํจ์์์ ์ ๋์ํ๋๋ก ์ค๊ณ๋์ด ์์ต๋๋ค.
(์ญ์ ์ฃผ) ์์ ํจ์๋ ๊ฐ์ ์ ๋ ฅ์ด ์ฃผ์ด์ง๋ค๋ฉด ํญ์ ๊ฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํ๋ ํจ์๋ฅผ ์๋ฏธํฉ๋๋ค.
ํ์ง๋ง, ํ์ง๋ง ์์ํจ์๊ฐ ์๋ ๊ฒฝ์ฐ JAX๋ Python ์ธํฐํ๋ฆฌํฐ์ ๋ค๋ฅด๊ฒ ๋์ํ ์ ์์ต๋๋ค. ์๋๋ ๊ทธ ์์ ์
๋๋ค. ์ด์ ๊ฐ์ ์๋ JAX ์์คํ
์์์ ๋์์ด ๋ณด์ฅ๋์ง ์์ต๋๋ค.
๋ฐ๋ผ์, JAX๋ฅผ ์ฌ์ฉํ๋ ๊ฐ์ฅ ์ ์ ํ ๋ฐฉ๋ฒ์ ํจ์์ ์ผ๋ก ์์ํ Pythonํจ์์ ๋ํด์๋ง ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค.
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
๋ง์ฝ Python ํจ์๊ฐ ์ค์ ๋ก ์คํ ์ดํธํ ๊ฐ์ฒด๋ฅผ ๋ด๋ถ์ ์ผ๋ก ์ฌ์ฉํ๋๋ผ๋, ์ด๋ฅผ ์ธ๋ถ์์ ์ฝ๊ฑฐ๋ ์ฐ์ง ์๋๋ค๋ฉด ํจ์์ ์ผ๋ก ์์ํ๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค.
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
jit์ ์ฌ์ฉํ๋ ค๋ JAX ํจ์๋ ์ด๋ค ์ ์ด ํ๋ฆ ๊ตฌ์ฑ์์์์ iterators๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅ๋์ง ์์ต๋๋ค. ๊ทธ ์ด์ ๋ iterator(๋ฐ๋ณต์)๊ฐ ๋ค์ ์์๋ฅผ ๊ฐ์ ธ์ค๊ธฐ ์ํ ์ํ(state)๋ฅผ ์ฐพ๊ธฐ ์ํ ํ์ด์ฌ ๊ฐ์ฒด์ด๊ธฐ ๋๋ฌธ์
๋๋ค. ์ด๋ฌํ ์ด์ ๋ก, iterators๋ JAX์ ํจ์์ ํ๋ก๊ทธ๋๋ฐ ๋ชจ๋ธ๊ณผ ํธํ๋์ง ์์ต๋๋ค. ์๋ ์ฝ๋๋ JAX์์ iterators๋ฅผ ์ฌ์ฉํ๋ ค๋ ๋ถ์ ์ ํ ์๋๋ค์ ๋ํ ์์๋ฅผ ๋ณด์ฌ์ค๋๋ค. ์ด๋ฌํ ์์ ๋ค ๋๋ถ๋ถ์ ์ค๋ฅ๋ฅผ ๋ฐํํ์ง๋ง, ์ด๋ค ๊ฒฝ์ฐ์๋ ์์์น ๋ชปํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค ์๋ ์์ต๋๋ค.
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
๐ชIn-Place ์ ๋ฐ์ดํธ (In-Place Updates)#
Numpy๋ฅผ ์ฌ์ฉํ ๋, ์ฌ๋ฌ๋ถ๋ค์ ์ข ์ข ์ด๋ ๊ฒ ์ฌ์ฉํ ๊ฒ์ ๋๋ค.
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
ํ์ง๋ง, ๋ง์ฝ JAX device array์ in-place๋ก ์ ๋ฐ์ดํธ๋ฅผ ์๋ํ๋ฉด, ์ค๋ฅ๊ฐ ๋ฐ์ํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. (โ_โ)
%xmode Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
๋ณ์์ in-place ๋ณํ์ ํ์ฉํ๋ ๊ฒ์ ํ๋ก๊ทธ๋จ์ ๋ถ์๊ณผ ๋ณํ์ ์ด๋ ต๊ฒ ๋ง๋๋ ์์ธ์ ๋๋ค. JAX์์๋ ํ๋ก๊ทธ๋จ์ด ์์ ํจ์์ฌ์ผ ํ๋ค๋ ๊ฒ์ ๊ธฐ์ตํฉ์๋ค.
JAX์์๋ in-place ๊ธฐ๋ฒ ๋์ ์, JAX์์๋ JAX array์ .at ์์ฑ์ ์ด์ฉํ์ฌ ํจ์์ ๋ฐฐ์ด ์
๋ฐ์ดํธ๋ฅผ ์ํํฉ๋๋ค. (.at property on JAX arrays.)
โ ๏ธ jit๋ ์ฝ๋ ์ lax.while_loop ๋๋ lax.fori_loop ๋ด๋ถ์์ ์ฌ๋ผ์ด์ค์ ํฌ๊ธฐ๋ ์ธ์ ๊ฐ์ ํจ์๊ฐ ์๋๋ผ ์ธ์ ํํ์ ํจ์์ฌ์ผ ๊ฐ๋ฅํฉ๋๋ค. - ์ฌ๋ผ์ด์ค ์์ ์ธ๋ฑ์ค์๋ ๊ทธ๋ฌํ ์ ํ์ด ์์ต๋๋ค. ์๋์ ์ ์ด ํ๋ฆ ๋ถ๋ถ์์ ์ด๋ฌํ ์ ์ฝ์ ๋ํ ์ ๋ณด๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
๋ฐฐ์ด ์
๋ฐ์ดํธ : x.at[idx].set(y)
์๋ฅผ ๋ค์ด, ์ ๋ฐ์ดํธ๋ ์๋์ ๊ฐ์ด ์์ฑํ ์ ์์ต๋๋ค.
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
JAX์ ์ ๋ฐ์ดํธ ํจ์๋ NumPy์๋ ๋ค๋ฅด๊ฒ out-of-place๋ก ๋์ํฉ๋๋ค. ์ฆ, ์ ๋ฐ์ดํธ๋ ๋ฐฐ์ด์ ์ ๋ฐฐ์ด๋ก ๋ฐํ๋๋ฉฐ ์๋ ๋ฐฐ์ด์ ์ ๋ฐ์ดํธ๋ก ์์ ๋์ง ์์ต๋๋ค.
print("original array unchanged:\n", jax_array)
ํ์ง๋ง, jit์ผ๋ก ์ปดํ์ผ ๋ ์ฝ๋ ๋ด์์ x.at[idx].set(y)์ ์
๋ ฅ ๊ฐ x๊ฐ ์ฌ์ฌ์ฉ๋์ง ์๋ ๊ฒฝ์ฐ, ์ปดํ์ผ๋ฌ๋ in-place๋ก ๋ฐฐ์ด์ด ์
๋ฐ์ดํธ ๋๋๋ก ์ต์ ํํ ์ ์์ต๋๋ค.
๋ค๋ฅธ ์ฐ์ฐ๊ณผ ํจ๊ป ๋ฐฐ์ด ์ ๋ฐ์ดํธ#
์ธ๋ฑ์ค๊ฐ ์ง์ ๋ ๋ฐฐ์ด์ ์ ๋ฐ์ดํธ๋ ๋จ์ํ ๊ฐ์ ๋ฎ์ด์ฐ๋ ๊ฒ์๋ง ์ ํ๋์ง๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ์๋์ ์์์ ๊ฐ์ด ์ธ๋ฑ์ค์ ๋ง์ ์ ํ๋ ์ฐ์ฐ๋ ์ํํ ์ ์์ต๋๋ค.
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
๋ณด๋ค ๋ ์์ธํ ์ธ๋ฑ์ค๋ ๋ฐฐ์ด์ ์
๋ฐ์ดํธ์ ๊ด๋ จํ์ฌ์๋, ํด๋น ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์. (documentatiaon for the .at property)
๐ช ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ฑ (Out-of-Bound Indexing)#
NumPy์์๋ ์ฌ๋ฌ๋ถ์ด ์ธ๋ฑ์ค ๋ฐฐ์ด์ ์ธ๋ฑ์ค ๋ฒ์๋ฅผ ๋ฒ์ด๋๋ ๋์์ ์ํํ๋ฉด ์๋์ ๊ฐ์ ์๋ฌ๋ฅผ ๋ณผ ์ ์์ต๋๋ค.
np.arange(10)[11]
ํ์ง๋ง, ๊ฐ์๊ธฐ์์ ๋์ํ๋ ์ฝ๋๋ก๋ถํฐ ์๋ฌ๋ฅผ ๋ฐ์์ํค๋ ๊ฒ์ ์ด๋ ต๊ฑฐ๋ ์ฌ์ง์ด๋ ๋ถ๊ฐ๋ฅํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ฏ๋ก, JAX๋ ๋ฐฐ์ด์ ๋ฒ์๋ฅผ ๋ฒ์ด๋๋ ์ธ๋ฑ์ฑ์ ๋ํด์ ์ค๋ฅ๊ฐ ์๋ ๋์์ ์ ํํด์ผ ํฉ๋๋ค. (์ ํจํ์ง ์์ ๋ถ๋ ์์์ ์ ์ฐ์ ์ ๊ฒฐ๊ณผ๊ฐ NaN์ด ๋๋ ๊ฒ๊ณผ ์ ์ฌํฉ๋๋ค.). ๋ง์ฝ ์ธ๋ฑ์ฑ ์์
์ด ๋ฐฐ์ด ์ธ๋ฑ์ค ์
๋ฐ์ดํธ(์: index_add ๋๋ scatter-์ ์ฌํ ๊ธฐ๋ณธ ์์)์ธ ๊ฒฝ์ฐ, ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ค์ ์
๋ฐ์ดํธ๋ ๊ฑด๋๋๋๋ค. ์์
์ด ๋ฐฐ์ด ์ธ๋ฑ์ค ๊ฒ์(์: NumPy ์ธ๋ฑ์ฑ ๋๋ gather-์ ์ฌ ๊ธฐ๋ณธ ์์)์ธ ๊ฒฝ์ฐ, ๋ฌด์ธ๊ฐ๋ฅผ ๋ฐํํด์ผ ํ๋ฏ๋ก ์ธ๋ฑ์ค๊ฐ ๋ฐฐ์ด์ ๋ฒ์์ ๊ณ ์ ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์๋์ ์ธ๋ฑ์ฑ ๋์์์๋ ๋ฐฐ์ด์ ๋ง์ง๋ง ๊ฐ์ด ๋ฐํ๋ ๊ฒ์
๋๋ค.
jnp.arange(10)[11]
์ธ๋ฑ์ค ๊ฒ์์ ๋ํ ์ด๋ฌํ ๋์์ผ๋ก ์ธํด jnp.nanargmin ๋ฐ jnp.nanargmax์ ๊ฐ์ ํจ์๋ NaN์ผ๋ก ๊ตฌ์ฑ๋ ์ฌ๋ผ์ด์ค์ ๋ํด -1์ ๋ฐํํ์ง๋ง Numpy๋ ์ค๋ฅ๋ฅผ ๋ฐ์์ํต๋๋ค.
์์์ ์ค๋ช ํ ๋ ๊ฐ์ง ๋์์ด ์๋ก ์ญ์ ๊ด๊ณ๊ฐ ์๋๊ธฐ ๋๋ฌธ์, ์ญ๋ฐฉํฅ ์๋ ๋ฏธ๋ถ(์ธ๋ฑ์ค ์ ๋ฐ์ดํธ๋ฅผ ์ธ๋ฑ์ค ๊ฒ์์ผ๋ก ๋ณํํ๊ณ ๊ทธ ๋ฐ๋๋ก ์ ํ)์ ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ฑ์ ์๋ฏธ๋ฅผ ๋ณด์กดํ์ง ์์ต๋๋ค. ๋ฐ๋ผ์ JAX์ ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ฑ์ ์ ์๋์ง ์์ ๋์์ผ๋ก ์๊ฐํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๐ช ๋น๋ฐฐ์ด ์ ๋ ฅ: NumPy vs. Jax (Non-array inputs: NumPy vs. JAX)#
NumPy๋ ์ผ๋ฐ์ ์ผ๋ก Python์ ๋ฆฌ์คํธ ๋๋ ํํ์ API ํจ์์ ๋ํ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
np.sum([1, 2, 3])
JAX๋ ์ด์ ๋ค๋ฅด๊ฒ ์ผ๋ฐ์ ์ผ๋ก ์ ์ฉํ ์ค๋ฅ๋ฅผ ๋ฐํํฉ๋๋ค.
jnp.sum([1, 2, 3])
์ด๋ ์๋์ ์ผ๋ก ์ค๊ณ๋ ๊ฒฐ๊ณผ์ ๋๋ค. ๊ทธ ์ด์ ๋ ์ถ์ ๋ ํจ์์ ๋ฆฌ์คํธ๋ ํํ์ ์ ๋ฌํ๊ฒ ๋๋ฉด ๊ฐ์งํ๊ธฐ ์ด๋ ค์ด ์กฐ์ฉํ ์ฑ๋ฅ์ ์ ํ๊ฐ ๋ฐ์ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
์๋ฅผ ๋ค์ด, ๋ฆฌ์คํธ์ ์
๋ ฅ์ ํ์ฉํ๋ ๋ค์ ๋ฒ์ ์ jnp.sum์ ๊ณ ๋ คํด๋ด
์๋ค.
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
๊ฒฐ๊ณผ๋ ์์ํ ๋๋ก์ด์ง๋ง ์ฌ๊ธฐ์๋ ์ ์ฌ์ ์ธ ์ฑ๋ฅ ๋ฌธ์ ๊ฐ ์จ๊ฒจ์ ธ ์์ต๋๋ค. JAX์ ์ถ์ ๋ฐ JIT ์ปดํ์ผ ๋ชจ๋ธ์์ Python์ ๋ฆฌ์คํธ ํน์ ํํ์ ๊ฐ ์์๊ฐ ๋ณ๋์ JAX ๋ณ์๋ก ์ทจ๊ธ๋๋ฉฐ, ์ด๋ฌํ ๋ณ์๋ค์ ๊ฐ๋ณ์ ์ผ๋ก ์ฒ๋ฆฌ๋์ด ๋๋ฐ์ด์ค๋ก ์ ์ก๋ฉ๋๋ค. ์ด๋ ์์ permissive_sum ํจ์์ ๋ํ jaxpr์์ ๋ณผ ์ ์์ต๋๋ค.
make_jaxpr(permissive_sum)(x)
๋ฆฌ์คํธ์ ๊ฐ ํญ๋ชฉ์ ๋ณ๋์ ์ ๋ ฅ์ผ๋ก ์ฒ๋ฆฌ๋๋ฏ๋ก, ๋ฆฌ์คํธ์ ํฌ๊ธฐ์ ๋ฐ๋ผ ์ ํ์ ์ผ๋ก ์ฆ๊ฐํ๋ ์ถ์ ๋ฐ ์ปดํ์ผ ์ค๋ฒํค๋๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด, JAX๋ ๋ฆฌ์คํธ ๋ฐ ํํ์ ๋ฐฐ์ด๋ก ์์์ ์ผ๋ก ๋ณํํ๋ ๊ฒ์ ํผํฉ๋๋ค.
๋ฐ๋ผ์, ํํ ๋๋ ๋ฆฌ์คํธ๋ฅผ JAX ํจ์์ ์ ๋ฌํ๋ ค๋ฉด ๋จผ์ ๋ช ์์ ์ผ๋ก ๋ฐฐ์ด๋ก ๋ณํํ ํ ์ ๋ฌํด์ผ ํฉ๋๋ค.
jnp.sum(jnp.array(x))
๐ช ๋์ (Random Numbers)#
rand()๋ก ์ธํด ๊ฒฐ๊ณผ๊ฐ ์์ฌ์ค๋ฌ์ด ๋ชจ๋ ๊ณผํ ๋ ผ๋ฌธ๋ค์ด ๋์๊ด ์ฑ ์ฅ์์ ์ฌ๋ผ์ง๋ค๋ฉด ๊ฐ ์ฑ ์ฅ์๋ ์ฃผ๋จน๋งํ ๊ฐ๊ฒฉ์ด ์๊ธธ ๊ฒ๋๋ค. - Numerical Recipes
RNGs์ State
์ฌ๋ฌ๋ถ๋ค์ NumPy ๋ฐ ๊ธฐํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์คํ ์ดํธํ ์์ฌ ๋์ ์์ฑ๊ธฐ(PRNG)์ ์ต์ํ ๊ฒ์ ๋๋ค. ์ด๋ฌํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ์์ฌ ๋์์ ์์ค๋ฅผ ์ ๊ณตํ๊ธฐ ์ํด ๋ง์ ์ธ๋ถ ์ ๋ณด๋ค์ ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์ ์ฉํ๊ฒ ์จ๊น๋๋ค.
print(np.random.random())
print(np.random.random())
print(np.random.random())
๋ฐฑ๊ทธ๋ผ์ด๋์์ numpy๋ Mersenne Twister PRNG๋ฅผ ์ฌ์ฉํ์ฌ ์์ฌ ๋์ ๊ธฐ๋ฅ์ ๊ฐํํฉ๋๋ค. PRNG์ ์ฃผ๊ธฐ๋ \(2^{19937} - 1\)์ด๊ณ ์ด๋ ์์ ์์๋ 624๊ฐ์ 32๋นํธ ๋ถํธ ์๋ ์ ์์ ์ด โ์ํธ๋กํผโ๊ฐ ์ผ๋ง๋ ๋ง์ด ์ฌ์ฉ๋์๋์ง์ ๋ํ ์์น๋ก ์ค๋ช ํ ์ ์์ต๋๋ค.
np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
์ด ์์ฌ ๋์ ์ํ ๋ฒกํฐ๋ ๋์๊ฐ ํ์ํ ๋๋ง๋ค ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์๋์ ์ผ๋ก ์ ๋ฐ์ดํธ๋์ด Mersenne twister ์ํ ๋ฒกํฐ์ uint32 ์ค 2๊ฐ๋ฅผ โ์๋นโํฉ๋๋ค.
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
Magic PRNG ์ํ์ ๋ฌธ์ ๋ ์๋ก ๋ค๋ฅธ ์ค๋ ๋, ํ๋ก์ธ์ค ๋ฐ ์ฅ์น์์ ์ฌ์ฉ ๋ฐ ์ ๋ฐ์ดํธ ๋๋ ๋ฐฉ์์ ๋ํด ์ถ๋ก ํ๊ธฐ ์ด๋ ต๊ณ ์ํธ๋กํผ ์์ฑ ๋ฐ ์๋น์ ๋ํ ์ธ๋ถ ์ ๋ณด๊ฐ ์ต์ข ์ฌ์ฉ์์๊ฒ ์จ๊ฒจ์ ธ ์์ ๋ ๋ฌธ์ ๋ฅผ ์ผ์ผํค๊ธฐ ๋งค์ฐ ์ฝ๋ค๋ ๊ฒ์ ๋๋ค.
๋ํ, Mersenne Twister PRNG๋ ๋ง์ ๋ฌธ์ ๊ฐ ์๋ ๊ฒ์ผ๋ก ์๋ ค์ ธ ์์ผ๋ฉฐ, 2.5Kb์ ํฐ ์ํ ํฌ๊ธฐ๋ฅผ ๊ฐ์ง๊ณ ์์ด ์ด๊ธฐํ ๋ฌธ์ ๋ฅผ ์ผ๊ธฐํ ์ ์์ต๋๋ค. ๋ํ, ์ต์ BigCrush ํ ์คํธ๋ฅผ ๋ง์กฑํ์ง ๋ชปํ๊ณ ์ผ๋ฐ์ ์ผ๋ก ๋๋ฆฌ๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
JAX PRNG
JAX๋ ๋์ PRNG ์ํ๋ฅผ ๋ช ์์ ์ผ๋ก ์ ๋ฌํ๊ณ ๋ฐ๋ณตํ์ฌ ์ํธ๋กํผ ์์ฑ ๋ฐ ์๋น๋ฅผ ์ฒ๋ฆฌํ๋ ๋ช ์์ PRNG๋ฅผ ๊ตฌํํ์ต๋๋ค. JAX๋ ๋ถํ ๊ฐ๋ฅํ ์ต์ Threefry counter ๊ธฐ๋ฐ PRNG๋ฅผ ์ฌ์ฉํฉ๋๋ค(Threefry counter-based PRNG). ์ฆ, ์ด๋ฌํ ์ค๊ณ๋ฅผ ํตํด PRNG ์ํ๋ฅผ ๋ณ๋ ฌ ํ๋ฅ ์ ์์ฑ์ ์ํด ์ฌ์ฉํ๊ธฐ ์ํด ์๋ก์ด PRNG๋ก ๋ถ๊ธฐํ ์ ์์ต๋๋ค.
๋ฌด์์ ์ํ๋ ํค(key)๋ผ๊ณ ๋ถ๋ฅด๋ ๋ ๊ฐ์ unsigned-int32๋ก ์ค๋ช
๋ฉ๋๋ค.
from jax import random
key = random.PRNGKey(0)
key
JAX์ ์์ ํจ์๋ PRNG ์ํ์์ ์์ฌ ๋์๋ฅผ ์์ฑํ์ง๋ง ์ํ๋ฅผ ๋ณ๊ฒฝํ์ง๋ ์์ต๋๋ค!
๋์ผํ ์ํ๋ฅผ ์ฌ์ฌ์ฉํ๋ ํ์๋ ์ฌํ๊ณผ ๋จ์กฐ๋ก์์ ์ ๋ฐํ๋ฉฐ ๊ฒฐ๊ตญ ์ต์ข
์ฌ์ฉ์์๊ฒ ํผ๋์ ๋ถ์ด๋ฃ๋ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ํ ์ ์์ต๋๋ค!
(Reusing the same state will cause sadness and monotony, depriving the end user of lifegiving chaos:)
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
๋์ , ์๋ก์ด ์์ฌ ๋์๊ฐ ํ์ํ ๋๋ง๋ค PRNG๋ฅผ ๋ถํ ํ์ฌ ์ฌ์ฉ ๊ฐ๋ฅํ ํ์ ํค๋ฅผ ์ป์ต๋๋ค.
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
์๋ก์ด ๋์๊ฐ ํ์ํ ๋๋ง๋ค ํค๋ฅผ ์ ํํ๊ณ ์ ํ์ ํค๋ฅผ ๋ง๋ญ๋๋ค.
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
ํ ๋ฒ์ ๋ ์ด์์ ํ์ํค๋ฅผ ๋ง๋ค ์ ์์ต๋๋ค.
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
๐ช ์ ์ด ํ๋ฆ (Control Flow)#
โ python control_flow + autodiff โ
Python ํจ์์ grad๋ฅผ ์ ์ฉํ๋ ค๋ ๊ฒฝ์ฐ Autograd(๋๋ Pytorch ๋๋ TF Eager)๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ฒ๋ผ ๋ฌธ์ ์์ด ์ผ๋ฐ์ ์ธ Python ์ ์ด ํ๋ฆ ๊ตฌ์ฑ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
Python control flow + JIT
jit์ ํจ๊ป ์ ์ด ํ๋ฆ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋ ๋ณต์กํ๋ฉฐ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ ๋ง์ ์ ์ฝ์ด ์์ต๋๋ค.
์ด ์์๋ ๋์ํฉ๋๋ค.
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
์๋ ์์๋ ๋์ํฉ๋๋ค.
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
ํ์ง๋ง ์ด ์์๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋์ํ์ง ์์ต๋๋ค.
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
์ ๊ทธ๋ด๊น!?
ํจ์๋ฅผ jit ์ปดํ์ผํ ๋ ์ผ๋ฐ์ ์ผ๋ก ์ปดํ์ผ๋ ์ฝ๋๋ฅผ ์บ์ํ์ฌ ๋ค์ํ ์ธ์ ๊ฐ์ ๋ํด ์๋ํ๋ ํจ์ ๋ฒ์ ์ ์ปดํ์ผํ์ฌ ์ฌ์ฌ์ฉํ ์ ์๋๋ก ํฉ๋๋ค. ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ๊ฐ ํจ์ ํ๊ฐ๋ง๋ค ๋ค์ ์ปดํ์ผํ ํ์๊ฐ ์์ต๋๋ค.
์๋ฅผ ๋ค์ด jnp.array([1., 2., 3.], jnp.float32) ๋ฐฐ์ด์์ @jit ํจ์๋ฅผ ํ๊ฐํ๊ธฐ ์ํด jnp.array([4., 5., 6.], jnp.float32)์์ ์ฌ์ฉํ๋ ์ปดํ์ผ๋ ์ฝ๋๋ฅผ ์ฌ์ฌ์ฉํ์ฌ ํ์ฌ ์ปดํ์ผ์ ์ํ๋๋ ์๊ฐ์ ์ ์ฝํ ์ ์์ต๋๋ค.
JAX๋ Python ์ฝ๋์ ๋ค์ํ ์ธ์ ๊ฐ์ ์ ํจํ ๋ทฐ๋ฅผ ์ป๊ธฐ ์ํด JAX๋ ๊ฐ๋ฅํ ์ ๋ ฅ ์งํฉ์ ๋ํ๋ด๋ ์ถ์ ๊ฐ์ผ๋ก ์ฝ๋๋ฅผ ์ถ์ ํฉ๋๋ค. ๋ค์ํ ์ถ์ํ ์์ค์ด ์์ผ๋ฉฐ ์๋ก ๋ค๋ฅธ ๋ณํ์ ์๋ก ๋ค๋ฅธ ์ถ์ํ ์์ค์ ์ฌ์ฉํฉ๋๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก jit์ ShapedArray ์ถ์ํ ์์ค์์ ์ฝ๋๋ฅผ ์ถ์ ํฉ๋๋ค. ๊ฐ ์ถ์ ๊ฐ์ ๊ณ ์ ๋ ๋ชจ์๊ณผ dtype์ ๊ฐ์ง๋ ๋ชจ๋ ๋ฐฐ์ด ๊ฐ์ ์งํฉ์ ๋ํ๋
๋๋ค. ์๋ฅผ ๋ค์ด ์ถ์ ๊ฐ ShapedAray((3,), jnp.float32)๋ฅผ ์ฌ์ฉํ์ฌ ์ถ์ ํ๋ฉด ํด๋น ๋ฐฐ์ด ์ธํธ์ ๊ตฌ์ฒด์ ์ธ ๊ฐ์ ๋ํด ์ฌ์ฌ์ฉํ ์ ์๋ ํจ์์ ๋ทฐ๋ฅผ ์ป์ ์ ์์ต๋๋ค. ์ฆ, ์ปดํ์ผ ์๊ฐ์ ์ค์ผ ์ ์๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
๊ทธ๋ฌ๋ ์ฌ๊ธฐ์๋ ์ฅ๋จ์ ์ด ์์ต๋๋ค. ํน์ ๊ตฌ์ฒด์ ์ธ ๊ฐ์ด ๊ฒฐ์ ๋์ง ์์ ShapedArray((), jnp.float32)์์ Python ํจ์๋ฅผ ์ถ์ ํ๋ ๊ฒฝ์ฐ if x < 3๊ณผ ๊ฐ์ ์ค์ ๋๋ฌํ๋ฉด ํํ์ x < 3์ {True, False} ์งํฉ์ ๋ํ๋ด๋ ์ถ์ ShapedArray((), jnp.bool_)๋ก ํ๊ฐ๋ฉ๋๋ค. Python์ด ์ด๋ฅผ ๊ตฌ์ฒด์ ์ธ True ๋๋ False๋ก ๊ฐ์ ํ๋ ค๊ณ ํ๋ฉด ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด๋ค ๋ถ๊ธฐ๋ฅผ ์ ํํด์ผ ํ ์ง ๋ชจ๋ฅด๊ณ ์ถ์ ์ ๊ณ์ํ ์ ์์ต๋๋ค! ๋จ์ ์ ์ถ์ํ ์์ค์ด ๋์์๋ก Python ์ฝ๋์ ๋ํ ๋ณด๋ค ์ผ๋ฐ์ ์ธ ๋ทฐ๋ฅผ ์ป์ ์ ์์ง๋ง(๋ฐ๋ผ์ ์ฌ์ปดํ์ผ์ ์ค์ผ ์ ์์ต๋๋ค.) ์ถ์ ์ ์๋ฃํ๋ ค๋ฉด Python ์ฝ๋์ ๋ ๋ง์ ์ ์ฝ์ด ํ์ํ๋ค๋ ๊ฒ์
๋๋ค.
์ข์ ์์์, ์ด ํธ๋ ์ด๋์คํ๋ฅผ ์ง์ ์ ์ดํ ์ ์๋ค๋ ๊ฒ์
๋๋ค. ๋ณด๋ค ์ ๋ฐํ ์ถ์ ๊ฐ์ ๋ํ jit ์ถ์ ์ ํตํด ์ถ์ ๊ฐ๋ฅ์ฑ ์ ์ฝ์ ์ํํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด jit์ static_argnums ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ผ๋ถ ์ธ์์ ๊ตฌ์ฒด์ ์ธ ๊ฐ์ ์ถ์ ํ๋๋ก ์ง์ ํ ์ ์์ต๋๋ค. ๋ค์์ ์ด์ ๋ํ ์์ ํจ์์
๋๋ค.
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
๋ฃจํ๋ฅผ ํฌํจํ ๋๋ค๋ฅธ ์์ ์ ๋๋ค.
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(jnp.array([2., 3., 4.]), 2)
static_argnums๋ฅผ ์ด์ฉํ ํจ๊ณผ๋ก ์ธํด ๋ฃจํ๊ฐ ์ ์ ์ผ๋ก ํผ์ณ์ง๋๋ค. JAX๋ ๋ํ Unshaped์ ๊ฐ์ ๋ ๋์ ์์ค์ ์ถ์ํ์์ ์ถ์ ํ ์ ์์ง๋ง ํ์ฌ ๋ณํ์ ๊ธฐ๋ณธ๊ฐ์ ์๋๋๋ค.
โ ๏ธย ์ธ์ ๊ฐ์ ๋ฐ๋ผ ํํ๊ฐ ๋ฐ๋๋ ํจ์
์ด๋ฌํ ์ ์ด ํ๋ฆ ๋ฌธ์ ๋ ๋ณด๋ค ๋ฏธ๋ฌํ ๋ฐฉ์์ผ๋ก๋ ๋ํ๋ฉ๋๋ค. jit ํ๋ ค๋ ์์น ํจ์๋ ๋ด๋ถ ๋ฐฐ์ด์ ๋ชจ์์ ์ธ์ ๊ฐ์ ๋ฐ๋ผ ํน์ ํ ์ ์์ต๋๋ค(์ธ์ ๋ชจ์์ ๋ฐ๋ผ ํน์ ํ๋ ๊ฒ์ ๊ด์ฐฎ์ต๋๋ค). ๊ฐ๋จํ ์๋ก ์
๋ ฅ ๋ณ์ ๊ธธ์ด์ ๋ฐ๋ผ ์ถ๋ ฅ์ด ๋ฌ๋ผ์ง๋ ํจ์๋ฅผ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค.
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
static_argnums๋ ์์ ์์ ๊ธธ์ด์ ๋ณ๊ฒฝ์ด ์ฆ์ง ์์ ๊ฒฝ์ฐ์๋ ํธ๋ฆฌํ ์ ์์ง๋ง ๋ณ๊ฒฝ์ด ์ฆ์ ๊ฒฝ์ฐ ์ฌ์์ด ๋ ์ ์์ต๋๋ค!
๋ง์ง๋ง์ผ๋ก ํจ์์ ์ ์ญ์ ์ธ ๋ถ์ํจ๊ณผ๋ค์ด ์๋ ๊ฒฝ์ฐ JAX์ ์ถ์ ํ๋ก๊ทธ๋จ์ผ๋ก ์ธํด ์ด์ํ ์ผ์ด ๋ฐ์ํ ์ ์์ต๋๋ค. ์ผ๋ฐ์ ์ธ ๋ฌธ์ ๋ jitโd ํจ์ ๋ด์์ ๋ฐฐ์ด์ ์ถ๋ ฅํ๋ ค๊ณ ์๋ํ ๋ ๋ฐ์ํ ์ ์์ต๋๋ค.
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
๊ตฌ์กฐ์ ์ ์ด ํ๋ฆ ๊ธฐ๋ณธ ์์#
JAX์๋ ์ ์ด ํ๋ฆ์ ๋ํ ๋ค์ํ ์ต์ ์ด ๋ง์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ์ฌ์ปดํ์ผ์ ํผํ๊ณ ์ถ์ ๊ฐ๋ฅํ ์ ์ด ํ๋ฆ์ ์ฌ์ฉํ๋ฉด์ ํฐ ๋ฃจํ๋ฅผ ํ๊ณ ์ถ์ง ์๋ค๋ฉด ์๋์ 4๊ฐ์ง ๊ตฌ์กฐ์ ์ ์ด ํ๋ฆ ๊ธฐ๋ณธ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
lax.condย differentiablelax.while_loopย fwd-mode-differentiablelax.fori_loopย fwd-mode-differentiableย in general;ย fwd and rev-mode differentiableย if endpoints are static.lax.scanย differentiable
cond#
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
jax.lax์๋ ๋์ ์กฐ๊ฑด์ ๋ฐ๋ผ ๋ถ๊ธฐํ ์ ์๋ ๋ค๋ฅธ ๋ ๊ฐ์ ํจ์๊ฐ ์ ๊ณต๋ฉ๋๋ค.
lax.select๋lax.cond์ ๋ฐฐ์น ๋ฒ์ ์ด์ง๋ง, ์ ํ์ง๋ ์ด์ ์ ๊ณ์ฐ๋ ๋ฐฐ์ด๋ก ํํ๋ฉ๋๋ค.lax.switch๋lax.cond์ ์ ์ฌํ์ง๋ง, ์ด๋ค ์์ ํธ์ถ ๊ฐ๋ฅํ ์ ํ์ง ์ฌ์ด์ ์ ํํ ์ ์์ต๋๋ค.
๋ํ, jax.numpy์์๋ ์ด๋ฌํ ํจ์์ ๋ํ ๋ค์์ Numpy ์คํ์ผ ์ธํฐํ์ด์ค๊ฐ ์ ๊ณต๋ฉ๋๋ค.
jnp.where๋ 3๊ฐ์ ์ธ์๊ฐ์๋ lax.select์ Numpy ์คํ์ผ ๋ํผ(wrapper)์ ๋๋ค.jnp.piecewise๋lax.switch์ Numpy ์คํ์ผ ๋ํผ(wrapper)์ด์ง๋ง, ๋จ์ผ ์ค์นผ๋ผ ์ธ๋ฑ์ค ๋์ ์ ๋ถ๋ฆฌ์ธ ์กฐ๊ฑด์ ๋ชฉ๋ก์ ๋ฐ๋ผ ์ ํํฉ๋๋ค.jnp.select๋jnp.piecewise์ ์ ์ฌํ API๋ฅผ ๊ฐ์ง์ง๋ง, ์ ํ์ง๋ ์ฌ์ ๊ณ์ฐ๋ ๋ฐฐ์ด๋ก ์ ๊ณต๋ฉ๋๋ค. ๊ฒฐ๊ณผ์ ์ผ๋กlax.select์ ์ฌ๋ฌ ํธ์ถ๋ก ๊ตฌํ๋ฉ๋๋ค.
while_loop#
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
fori_loop#
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Summary#
๐ช ๋์ ํํ (Dynamic Shapes)#
jax.jit, jax.vmap, jax.grad ๋ฑ๊ณผ ๊ฐ์ ๋ณํ ๋ด์์ ์ฌ์ฉ๋๋ JAX ์ฝ๋๋ ๋ชจ๋ ์ถ๋ ฅ ๋ฐฐ์ด๊ณผ ์ค๊ฐ ๋ฐฐ์ด์ด ์ ์ ๋ชจ์์ ๊ฐ์ ธ์ผ ํฉ๋๋ค. ์ฆ, ๋ชจ์์ ๋ค๋ฅธ ๋ฐฐ์ด ๋ด์ ๊ฐ์ ์์กดํ์ง ์์์ผ ํฉ๋๋ค.
์๋ฅผ ๋ค์ด, jnp.nansum์ ๋ฒ์ ์ ์ง์ ๊ตฌํํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ์์ํ ์ ์์ต๋๋ค
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
JIT ๋ฐ ๊ธฐํ ๋ณํ ์ธ๋ถ์์๋ ์์๋๋ก ์๋ํฉ๋๋ค.
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
jax.jit ๋๋ ๋ค๋ฅธ ๋ณํ์ ์ด ํจ์์ ์ ์ฉํ๋ ค๊ณ ํ๋ฉด ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค.
jax.jit(nansum)(x)
๋ฌธ์ ๋ x_without_nans์ ํฌ๊ธฐ๊ฐ x ๋ด์ ๊ฐ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ๊ฒฐ์ ๋๋ค๋ ์ ์
๋๋ค. ์ข
์ข
JAX์์๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ํตํด ๋์ ์ผ๋ก ํฌ๊ธฐ ์กฐ์ ๋ ๋ฐฐ์ด์ ํ์์ฑ์ ํด๊ฒฐํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด jnp.where์ 3๊ฐ ์ธ์ ํ์์ ์ฌ์ฉํ์ฌ NaN ๊ฐ์ 0์ผ๋ก ๋์ฒดํจ์ผ๋ก์จ ๋์ ๋ชจ์์ ํผํ๋ฉด์๋ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ๊ณ์ฐํ ์ ์์ต๋๋ค.
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
๋์ ๋ชจ์์ ๋ฐฐ์ด์ด ๋ฐ์ํ๋ ๋ค๋ฅธ ์ํฉ์์๋ ์ ์ฌํ ํธ๋ฆญ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
๐ช NaNs#
NaNs ๋๋ฒ๊น #
ํจ์ ๋๋ ๊ทธ๋๋์ธํธ์์ NaN์ด ๋ฐ์ํ๋ ์์น๋ฅผ ์ถ์ ํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด NaN ๊ฒ์ฌ๊ธฐ๋ฅผ ์ผค ์ ์์ต๋๋ค.
JAX_DEBUG_NANS=Trueย ํ๊ฒฝ ๋ณ์ ์ค์ ๋ฉ์ธ ํ์ผ ์๋จ์ ย
fromย jax.configย importย configย ์ยconfig.update("jax_debug_nans",True)ย ๋ฅผ ์ถ๊ฐํ์ธ์.๋ฉ์ธ ํ์ผ์ย
fromย jax.configย importย configย ์ยconfig.parse_flags_with_absl()ย ๋ฅผ ์ถ๊ฐํ์ธ์. ๊ทธ๋ฐ ๋ค์ ๋ช ๋ น ์ค ํ๋๊ทธ์ย-jax_debug_nans=True์ ์ด์ฉํ์ฌ ์ต์ ์ ์ค์ ํ์ธ์.
์ด๋ก ์ธํด NaN ์์ฑ ์ฆ์ ๊ณ์ฐ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด ์ต์
์ ์ผ๋ฉด XLA์์ ์์ฑ๋ ๋ชจ๋ ๋ถ๋ ์์์ ์ ํ ๊ฐ์ nan ๊ฒ์ฌ๊ฐ ์ถ๊ฐ๋ฉ๋๋ค. ์ฆ, @jit์ ์ํ์ง ์๋ ๋ชจ๋ ๊ธฐ๋ณธ ์์
์ ๋ํด ๊ฐ์ ๋ค์ ํธ์คํธ๋ก ๊ฐ์ ธ์ ndarry๋ก ๊ฒ์ฌํฉ๋๋ค. @jit ํ์์ ์๋ ์ฝ๋์ ๊ฒฝ์ฐ ๋ชจ๋ @jitํจ์์ ์ถ๋ ฅ์ ๊ฒ์ฌํ๊ณ NaN์ด ์๋ ๊ฒฝ์ฐ ์ต์ ํ๋์ง ์์ op-by-op ๋ชจ๋์์ ํจ์๋ฅผ ๋ค์ ์คํํ์ฌ ํ ๋ฒ์ ํ ๋ ๋ฒจ์ฉ @jit์ ์ ๊ฑฐํฉ๋๋ค.
@jit์์๋ง ๋ฐ์ํ์ง๋ง ์ต์ ํ๋์ง ์์ ๋ชจ๋์์๋ ์์ฑ๋์ง ์๋ nan๊ณผ ๊ฐ์ ๊น๋ค๋ก์ด ์ํฉ์ด ๋ฐ์ํ ์ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ ๊ฒฝ๊ณ ๋ฉ์์ง๊ฐ ์ถ๋ ฅ๋์ง๋ง ์ฝ๋๋ ๊ณ์ ์คํ๋ฉ๋๋ค.
๊ทธ๋๋์ธํธ ํ๊ฐ์ ์ญ๋ฐฉํฅ ํจ์ค์์ nans๊ฐ ์์ฑ๋๋ ๊ฒฝ์ฐ ์คํ ์ถ์ ์์ ๋ช ํ๋ ์ ์๋ก ์์ธ๊ฐ ๋ฐ์ํ๋ฉด backward_pass ํจ์ ๋ด๋ถ๋ก ์ง์
ํฉ๋๋ค. ์ด ํจ์๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๊ธฐ๋ณธ ์์
์ํ์ค๋ฅผ ์ญ์์ผ๋ก ์ํํ๋ ๊ฐ๋จํ jaxpr ์ธํฐํ๋ฆฌํฐ์
๋๋ค. ์๋ ์์์ env JAX_DEBUG_NANS=True ipython ๋ช
๋ น์ค์ ์ฌ์ฉํ์ฌ ipython repl์ ์์ํ ํ, ๋ค์์ ์คํํ์ต๋๋ค.
import jax.numpy as jnp
jnp.divide(0., 0.)
์์ฑ๋ NaN์ด ์กํ์ต๋๋ค. %debug๋ฅผ ์คํํ๋ฉด ์ฌํ ๋๋ฒ๊ฑฐ๋ฅผ ์ป์ ์ ์์ต๋๋ค. ์ด๊ฒ์ ์๋ ์์ ์ ๊ฐ์ด @jit ์ผ๋ก ๊ฐ์ธ์ง ํจ์์์๋ ์๋ํฉ๋๋ค.
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return DeviceArray(device_buffer, *result_shape)
FloatingPointError: invalid value
@jit
def f(x, y):
a = x * y
b = (x + y) / (x - y)
c = a + 2
return a + b * c
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
์ด ์ฝ๋๋ @jit ํจ์์ ์ถ๋ ฅ์์ nan์ ๋ฐ๊ฒฌํ๋ฉด ์ต์ ํ๋์ง ์์ ์ฝ๋๋ฅผ ํธ์ถํ๋ฏ๋ก ์ฌ์ ํ ๋ช
ํํ ์คํ์ ์ถ์ ํ ์ ์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ %debug๋ก ์ฌํ ๋๋ฒ๊ฑฐ๋ฅผ ์คํํ์ฌ ์ค๋ฅ๋ฅผ ํ์
ํ๊ธฐ ์ํด ๋ชจ๋ ๊ฐ์ ๊ฒ์ฌํ ์ ์์ต๋๋ค.
โ ๏ธ ๋๋ฒ๊น ํ์ง ์๋ ๊ฒฝ์ฐ NaN ๊ฒ์ฌ๊ธฐ๋ฅผ ์ผ์๋ ์ ๋ฉ๋๋ค. ๋ง์ ์ฅ์น-ํธ์คํธ ์๋ณต ๋ฐ ์ฑ๋ฅ ์ ํ๊ฐ ๋ฐ์ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค!
โ ๏ธ NaN ๊ฒ์ฌ๊ธฐ๋ pmap์์ ์๋ํ์ง ์์ต๋๋ค. pmap ์ฝ๋์์ nans๋ฅผ ๋๋ฒ๊น ํ๋ ค๋ฉด pmap์ vmap์ผ๋ก ๊ต์ฒดํด์ผ ํฉ๋๋ค.
๐ช Double (64bit) ์ ๋ฐ๋ (Double (64bit) precision)#
ํ์ฌ JAX๋ ๊ธฐ๋ณธ์ ์ผ๋ก NumPy API๊ฐ ํผ์ฐ์ฐ์๋ฅผ ๊ฐ์ ๋ก ๋๋ธํ(double)์ผ๋ก ๋ณํํ๋ ๊ฒฝํฅ์ ์ํํ๊ธฐ ์ํด ๋จ์ ๋ฐ๋(single-precision) ์ซ์๋ฅผ ๊ฐ์ ๋ก ์ ์ฉํ๊ณ ์์ต๋๋ค. ์ด๋ ๋ง์ ๋จธ์ ๋ฌ๋ ์ ํ๋ฆฌ์ผ์ด์ ์์ ์ํ๋ ๋์์ด์ง๋ง, ์ด๋ ์์์น ๋ชปํ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ํ ์ ์์ต๋๋ค!
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype
Double ์ ๋ฐ๋์ ์ซ์๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด, ์์ ์ jax_enable_x64 ๊ตฌ์ฑ ๋ณ์๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค.
์ด๋ฅผ ์ํํ๊ธฐ ์ํ ๋ช๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
JAX_ENABLE_X64=True๋ก ์ค์ ํ์ฌ 64๋นํธ ๋ชจ๋๋ฅผ ์ฌ์ฉ ๊ฐ๋ฅํ๊ฒ ํ ์ ์์ต๋๋ค.์์ ์์ย
jax_enable_x64ย ๊ตฌ์ฑ ํ๋๊ทธ๋ฅผ ์๋์ผ๋ก ์ค์ ํ ์ ์์ต๋๋ค:
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
absl.app.run(main)์ ์ฌ์ฉํ์ฌ ๋ช ๋ น์ค ํ๋๊ทธ๋ฅผ ํ์ฑํ ์ ์์ต๋๋ค.
from jax.config import config
config.config_with_absl()
absl.app.run(main)๋ฅผ ์ฌ์ฉํ์ง ์๊ณ JAX๊ฐ absl ํ์ฑ์ ์ํํ๊ฒ ํ๋ ค๋ฉด ๋ค์๊ณผ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค:
from jax.config import config
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
#2-#4๋ JAX์ ๋ชจ๋ ๊ตฌ์ฑ ์ต์ ์์ ์๋ํฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ x64 ๋ชจ๋๊ฐ ํ์ฑํ๋์๋์ง ํ์ธํ ์ ์์ต๋๋ค.
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
์ฃผ์์ฌํญ#
โ ๏ธ XLA๋ ๋ชจ๋ ๋ฐฑ์๋์์ 64๋นํธ ์ปจ๋ณผ๋ฃจ์ ์ ์ง์ํ์ง ์์ต๋๋ค!
๐ช NumPy์์ ์ ๋๋ ์ฌ๋ฌ๊ฐ์ง ํ์๋ค (Miscellaneous Divergences from NumPy)#
jax.numpy๋ Numpy API ๋์์ ์ ์ฌํ๊ฒ ๋์ํ๋๋ก ์ค๊ณ๋์์ง๋ง, ๋์์ด ๋ค๋ฅธ ์ฝ๋์ผ์ด์ค๋ค์ด ์์ต๋๋ค. ์ด๋ฌํ ๋ง์ ๊ฒฝ์ฐ๋ค์ ์ ์น์
์์ ์์ธํ ์ค๋ช
ํฉ๋๋ค. ์ฌ๊ธฐ์๋ API๊ฐ ๋ค๋ฅด๊ฒ ๋์ํ๋ ๋ช ๊ฐ์ง ๋ค๋ฅธ ์ฌ๋ก๋ค์ ๋์ปํฉ๋๋ค.
๋ฐ์ด๋๋ฆฌ ์์ ์ ๊ฒฝ์ฐ JAX์ ์ ํ ์น๊ฒฉ ๊ท์น์ NumPy์์ ์ฌ์ฉํ๋ ๊ท์น๊ณผ ๋ค์ ๋ค๋ฆ ๋๋ค. ์์ธํ ๋ด์ฉ์ Type Promotion Semantics๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
์์ ํ์ง ์์ ์ ํ ์บ์คํ (์ฆ, ๋์ dtype์ด ์ ๋ ฅ ๊ฐ์ ๋ํ๋ผ ์ ์๋ ์บ์คํ )๋ฅผ ์ํํ ๋ JAX์ ๋์์ ๋ฐฑ์๋์ ๋ฐ๋ผ ๋ค๋ฅผ ์ ์์ผ๋ฉฐ ์ผ๋ฐ์ ์ผ๋ก NumPy์ ๋์๊ณผ ๋ค๋ฅผ ์ ์์ต๋๋ค. Numpy๋ ์บ์คํ ์ธ์๋ฅผ ํตํด ์ด๋ฌํ ์๋๋ฆฌ์ค์์ ๊ฒฐ๊ณผ๋ฅผ ์ ์ดํ ์ ์์ต๋๋ค(np.ndarray.astype ์ฐธ์กฐ). JAX๋ XLA:ConvertElementType์ ๋์์ ์ง์ ์์ํ๋ ๋์ ์ด๋ฌํ ๊ตฌ์ฑ์ ์ ๊ณตํ์ง ์์ต๋๋ค.
์ฌ๊ธฐ์ NumPy์ JAX์ ์์ ํ์ง ์์ ์บ์คํ ์ ๋ฐ๋ฅธ ๋ค๋ฅธ ๊ฒฐ๊ณผ์ ๋ํ ์์ ์ ๋๋ค.
np.arange(254.0, 258.0).astype('uint8')
array([254, 255, 0, 1], dtype=uint8)
jnp.arange(254.0, 258.0).astype('uint8')
DeviceArray([254, 255, 255, 255], dtype=uint8)
์ด๋ฌํ ์ข ๋ฅ์ ๋ถ์ผ์น๋ ์ผ๋ฐ์ ์ผ๋ก ๋ถ๋์์ ์ ์ ์ ํ์ผ๋ก ๋๋ ๊ทธ ๋ฐ๋๋ก ๊ทน๋จ์ ์ธ ๊ฐ์ ์บ์คํ ํ ๋ ๋ฐ์ํฉ๋๋ค.
Fin.#
์ฌ๊ธฐ์์ ๋ค๋ฃจ์ง ์์ ๋ช๋ช ๋น์ ์ ํ๋๊ฒ ํ๋ ์์ธ๋ค์ ์ ๋ณดํด์ฃผ์๋ฉด ํด๋น ํํ ๋ฆฌ์ผ ํ์ด์ง์ ๋ฐ์ํ๊ฒ ์ต๋๋ค.