๐ชJAX - ์ธ๋ถ์ ์ธ ํน์ง๋ค (JAX - The Sharp Bits)๐ช#
์ ์ : Anselm Levskaya & Mattew Johnson
์ญ์ : ์ฅ์ง์ฐ(wdfrty1234@gmail.com)
๊ฒ์ : ์ด์๋น, ๋ฐ์ ํ
์ดํ๋ฆฌ์ ๋ณ๋๋ฆฌ๋ฅผ ๊ฑท๋ค๋ณด๋ฉด, ์ฌ๋๋ค์ด
JAX
์ ๋ํด *โuna anima di pura programmazione funzionale(์์ํ ํจ์ํ ํ๋ก๊ทธ๋๋ฐ์ ์ํผ)โ*์ด๋ผ๋ ์ฉ์ด๋ก ๋ฌ์ฌํ๋ ์ฌ๋๋ค์ด ๋ง๋ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
JAX
๋ ์์นํด์ ํ๋ก๊ทธ๋จ์ ๋ณํ์ ์ํ ์ธ์ด๋ก, CPU ํน์ ๊ฐ์๊ธฐ(GPU/TPU)์์ ์์นํ ํ๋ก๊ทธ๋จ์ ์ปดํ์ผํ์ฌ ๋์์ํฌ ์ ์์ต๋๋ค. ํน์ ํ ์ ์ฝ์กฐ๊ฑด์ ์ถฉ์กฑํ๋ ๊ฒฝ์ฐ, JAX
๋ ์์น ๋ฐ ๊ณผํ ํ๋ก๊ทธ๋จ์์ ์ ๋์ํฉ๋๋ค!
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
๐ช์์ ํจ์ (Pure functions)#
JAX
๋ (๋ชจ๋ ์
๋ ฅ ๋ฐ์ดํฐ๊ฐ ํจ์์ ๋งค๊ฐ๋ณ์๋ฅผ ํตํด ์ ๋ฌ๋๊ณ , ๋ชจ๋ ์ถ๋ ฅ ๊ฒฐ๊ณผ๊ฐ ํจ์์ ๊ฒฐ๊ณผ๋ฅผ ํตํด์ ๋์ค๋) ์์ ํจ์์์ ์ ๋์ํ๋๋ก ์ค๊ณ๋์ด ์์ต๋๋ค.
(์ญ์ ์ฃผ) ์์ ํจ์๋ ๊ฐ์ ์ ๋ ฅ์ด ์ฃผ์ด์ง๋ค๋ฉด ํญ์ ๊ฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํ๋ ํจ์๋ฅผ ์๋ฏธํฉ๋๋ค.
ํ์ง๋ง, ํ์ง๋ง ์์ํจ์๊ฐ ์๋ ๊ฒฝ์ฐ JAX
๋ Python ์ธํฐํ๋ฆฌํฐ์ ๋ค๋ฅด๊ฒ ๋์ํ ์ ์์ต๋๋ค. ์๋๋ ๊ทธ ์์ ์
๋๋ค. ์ด์ ๊ฐ์ ์๋ JAX
์์คํ
์์์ ๋์์ด ๋ณด์ฅ๋์ง ์์ต๋๋ค.
๋ฐ๋ผ์, JAX
๋ฅผ ์ฌ์ฉํ๋ ๊ฐ์ฅ ์ ์ ํ ๋ฐฉ๋ฒ์ ํจ์์ ์ผ๋ก ์์ํ Pythonํจ์์ ๋ํด์๋ง ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค.
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
๋ง์ฝ Python ํจ์๊ฐ ์ค์ ๋ก ์คํ ์ดํธํ ๊ฐ์ฒด๋ฅผ ๋ด๋ถ์ ์ผ๋ก ์ฌ์ฉํ๋๋ผ๋, ์ด๋ฅผ ์ธ๋ถ์์ ์ฝ๊ฑฐ๋ ์ฐ์ง ์๋๋ค๋ฉด ํจ์์ ์ผ๋ก ์์ํ๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค.
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
jit
์ ์ฌ์ฉํ๋ ค๋ JAX
ํจ์๋ ์ด๋ค ์ ์ด ํ๋ฆ ๊ตฌ์ฑ์์์์ iterators๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅ๋์ง ์์ต๋๋ค. ๊ทธ ์ด์ ๋ iterator(๋ฐ๋ณต์)๊ฐ ๋ค์ ์์๋ฅผ ๊ฐ์ ธ์ค๊ธฐ ์ํ ์ํ(state)๋ฅผ ์ฐพ๊ธฐ ์ํ ํ์ด์ฌ ๊ฐ์ฒด์ด๊ธฐ ๋๋ฌธ์
๋๋ค. ์ด๋ฌํ ์ด์ ๋ก, iterators๋ JAX์ ํจ์์ ํ๋ก๊ทธ๋๋ฐ ๋ชจ๋ธ๊ณผ ํธํ๋์ง ์์ต๋๋ค. ์๋ ์ฝ๋๋ JAX์์ iterators๋ฅผ ์ฌ์ฉํ๋ ค๋ ๋ถ์ ์ ํ ์๋๋ค์ ๋ํ ์์๋ฅผ ๋ณด์ฌ์ค๋๋ค. ์ด๋ฌํ ์์ ๋ค ๋๋ถ๋ถ์ ์ค๋ฅ๋ฅผ ๋ฐํํ์ง๋ง, ์ด๋ค ๊ฒฝ์ฐ์๋ ์์์น ๋ชปํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค ์๋ ์์ต๋๋ค.
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
๐ชIn-Place ์ ๋ฐ์ดํธ (In-Place Updates)#
Numpy๋ฅผ ์ฌ์ฉํ ๋, ์ฌ๋ฌ๋ถ๋ค์ ์ข ์ข ์ด๋ ๊ฒ ์ฌ์ฉํ ๊ฒ์ ๋๋ค.
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
ํ์ง๋ง, ๋ง์ฝ JAX device array์ in-place๋ก ์ ๋ฐ์ดํธ๋ฅผ ์๋ํ๋ฉด, ์ค๋ฅ๊ฐ ๋ฐ์ํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. (โ_โ)
%xmode Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
๋ณ์์ in-place ๋ณํ์ ํ์ฉํ๋ ๊ฒ์ ํ๋ก๊ทธ๋จ์ ๋ถ์๊ณผ ๋ณํ์ ์ด๋ ต๊ฒ ๋ง๋๋ ์์ธ์ ๋๋ค. JAX์์๋ ํ๋ก๊ทธ๋จ์ด ์์ ํจ์์ฌ์ผ ํ๋ค๋ ๊ฒ์ ๊ธฐ์ตํฉ์๋ค.
JAX
์์๋ in-place ๊ธฐ๋ฒ ๋์ ์, JAX์์๋ JAX array์ .at
์์ฑ์ ์ด์ฉํ์ฌ ํจ์์ ๋ฐฐ์ด ์
๋ฐ์ดํธ๋ฅผ ์ํํฉ๋๋ค. (.at
property on JAX arrays.)
โ ๏ธ jit
๋ ์ฝ๋ ์ lax.while_loop
๋๋ lax.fori_loop
๋ด๋ถ์์ ์ฌ๋ผ์ด์ค์ ํฌ๊ธฐ๋ ์ธ์ ๊ฐ์ ํจ์๊ฐ ์๋๋ผ ์ธ์ ํํ์ ํจ์์ฌ์ผ ๊ฐ๋ฅํฉ๋๋ค. - ์ฌ๋ผ์ด์ค ์์ ์ธ๋ฑ์ค์๋ ๊ทธ๋ฌํ ์ ํ์ด ์์ต๋๋ค. ์๋์ ์ ์ด ํ๋ฆ ๋ถ๋ถ์์ ์ด๋ฌํ ์ ์ฝ์ ๋ํ ์ ๋ณด๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
๋ฐฐ์ด ์
๋ฐ์ดํธ : x.at[idx].set(y)
์๋ฅผ ๋ค์ด, ์ ๋ฐ์ดํธ๋ ์๋์ ๊ฐ์ด ์์ฑํ ์ ์์ต๋๋ค.
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
JAX์ ์ ๋ฐ์ดํธ ํจ์๋ NumPy์๋ ๋ค๋ฅด๊ฒ out-of-place๋ก ๋์ํฉ๋๋ค. ์ฆ, ์ ๋ฐ์ดํธ๋ ๋ฐฐ์ด์ ์ ๋ฐฐ์ด๋ก ๋ฐํ๋๋ฉฐ ์๋ ๋ฐฐ์ด์ ์ ๋ฐ์ดํธ๋ก ์์ ๋์ง ์์ต๋๋ค.
print("original array unchanged:\n", jax_array)
ํ์ง๋ง, jit
์ผ๋ก ์ปดํ์ผ ๋ ์ฝ๋ ๋ด์์ x.at[idx].set(y)
์ ์
๋ ฅ ๊ฐ x๊ฐ ์ฌ์ฌ์ฉ๋์ง ์๋ ๊ฒฝ์ฐ, ์ปดํ์ผ๋ฌ๋ in-place๋ก ๋ฐฐ์ด์ด ์
๋ฐ์ดํธ ๋๋๋ก ์ต์ ํํ ์ ์์ต๋๋ค.
๋ค๋ฅธ ์ฐ์ฐ๊ณผ ํจ๊ป ๋ฐฐ์ด ์ ๋ฐ์ดํธ#
์ธ๋ฑ์ค๊ฐ ์ง์ ๋ ๋ฐฐ์ด์ ์ ๋ฐ์ดํธ๋ ๋จ์ํ ๊ฐ์ ๋ฎ์ด์ฐ๋ ๊ฒ์๋ง ์ ํ๋์ง๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ์๋์ ์์์ ๊ฐ์ด ์ธ๋ฑ์ค์ ๋ง์ ์ ํ๋ ์ฐ์ฐ๋ ์ํํ ์ ์์ต๋๋ค.
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
๋ณด๋ค ๋ ์์ธํ ์ธ๋ฑ์ค๋ ๋ฐฐ์ด์ ์
๋ฐ์ดํธ์ ๊ด๋ จํ์ฌ์๋, ํด๋น ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์. (documentatiaon for the .at
property)
๐ช ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ฑ (Out-of-Bound Indexing)#
NumPy์์๋ ์ฌ๋ฌ๋ถ์ด ์ธ๋ฑ์ค ๋ฐฐ์ด์ ์ธ๋ฑ์ค ๋ฒ์๋ฅผ ๋ฒ์ด๋๋ ๋์์ ์ํํ๋ฉด ์๋์ ๊ฐ์ ์๋ฌ๋ฅผ ๋ณผ ์ ์์ต๋๋ค.
np.arange(10)[11]
ํ์ง๋ง, ๊ฐ์๊ธฐ์์ ๋์ํ๋ ์ฝ๋๋ก๋ถํฐ ์๋ฌ๋ฅผ ๋ฐ์์ํค๋ ๊ฒ์ ์ด๋ ต๊ฑฐ๋ ์ฌ์ง์ด๋ ๋ถ๊ฐ๋ฅํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ฏ๋ก, JAX๋ ๋ฐฐ์ด์ ๋ฒ์๋ฅผ ๋ฒ์ด๋๋ ์ธ๋ฑ์ฑ์ ๋ํด์ ์ค๋ฅ๊ฐ ์๋ ๋์์ ์ ํํด์ผ ํฉ๋๋ค. (์ ํจํ์ง ์์ ๋ถ๋ ์์์ ์ ์ฐ์ ์ ๊ฒฐ๊ณผ๊ฐ NaN์ด ๋๋ ๊ฒ๊ณผ ์ ์ฌํฉ๋๋ค.). ๋ง์ฝ ์ธ๋ฑ์ฑ ์์
์ด ๋ฐฐ์ด ์ธ๋ฑ์ค ์
๋ฐ์ดํธ(์: index_add
๋๋ scatter
-์ ์ฌํ ๊ธฐ๋ณธ ์์)์ธ ๊ฒฝ์ฐ, ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ค์ ์
๋ฐ์ดํธ๋ ๊ฑด๋๋๋๋ค. ์์
์ด ๋ฐฐ์ด ์ธ๋ฑ์ค ๊ฒ์(์: NumPy ์ธ๋ฑ์ฑ ๋๋ gather
-์ ์ฌ ๊ธฐ๋ณธ ์์)์ธ ๊ฒฝ์ฐ, ๋ฌด์ธ๊ฐ๋ฅผ ๋ฐํํด์ผ ํ๋ฏ๋ก ์ธ๋ฑ์ค๊ฐ ๋ฐฐ์ด์ ๋ฒ์์ ๊ณ ์ ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์๋์ ์ธ๋ฑ์ฑ ๋์์์๋ ๋ฐฐ์ด์ ๋ง์ง๋ง ๊ฐ์ด ๋ฐํ๋ ๊ฒ์
๋๋ค.
jnp.arange(10)[11]
์ธ๋ฑ์ค ๊ฒ์์ ๋ํ ์ด๋ฌํ ๋์์ผ๋ก ์ธํด jnp.nanargmin
๋ฐ jnp.nanargmax
์ ๊ฐ์ ํจ์๋ NaN์ผ๋ก ๊ตฌ์ฑ๋ ์ฌ๋ผ์ด์ค์ ๋ํด -1์ ๋ฐํํ์ง๋ง Numpy๋ ์ค๋ฅ๋ฅผ ๋ฐ์์ํต๋๋ค.
์์์ ์ค๋ช ํ ๋ ๊ฐ์ง ๋์์ด ์๋ก ์ญ์ ๊ด๊ณ๊ฐ ์๋๊ธฐ ๋๋ฌธ์, ์ญ๋ฐฉํฅ ์๋ ๋ฏธ๋ถ(์ธ๋ฑ์ค ์ ๋ฐ์ดํธ๋ฅผ ์ธ๋ฑ์ค ๊ฒ์์ผ๋ก ๋ณํํ๊ณ ๊ทธ ๋ฐ๋๋ก ์ ํ)์ ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ฑ์ ์๋ฏธ๋ฅผ ๋ณด์กดํ์ง ์์ต๋๋ค. ๋ฐ๋ผ์ JAX์ ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ธ๋ฑ์ฑ์ ์ ์๋์ง ์์ ๋์์ผ๋ก ์๊ฐํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๐ช ๋น๋ฐฐ์ด ์ ๋ ฅ: NumPy vs. Jax (Non-array inputs: NumPy vs. JAX)#
NumPy๋ ์ผ๋ฐ์ ์ผ๋ก Python์ ๋ฆฌ์คํธ ๋๋ ํํ์ API ํจ์์ ๋ํ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
np.sum([1, 2, 3])
JAX๋ ์ด์ ๋ค๋ฅด๊ฒ ์ผ๋ฐ์ ์ผ๋ก ์ ์ฉํ ์ค๋ฅ๋ฅผ ๋ฐํํฉ๋๋ค.
jnp.sum([1, 2, 3])
์ด๋ ์๋์ ์ผ๋ก ์ค๊ณ๋ ๊ฒฐ๊ณผ์ ๋๋ค. ๊ทธ ์ด์ ๋ ์ถ์ ๋ ํจ์์ ๋ฆฌ์คํธ๋ ํํ์ ์ ๋ฌํ๊ฒ ๋๋ฉด ๊ฐ์งํ๊ธฐ ์ด๋ ค์ด ์กฐ์ฉํ ์ฑ๋ฅ์ ์ ํ๊ฐ ๋ฐ์ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
์๋ฅผ ๋ค์ด, ๋ฆฌ์คํธ์ ์
๋ ฅ์ ํ์ฉํ๋ ๋ค์ ๋ฒ์ ์ jnp.sum
์ ๊ณ ๋ คํด๋ด
์๋ค.
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
๊ฒฐ๊ณผ๋ ์์ํ ๋๋ก์ด์ง๋ง ์ฌ๊ธฐ์๋ ์ ์ฌ์ ์ธ ์ฑ๋ฅ ๋ฌธ์ ๊ฐ ์จ๊ฒจ์ ธ ์์ต๋๋ค. JAX์ ์ถ์ ๋ฐ JIT ์ปดํ์ผ ๋ชจ๋ธ์์ Python์ ๋ฆฌ์คํธ ํน์ ํํ์ ๊ฐ ์์๊ฐ ๋ณ๋์ JAX ๋ณ์๋ก ์ทจ๊ธ๋๋ฉฐ, ์ด๋ฌํ ๋ณ์๋ค์ ๊ฐ๋ณ์ ์ผ๋ก ์ฒ๋ฆฌ๋์ด ๋๋ฐ์ด์ค๋ก ์ ์ก๋ฉ๋๋ค. ์ด๋ ์์ permissive_sum
ํจ์์ ๋ํ jaxpr์์ ๋ณผ ์ ์์ต๋๋ค.
make_jaxpr(permissive_sum)(x)
๋ฆฌ์คํธ์ ๊ฐ ํญ๋ชฉ์ ๋ณ๋์ ์ ๋ ฅ์ผ๋ก ์ฒ๋ฆฌ๋๋ฏ๋ก, ๋ฆฌ์คํธ์ ํฌ๊ธฐ์ ๋ฐ๋ผ ์ ํ์ ์ผ๋ก ์ฆ๊ฐํ๋ ์ถ์ ๋ฐ ์ปดํ์ผ ์ค๋ฒํค๋๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด, JAX๋ ๋ฆฌ์คํธ ๋ฐ ํํ์ ๋ฐฐ์ด๋ก ์์์ ์ผ๋ก ๋ณํํ๋ ๊ฒ์ ํผํฉ๋๋ค.
๋ฐ๋ผ์, ํํ ๋๋ ๋ฆฌ์คํธ๋ฅผ JAX ํจ์์ ์ ๋ฌํ๋ ค๋ฉด ๋จผ์ ๋ช ์์ ์ผ๋ก ๋ฐฐ์ด๋ก ๋ณํํ ํ ์ ๋ฌํด์ผ ํฉ๋๋ค.
jnp.sum(jnp.array(x))
๐ช ๋์ (Random Numbers)#
rand()๋ก ์ธํด ๊ฒฐ๊ณผ๊ฐ ์์ฌ์ค๋ฌ์ด ๋ชจ๋ ๊ณผํ ๋ ผ๋ฌธ๋ค์ด ๋์๊ด ์ฑ ์ฅ์์ ์ฌ๋ผ์ง๋ค๋ฉด ๊ฐ ์ฑ ์ฅ์๋ ์ฃผ๋จน๋งํ ๊ฐ๊ฒฉ์ด ์๊ธธ ๊ฒ๋๋ค. - Numerical Recipes
RNGs์ State
์ฌ๋ฌ๋ถ๋ค์ NumPy ๋ฐ ๊ธฐํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์คํ ์ดํธํ ์์ฌ ๋์ ์์ฑ๊ธฐ(PRNG)์ ์ต์ํ ๊ฒ์ ๋๋ค. ์ด๋ฌํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ์์ฌ ๋์์ ์์ค๋ฅผ ์ ๊ณตํ๊ธฐ ์ํด ๋ง์ ์ธ๋ถ ์ ๋ณด๋ค์ ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์ ์ฉํ๊ฒ ์จ๊น๋๋ค.
print(np.random.random())
print(np.random.random())
print(np.random.random())
๋ฐฑ๊ทธ๋ผ์ด๋์์ numpy๋ Mersenne Twister PRNG๋ฅผ ์ฌ์ฉํ์ฌ ์์ฌ ๋์ ๊ธฐ๋ฅ์ ๊ฐํํฉ๋๋ค. PRNG์ ์ฃผ๊ธฐ๋ \(2^{19937} - 1\)์ด๊ณ ์ด๋ ์์ ์์๋ 624๊ฐ์ 32๋นํธ ๋ถํธ ์๋ ์ ์์ ์ด โ์ํธ๋กํผโ๊ฐ ์ผ๋ง๋ ๋ง์ด ์ฌ์ฉ๋์๋์ง์ ๋ํ ์์น๋ก ์ค๋ช ํ ์ ์์ต๋๋ค.
np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
์ด ์์ฌ ๋์ ์ํ ๋ฒกํฐ๋ ๋์๊ฐ ํ์ํ ๋๋ง๋ค ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์๋์ ์ผ๋ก ์ ๋ฐ์ดํธ๋์ด Mersenne twister ์ํ ๋ฒกํฐ์ uint32 ์ค 2๊ฐ๋ฅผ โ์๋นโํฉ๋๋ค.
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
Magic PRNG ์ํ์ ๋ฌธ์ ๋ ์๋ก ๋ค๋ฅธ ์ค๋ ๋, ํ๋ก์ธ์ค ๋ฐ ์ฅ์น์์ ์ฌ์ฉ ๋ฐ ์ ๋ฐ์ดํธ ๋๋ ๋ฐฉ์์ ๋ํด ์ถ๋ก ํ๊ธฐ ์ด๋ ต๊ณ ์ํธ๋กํผ ์์ฑ ๋ฐ ์๋น์ ๋ํ ์ธ๋ถ ์ ๋ณด๊ฐ ์ต์ข ์ฌ์ฉ์์๊ฒ ์จ๊ฒจ์ ธ ์์ ๋ ๋ฌธ์ ๋ฅผ ์ผ์ผํค๊ธฐ ๋งค์ฐ ์ฝ๋ค๋ ๊ฒ์ ๋๋ค.
๋ํ, Mersenne Twister PRNG๋ ๋ง์ ๋ฌธ์ ๊ฐ ์๋ ๊ฒ์ผ๋ก ์๋ ค์ ธ ์์ผ๋ฉฐ, 2.5Kb์ ํฐ ์ํ ํฌ๊ธฐ๋ฅผ ๊ฐ์ง๊ณ ์์ด ์ด๊ธฐํ ๋ฌธ์ ๋ฅผ ์ผ๊ธฐํ ์ ์์ต๋๋ค. ๋ํ, ์ต์ BigCrush ํ ์คํธ๋ฅผ ๋ง์กฑํ์ง ๋ชปํ๊ณ ์ผ๋ฐ์ ์ผ๋ก ๋๋ฆฌ๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
JAX PRNG
JAX๋ ๋์ PRNG ์ํ๋ฅผ ๋ช ์์ ์ผ๋ก ์ ๋ฌํ๊ณ ๋ฐ๋ณตํ์ฌ ์ํธ๋กํผ ์์ฑ ๋ฐ ์๋น๋ฅผ ์ฒ๋ฆฌํ๋ ๋ช ์์ PRNG๋ฅผ ๊ตฌํํ์ต๋๋ค. JAX๋ ๋ถํ ๊ฐ๋ฅํ ์ต์ Threefry counter ๊ธฐ๋ฐ PRNG๋ฅผ ์ฌ์ฉํฉ๋๋ค(Threefry counter-based PRNG). ์ฆ, ์ด๋ฌํ ์ค๊ณ๋ฅผ ํตํด PRNG ์ํ๋ฅผ ๋ณ๋ ฌ ํ๋ฅ ์ ์์ฑ์ ์ํด ์ฌ์ฉํ๊ธฐ ์ํด ์๋ก์ด PRNG๋ก ๋ถ๊ธฐํ ์ ์์ต๋๋ค.
๋ฌด์์ ์ํ๋ ํค(key
)๋ผ๊ณ ๋ถ๋ฅด๋ ๋ ๊ฐ์ unsigned-int32๋ก ์ค๋ช
๋ฉ๋๋ค.
from jax import random
key = random.PRNGKey(0)
key
JAX์ ์์ ํจ์๋ PRNG ์ํ์์ ์์ฌ ๋์๋ฅผ ์์ฑํ์ง๋ง ์ํ๋ฅผ ๋ณ๊ฒฝํ์ง๋ ์์ต๋๋ค!
๋์ผํ ์ํ๋ฅผ ์ฌ์ฌ์ฉํ๋ ํ์๋ ์ฌํ๊ณผ ๋จ์กฐ๋ก์์ ์ ๋ฐํ๋ฉฐ ๊ฒฐ๊ตญ ์ต์ข
์ฌ์ฉ์์๊ฒ ํผ๋์ ๋ถ์ด๋ฃ๋ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ํ ์ ์์ต๋๋ค!
(Reusing the same state will cause sadness and monotony, depriving the end user of lifegiving chaos:)
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
๋์ , ์๋ก์ด ์์ฌ ๋์๊ฐ ํ์ํ ๋๋ง๋ค PRNG๋ฅผ ๋ถํ ํ์ฌ ์ฌ์ฉ ๊ฐ๋ฅํ ํ์ ํค๋ฅผ ์ป์ต๋๋ค.
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
์๋ก์ด ๋์๊ฐ ํ์ํ ๋๋ง๋ค ํค๋ฅผ ์ ํํ๊ณ ์ ํ์ ํค๋ฅผ ๋ง๋ญ๋๋ค.
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
ํ ๋ฒ์ ๋ ์ด์์ ํ์ํค๋ฅผ ๋ง๋ค ์ ์์ต๋๋ค.
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
๐ช ์ ์ด ํ๋ฆ (Control Flow)#
โ python control_flow + autodiff โ
Python ํจ์์ grad
๋ฅผ ์ ์ฉํ๋ ค๋ ๊ฒฝ์ฐ Autograd(๋๋ Pytorch ๋๋ TF Eager)๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ฒ๋ผ ๋ฌธ์ ์์ด ์ผ๋ฐ์ ์ธ Python ์ ์ด ํ๋ฆ ๊ตฌ์ฑ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
Python control flow + JIT
jit
์ ํจ๊ป ์ ์ด ํ๋ฆ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋ ๋ณต์กํ๋ฉฐ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ ๋ง์ ์ ์ฝ์ด ์์ต๋๋ค.
์ด ์์๋ ๋์ํฉ๋๋ค.
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
์๋ ์์๋ ๋์ํฉ๋๋ค.
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
ํ์ง๋ง ์ด ์์๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋์ํ์ง ์์ต๋๋ค.
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
์ ๊ทธ๋ด๊น!?
ํจ์๋ฅผ jit
์ปดํ์ผํ ๋ ์ผ๋ฐ์ ์ผ๋ก ์ปดํ์ผ๋ ์ฝ๋๋ฅผ ์บ์ํ์ฌ ๋ค์ํ ์ธ์ ๊ฐ์ ๋ํด ์๋ํ๋ ํจ์ ๋ฒ์ ์ ์ปดํ์ผํ์ฌ ์ฌ์ฌ์ฉํ ์ ์๋๋ก ํฉ๋๋ค. ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ๊ฐ ํจ์ ํ๊ฐ๋ง๋ค ๋ค์ ์ปดํ์ผํ ํ์๊ฐ ์์ต๋๋ค.
์๋ฅผ ๋ค์ด jnp.array([1., 2., 3.], jnp.float32)
๋ฐฐ์ด์์ @jit
ํจ์๋ฅผ ํ๊ฐํ๊ธฐ ์ํด jnp.array([4., 5., 6.], jnp.float32)
์์ ์ฌ์ฉํ๋ ์ปดํ์ผ๋ ์ฝ๋๋ฅผ ์ฌ์ฌ์ฉํ์ฌ ํ์ฌ ์ปดํ์ผ์ ์ํ๋๋ ์๊ฐ์ ์ ์ฝํ ์ ์์ต๋๋ค.
JAX๋ Python ์ฝ๋์ ๋ค์ํ ์ธ์ ๊ฐ์ ์ ํจํ ๋ทฐ๋ฅผ ์ป๊ธฐ ์ํด JAX๋ ๊ฐ๋ฅํ ์ ๋ ฅ ์งํฉ์ ๋ํ๋ด๋ ์ถ์ ๊ฐ์ผ๋ก ์ฝ๋๋ฅผ ์ถ์ ํฉ๋๋ค. ๋ค์ํ ์ถ์ํ ์์ค์ด ์์ผ๋ฉฐ ์๋ก ๋ค๋ฅธ ๋ณํ์ ์๋ก ๋ค๋ฅธ ์ถ์ํ ์์ค์ ์ฌ์ฉํฉ๋๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก jit
์ ShapedArray
์ถ์ํ ์์ค์์ ์ฝ๋๋ฅผ ์ถ์ ํฉ๋๋ค. ๊ฐ ์ถ์ ๊ฐ์ ๊ณ ์ ๋ ๋ชจ์๊ณผ dtype์ ๊ฐ์ง๋ ๋ชจ๋ ๋ฐฐ์ด ๊ฐ์ ์งํฉ์ ๋ํ๋
๋๋ค. ์๋ฅผ ๋ค์ด ์ถ์ ๊ฐ ShapedAray((3,), jnp.float32)
๋ฅผ ์ฌ์ฉํ์ฌ ์ถ์ ํ๋ฉด ํด๋น ๋ฐฐ์ด ์ธํธ์ ๊ตฌ์ฒด์ ์ธ ๊ฐ์ ๋ํด ์ฌ์ฌ์ฉํ ์ ์๋ ํจ์์ ๋ทฐ๋ฅผ ์ป์ ์ ์์ต๋๋ค. ์ฆ, ์ปดํ์ผ ์๊ฐ์ ์ค์ผ ์ ์๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
๊ทธ๋ฌ๋ ์ฌ๊ธฐ์๋ ์ฅ๋จ์ ์ด ์์ต๋๋ค. ํน์ ๊ตฌ์ฒด์ ์ธ ๊ฐ์ด ๊ฒฐ์ ๋์ง ์์ ShapedArray((), jnp.float32)
์์ Python ํจ์๋ฅผ ์ถ์ ํ๋ ๊ฒฝ์ฐ if x < 3
๊ณผ ๊ฐ์ ์ค์ ๋๋ฌํ๋ฉด ํํ์ x < 3
์ {True, False}
์งํฉ์ ๋ํ๋ด๋ ์ถ์ ShapedArray((), jnp.bool_)
๋ก ํ๊ฐ๋ฉ๋๋ค. Python์ด ์ด๋ฅผ ๊ตฌ์ฒด์ ์ธ True
๋๋ False
๋ก ๊ฐ์ ํ๋ ค๊ณ ํ๋ฉด ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด๋ค ๋ถ๊ธฐ๋ฅผ ์ ํํด์ผ ํ ์ง ๋ชจ๋ฅด๊ณ ์ถ์ ์ ๊ณ์ํ ์ ์์ต๋๋ค! ๋จ์ ์ ์ถ์ํ ์์ค์ด ๋์์๋ก Python ์ฝ๋์ ๋ํ ๋ณด๋ค ์ผ๋ฐ์ ์ธ ๋ทฐ๋ฅผ ์ป์ ์ ์์ง๋ง(๋ฐ๋ผ์ ์ฌ์ปดํ์ผ์ ์ค์ผ ์ ์์ต๋๋ค.) ์ถ์ ์ ์๋ฃํ๋ ค๋ฉด Python ์ฝ๋์ ๋ ๋ง์ ์ ์ฝ์ด ํ์ํ๋ค๋ ๊ฒ์
๋๋ค.
์ข์ ์์์, ์ด ํธ๋ ์ด๋์คํ๋ฅผ ์ง์ ์ ์ดํ ์ ์๋ค๋ ๊ฒ์
๋๋ค. ๋ณด๋ค ์ ๋ฐํ ์ถ์ ๊ฐ์ ๋ํ jit
์ถ์ ์ ํตํด ์ถ์ ๊ฐ๋ฅ์ฑ ์ ์ฝ์ ์ํํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด jit
์ static_argnums
์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ผ๋ถ ์ธ์์ ๊ตฌ์ฒด์ ์ธ ๊ฐ์ ์ถ์ ํ๋๋ก ์ง์ ํ ์ ์์ต๋๋ค. ๋ค์์ ์ด์ ๋ํ ์์ ํจ์์
๋๋ค.
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
๋ฃจํ๋ฅผ ํฌํจํ ๋๋ค๋ฅธ ์์ ์ ๋๋ค.
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(jnp.array([2., 3., 4.]), 2)
static_argnums๋ฅผ ์ด์ฉํ ํจ๊ณผ๋ก ์ธํด ๋ฃจํ๊ฐ ์ ์ ์ผ๋ก ํผ์ณ์ง๋๋ค. JAX๋ ๋ํ Unshaped์ ๊ฐ์ ๋ ๋์ ์์ค์ ์ถ์ํ์์ ์ถ์ ํ ์ ์์ง๋ง ํ์ฌ ๋ณํ์ ๊ธฐ๋ณธ๊ฐ์ ์๋๋๋ค.
โ ๏ธย ์ธ์ ๊ฐ์ ๋ฐ๋ผ ํํ๊ฐ ๋ฐ๋๋ ํจ์
์ด๋ฌํ ์ ์ด ํ๋ฆ ๋ฌธ์ ๋ ๋ณด๋ค ๋ฏธ๋ฌํ ๋ฐฉ์์ผ๋ก๋ ๋ํ๋ฉ๋๋ค. jit ํ๋ ค๋ ์์น ํจ์๋ ๋ด๋ถ ๋ฐฐ์ด์ ๋ชจ์์ ์ธ์ ๊ฐ์ ๋ฐ๋ผ ํน์ ํ ์ ์์ต๋๋ค(์ธ์ ๋ชจ์์ ๋ฐ๋ผ ํน์ ํ๋ ๊ฒ์ ๊ด์ฐฎ์ต๋๋ค). ๊ฐ๋จํ ์๋ก ์
๋ ฅ ๋ณ์ ๊ธธ์ด
์ ๋ฐ๋ผ ์ถ๋ ฅ์ด ๋ฌ๋ผ์ง๋ ํจ์๋ฅผ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค.
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
static_argnums
๋ ์์ ์์ ๊ธธ์ด
์ ๋ณ๊ฒฝ์ด ์ฆ์ง ์์ ๊ฒฝ์ฐ์๋ ํธ๋ฆฌํ ์ ์์ง๋ง ๋ณ๊ฒฝ์ด ์ฆ์ ๊ฒฝ์ฐ ์ฌ์์ด ๋ ์ ์์ต๋๋ค!
๋ง์ง๋ง์ผ๋ก ํจ์์ ์ ์ญ์ ์ธ ๋ถ์ํจ๊ณผ๋ค์ด ์๋ ๊ฒฝ์ฐ JAX์ ์ถ์ ํ๋ก๊ทธ๋จ์ผ๋ก ์ธํด ์ด์ํ ์ผ์ด ๋ฐ์ํ ์ ์์ต๋๋ค. ์ผ๋ฐ์ ์ธ ๋ฌธ์ ๋ jitโd ํจ์ ๋ด์์ ๋ฐฐ์ด์ ์ถ๋ ฅํ๋ ค๊ณ ์๋ํ ๋ ๋ฐ์ํ ์ ์์ต๋๋ค.
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
๊ตฌ์กฐ์ ์ ์ด ํ๋ฆ ๊ธฐ๋ณธ ์์#
JAX์๋ ์ ์ด ํ๋ฆ์ ๋ํ ๋ค์ํ ์ต์ ์ด ๋ง์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ์ฌ์ปดํ์ผ์ ํผํ๊ณ ์ถ์ ๊ฐ๋ฅํ ์ ์ด ํ๋ฆ์ ์ฌ์ฉํ๋ฉด์ ํฐ ๋ฃจํ๋ฅผ ํ๊ณ ์ถ์ง ์๋ค๋ฉด ์๋์ 4๊ฐ์ง ๊ตฌ์กฐ์ ์ ์ด ํ๋ฆ ๊ธฐ๋ณธ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
lax.cond
ย differentiablelax.while_loop
ย fwd-mode-differentiablelax.fori_loop
ย fwd-mode-differentiableย in general;ย fwd and rev-mode differentiableย if endpoints are static.lax.scan
ย differentiable
cond
#
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
jax.lax
์๋ ๋์ ์กฐ๊ฑด์ ๋ฐ๋ผ ๋ถ๊ธฐํ ์ ์๋ ๋ค๋ฅธ ๋ ๊ฐ์ ํจ์๊ฐ ์ ๊ณต๋ฉ๋๋ค.
lax.select
๋lax.cond
์ ๋ฐฐ์น ๋ฒ์ ์ด์ง๋ง, ์ ํ์ง๋ ์ด์ ์ ๊ณ์ฐ๋ ๋ฐฐ์ด๋ก ํํ๋ฉ๋๋ค.lax.switch
๋lax.cond
์ ์ ์ฌํ์ง๋ง, ์ด๋ค ์์ ํธ์ถ ๊ฐ๋ฅํ ์ ํ์ง ์ฌ์ด์ ์ ํํ ์ ์์ต๋๋ค.
๋ํ, jax.numpy
์์๋ ์ด๋ฌํ ํจ์์ ๋ํ ๋ค์์ Numpy ์คํ์ผ ์ธํฐํ์ด์ค๊ฐ ์ ๊ณต๋ฉ๋๋ค.
jnp.where
๋ 3๊ฐ์ ์ธ์๊ฐ์๋ lax.select์ Numpy ์คํ์ผ ๋ํผ(wrapper)์ ๋๋ค.jnp.piecewise
๋lax.switch
์ Numpy ์คํ์ผ ๋ํผ(wrapper)์ด์ง๋ง, ๋จ์ผ ์ค์นผ๋ผ ์ธ๋ฑ์ค ๋์ ์ ๋ถ๋ฆฌ์ธ ์กฐ๊ฑด์ ๋ชฉ๋ก์ ๋ฐ๋ผ ์ ํํฉ๋๋ค.jnp.select
๋jnp.piecewise
์ ์ ์ฌํ API๋ฅผ ๊ฐ์ง์ง๋ง, ์ ํ์ง๋ ์ฌ์ ๊ณ์ฐ๋ ๋ฐฐ์ด๋ก ์ ๊ณต๋ฉ๋๋ค. ๊ฒฐ๊ณผ์ ์ผ๋กlax.select
์ ์ฌ๋ฌ ํธ์ถ๋ก ๊ตฌํ๋ฉ๋๋ค.
while_loop
#
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
fori_loop
#
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Summary#
๐ช ๋์ ํํ (Dynamic Shapes)#
jax.jit
, jax.vmap
, jax.grad
๋ฑ๊ณผ ๊ฐ์ ๋ณํ ๋ด์์ ์ฌ์ฉ๋๋ JAX ์ฝ๋๋ ๋ชจ๋ ์ถ๋ ฅ ๋ฐฐ์ด๊ณผ ์ค๊ฐ ๋ฐฐ์ด์ด ์ ์ ๋ชจ์์ ๊ฐ์ ธ์ผ ํฉ๋๋ค. ์ฆ, ๋ชจ์์ ๋ค๋ฅธ ๋ฐฐ์ด ๋ด์ ๊ฐ์ ์์กดํ์ง ์์์ผ ํฉ๋๋ค.
์๋ฅผ ๋ค์ด, jnp.nansum
์ ๋ฒ์ ์ ์ง์ ๊ตฌํํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ์์ํ ์ ์์ต๋๋ค
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
JIT ๋ฐ ๊ธฐํ ๋ณํ ์ธ๋ถ์์๋ ์์๋๋ก ์๋ํฉ๋๋ค.
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
jax.jit
๋๋ ๋ค๋ฅธ ๋ณํ์ ์ด ํจ์์ ์ ์ฉํ๋ ค๊ณ ํ๋ฉด ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค.
jax.jit(nansum)(x)
๋ฌธ์ ๋ x_without_nans
์ ํฌ๊ธฐ๊ฐ x
๋ด์ ๊ฐ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ๊ฒฐ์ ๋๋ค๋ ์ ์
๋๋ค. ์ข
์ข
JAX์์๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ํตํด ๋์ ์ผ๋ก ํฌ๊ธฐ ์กฐ์ ๋ ๋ฐฐ์ด์ ํ์์ฑ์ ํด๊ฒฐํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด jnp.where
์ 3๊ฐ ์ธ์ ํ์์ ์ฌ์ฉํ์ฌ NaN ๊ฐ์ 0์ผ๋ก ๋์ฒดํจ์ผ๋ก์จ ๋์ ๋ชจ์์ ํผํ๋ฉด์๋ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ๊ณ์ฐํ ์ ์์ต๋๋ค.
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
๋์ ๋ชจ์์ ๋ฐฐ์ด์ด ๋ฐ์ํ๋ ๋ค๋ฅธ ์ํฉ์์๋ ์ ์ฌํ ํธ๋ฆญ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
๐ช NaNs#
NaNs ๋๋ฒ๊น #
ํจ์ ๋๋ ๊ทธ๋๋์ธํธ์์ NaN์ด ๋ฐ์ํ๋ ์์น๋ฅผ ์ถ์ ํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด NaN ๊ฒ์ฌ๊ธฐ๋ฅผ ์ผค ์ ์์ต๋๋ค.
JAX_DEBUG_NANS=True
ย ํ๊ฒฝ ๋ณ์ ์ค์ ๋ฉ์ธ ํ์ผ ์๋จ์ ย
fromย jax.configย importย config
ย ์ยconfig.update("jax_debug_nans",True)
ย ๋ฅผ ์ถ๊ฐํ์ธ์.๋ฉ์ธ ํ์ผ์ย
fromย jax.configย importย config
ย ์ยconfig.parse_flags_with_absl()
ย ๋ฅผ ์ถ๊ฐํ์ธ์. ๊ทธ๋ฐ ๋ค์ ๋ช ๋ น ์ค ํ๋๊ทธ์ย-jax_debug_nans=True
์ ์ด์ฉํ์ฌ ์ต์ ์ ์ค์ ํ์ธ์.
์ด๋ก ์ธํด NaN ์์ฑ ์ฆ์ ๊ณ์ฐ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด ์ต์
์ ์ผ๋ฉด XLA์์ ์์ฑ๋ ๋ชจ๋ ๋ถ๋ ์์์ ์ ํ ๊ฐ์ nan ๊ฒ์ฌ๊ฐ ์ถ๊ฐ๋ฉ๋๋ค. ์ฆ, @jit
์ ์ํ์ง ์๋ ๋ชจ๋ ๊ธฐ๋ณธ ์์
์ ๋ํด ๊ฐ์ ๋ค์ ํธ์คํธ๋ก ๊ฐ์ ธ์ ndarry๋ก ๊ฒ์ฌํฉ๋๋ค. @jit
ํ์์ ์๋ ์ฝ๋์ ๊ฒฝ์ฐ ๋ชจ๋ @jit
ํจ์์ ์ถ๋ ฅ์ ๊ฒ์ฌํ๊ณ NaN์ด ์๋ ๊ฒฝ์ฐ ์ต์ ํ๋์ง ์์ op-by-op ๋ชจ๋์์ ํจ์๋ฅผ ๋ค์ ์คํํ์ฌ ํ ๋ฒ์ ํ ๋ ๋ฒจ์ฉ @jit
์ ์ ๊ฑฐํฉ๋๋ค.
@jit
์์๋ง ๋ฐ์ํ์ง๋ง ์ต์ ํ๋์ง ์์ ๋ชจ๋์์๋ ์์ฑ๋์ง ์๋ nan๊ณผ ๊ฐ์ ๊น๋ค๋ก์ด ์ํฉ์ด ๋ฐ์ํ ์ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ ๊ฒฝ๊ณ ๋ฉ์์ง๊ฐ ์ถ๋ ฅ๋์ง๋ง ์ฝ๋๋ ๊ณ์ ์คํ๋ฉ๋๋ค.
๊ทธ๋๋์ธํธ ํ๊ฐ์ ์ญ๋ฐฉํฅ ํจ์ค์์ nans๊ฐ ์์ฑ๋๋ ๊ฒฝ์ฐ ์คํ ์ถ์ ์์ ๋ช ํ๋ ์ ์๋ก ์์ธ๊ฐ ๋ฐ์ํ๋ฉด backward_pass ํจ์ ๋ด๋ถ๋ก ์ง์
ํฉ๋๋ค. ์ด ํจ์๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๊ธฐ๋ณธ ์์
์ํ์ค๋ฅผ ์ญ์์ผ๋ก ์ํํ๋ ๊ฐ๋จํ jaxpr ์ธํฐํ๋ฆฌํฐ์
๋๋ค. ์๋ ์์์ env JAX_DEBUG_NANS=True ipython
๋ช
๋ น์ค์ ์ฌ์ฉํ์ฌ ipython repl์ ์์ํ ํ, ๋ค์์ ์คํํ์ต๋๋ค.
import jax.numpy as jnp
jnp.divide(0., 0.)
์์ฑ๋ NaN์ด ์กํ์ต๋๋ค. %debug
๋ฅผ ์คํํ๋ฉด ์ฌํ ๋๋ฒ๊ฑฐ๋ฅผ ์ป์ ์ ์์ต๋๋ค. ์ด๊ฒ์ ์๋ ์์ ์ ๊ฐ์ด @jit
์ผ๋ก ๊ฐ์ธ์ง ํจ์์์๋ ์๋ํฉ๋๋ค.
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return DeviceArray(device_buffer, *result_shape)
FloatingPointError: invalid value
@jit
def f(x, y):
a = x * y
b = (x + y) / (x - y)
c = a + 2
return a + b * c
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
์ด ์ฝ๋๋ @jit
ํจ์์ ์ถ๋ ฅ์์ nan์ ๋ฐ๊ฒฌํ๋ฉด ์ต์ ํ๋์ง ์์ ์ฝ๋๋ฅผ ํธ์ถํ๋ฏ๋ก ์ฌ์ ํ ๋ช
ํํ ์คํ์ ์ถ์ ํ ์ ์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ %debug
๋ก ์ฌํ ๋๋ฒ๊ฑฐ๋ฅผ ์คํํ์ฌ ์ค๋ฅ๋ฅผ ํ์
ํ๊ธฐ ์ํด ๋ชจ๋ ๊ฐ์ ๊ฒ์ฌํ ์ ์์ต๋๋ค.
โ ๏ธ ๋๋ฒ๊น ํ์ง ์๋ ๊ฒฝ์ฐ NaN ๊ฒ์ฌ๊ธฐ๋ฅผ ์ผ์๋ ์ ๋ฉ๋๋ค. ๋ง์ ์ฅ์น-ํธ์คํธ ์๋ณต ๋ฐ ์ฑ๋ฅ ์ ํ๊ฐ ๋ฐ์ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค!
โ ๏ธ NaN ๊ฒ์ฌ๊ธฐ๋ pmap์์ ์๋ํ์ง ์์ต๋๋ค. pmap ์ฝ๋์์ nans๋ฅผ ๋๋ฒ๊น ํ๋ ค๋ฉด pmap์ vmap์ผ๋ก ๊ต์ฒดํด์ผ ํฉ๋๋ค.
๐ช Double (64bit) ์ ๋ฐ๋ (Double (64bit) precision)#
ํ์ฌ JAX๋ ๊ธฐ๋ณธ์ ์ผ๋ก NumPy API๊ฐ ํผ์ฐ์ฐ์๋ฅผ ๊ฐ์ ๋ก ๋๋ธํ(double)์ผ๋ก ๋ณํํ๋ ๊ฒฝํฅ์ ์ํํ๊ธฐ ์ํด ๋จ์ ๋ฐ๋(single-precision) ์ซ์๋ฅผ ๊ฐ์ ๋ก ์ ์ฉํ๊ณ ์์ต๋๋ค. ์ด๋ ๋ง์ ๋จธ์ ๋ฌ๋ ์ ํ๋ฆฌ์ผ์ด์ ์์ ์ํ๋ ๋์์ด์ง๋ง, ์ด๋ ์์์น ๋ชปํ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ํ ์ ์์ต๋๋ค!
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype
Double ์ ๋ฐ๋์ ์ซ์๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด, ์์ ์ jax_enable_x64
๊ตฌ์ฑ ๋ณ์๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค.
์ด๋ฅผ ์ํํ๊ธฐ ์ํ ๋ช๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
JAX_ENABLE_X64=True
๋ก ์ค์ ํ์ฌ 64๋นํธ ๋ชจ๋๋ฅผ ์ฌ์ฉ ๊ฐ๋ฅํ๊ฒ ํ ์ ์์ต๋๋ค.์์ ์์ย
jax_enable_x64
ย ๊ตฌ์ฑ ํ๋๊ทธ๋ฅผ ์๋์ผ๋ก ์ค์ ํ ์ ์์ต๋๋ค:
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
absl.app.run(main)
์ ์ฌ์ฉํ์ฌ ๋ช ๋ น์ค ํ๋๊ทธ๋ฅผ ํ์ฑํ ์ ์์ต๋๋ค.
from jax.config import config
config.config_with_absl()
absl.app.run(main)
๋ฅผ ์ฌ์ฉํ์ง ์๊ณ JAX๊ฐ absl ํ์ฑ์ ์ํํ๊ฒ ํ๋ ค๋ฉด ๋ค์๊ณผ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค:
from jax.config import config
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
#2-#4๋ JAX์ ๋ชจ๋ ๊ตฌ์ฑ ์ต์ ์์ ์๋ํฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ x64 ๋ชจ๋๊ฐ ํ์ฑํ๋์๋์ง ํ์ธํ ์ ์์ต๋๋ค.
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
์ฃผ์์ฌํญ#
โ ๏ธ XLA๋ ๋ชจ๋ ๋ฐฑ์๋์์ 64๋นํธ ์ปจ๋ณผ๋ฃจ์ ์ ์ง์ํ์ง ์์ต๋๋ค!
๐ช NumPy์์ ์ ๋๋ ์ฌ๋ฌ๊ฐ์ง ํ์๋ค (Miscellaneous Divergences from NumPy)#
jax.numpy
๋ Numpy API ๋์์ ์ ์ฌํ๊ฒ ๋์ํ๋๋ก ์ค๊ณ๋์์ง๋ง, ๋์์ด ๋ค๋ฅธ ์ฝ๋์ผ์ด์ค๋ค์ด ์์ต๋๋ค. ์ด๋ฌํ ๋ง์ ๊ฒฝ์ฐ๋ค์ ์ ์น์
์์ ์์ธํ ์ค๋ช
ํฉ๋๋ค. ์ฌ๊ธฐ์๋ API๊ฐ ๋ค๋ฅด๊ฒ ๋์ํ๋ ๋ช ๊ฐ์ง ๋ค๋ฅธ ์ฌ๋ก๋ค์ ๋์ปํฉ๋๋ค.
๋ฐ์ด๋๋ฆฌ ์์ ์ ๊ฒฝ์ฐ JAX์ ์ ํ ์น๊ฒฉ ๊ท์น์ NumPy์์ ์ฌ์ฉํ๋ ๊ท์น๊ณผ ๋ค์ ๋ค๋ฆ ๋๋ค. ์์ธํ ๋ด์ฉ์ Type Promotion Semantics๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
์์ ํ์ง ์์ ์ ํ ์บ์คํ (์ฆ, ๋์ dtype์ด ์ ๋ ฅ ๊ฐ์ ๋ํ๋ผ ์ ์๋ ์บ์คํ )๋ฅผ ์ํํ ๋ JAX์ ๋์์ ๋ฐฑ์๋์ ๋ฐ๋ผ ๋ค๋ฅผ ์ ์์ผ๋ฉฐ ์ผ๋ฐ์ ์ผ๋ก NumPy์ ๋์๊ณผ ๋ค๋ฅผ ์ ์์ต๋๋ค. Numpy๋ ์บ์คํ ์ธ์๋ฅผ ํตํด ์ด๋ฌํ ์๋๋ฆฌ์ค์์ ๊ฒฐ๊ณผ๋ฅผ ์ ์ดํ ์ ์์ต๋๋ค(np.ndarray.astype ์ฐธ์กฐ). JAX๋ XLA:ConvertElementType์ ๋์์ ์ง์ ์์ํ๋ ๋์ ์ด๋ฌํ ๊ตฌ์ฑ์ ์ ๊ณตํ์ง ์์ต๋๋ค.
์ฌ๊ธฐ์ NumPy์ JAX์ ์์ ํ์ง ์์ ์บ์คํ ์ ๋ฐ๋ฅธ ๋ค๋ฅธ ๊ฒฐ๊ณผ์ ๋ํ ์์ ์ ๋๋ค.
np.arange(254.0, 258.0).astype('uint8')
array([254, 255, 0, 1], dtype=uint8)
jnp.arange(254.0, 258.0).astype('uint8')
DeviceArray([254, 255, 255, 255], dtype=uint8)
์ด๋ฌํ ์ข ๋ฅ์ ๋ถ์ผ์น๋ ์ผ๋ฐ์ ์ผ๋ก ๋ถ๋์์ ์ ์ ์ ํ์ผ๋ก ๋๋ ๊ทธ ๋ฐ๋๋ก ๊ทน๋จ์ ์ธ ๊ฐ์ ์บ์คํ ํ ๋ ๋ฐ์ํฉ๋๋ค.
Fin.#
์ฌ๊ธฐ์์ ๋ค๋ฃจ์ง ์์ ๋ช๋ช ๋น์ ์ ํ๋๊ฒ ํ๋ ์์ธ๋ค์ ์ ๋ณดํด์ฃผ์๋ฉด ํด๋น ํํ ๋ฆฌ์ผ ํ์ด์ง์ ๋ฐ์ํ๊ฒ ์ต๋๋ค.