๐Ÿ”ชJAX - ์„ธ๋ถ€์ ์ธ ํŠน์ง•๋“ค (JAX - The Sharp Bits)๐Ÿ”ช#

Open in Colab

์ €์ž : Anselm Levskaya & Mattew Johnson

์—ญ์ž : ์žฅ์ง„์šฐ(wdfrty1234@gmail.com)

๊ฒ€์ˆ˜ : ์ด์˜๋นˆ, ๋ฐ•์ •ํ˜„

์ดํƒˆ๋ฆฌ์•„ ๋ณ€๋‘๋ฆฌ๋ฅผ ๊ฑท๋‹ค๋ณด๋ฉด, ์‚ฌ๋žŒ๋“ค์ด JAX์— ๋Œ€ํ•ด *โ€œuna anima di pura programmazione funzionale(์ˆœ์ˆ˜ํ•œ ํ•จ์ˆ˜ํ˜• ํ”„๋กœ๊ทธ๋ž˜๋ฐ์˜ ์˜ํ˜ผ)โ€*์ด๋ผ๋Š” ์šฉ์–ด๋กœ ๋ฌ˜์‚ฌํ•˜๋Š” ์‚ฌ๋žŒ๋“ค์ด ๋งŽ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

JAX๋Š” ์ˆ˜์น˜ํ•ด์„ ํ”„๋กœ๊ทธ๋žจ์˜ ๋ณ€ํ™˜์„ ์œ„ํ•œ ์–ธ์–ด๋กœ, CPU ํ˜น์€ ๊ฐ€์†๊ธฐ(GPU/TPU)์—์„œ ์ˆ˜์น˜ํ˜• ํ”„๋กœ๊ทธ๋žจ์„ ์ปดํŒŒ์ผํ•˜์—ฌ ๋™์ž‘์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํŠน์ •ํ•œ ์ œ์•ฝ์กฐ๊ฑด์„ ์ถฉ์กฑํ•˜๋Š” ๊ฒฝ์šฐ, JAX๋Š” ์ˆ˜์น˜ ๋ฐ ๊ณผํ•™ ํ”„๋กœ๊ทธ๋žจ์—์„œ ์ž˜ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค!

import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

๐Ÿ”ช์ˆœ์ˆ˜ ํ•จ์ˆ˜ (Pure functions)#

JAX๋Š” (๋ชจ๋“  ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๊ฐ€ ํ•จ์ˆ˜์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ†ตํ•ด ์ „๋‹ฌ๋˜๊ณ , ๋ชจ๋“  ์ถœ๋ ฅ ๊ฒฐ๊ณผ๊ฐ€ ํ•จ์ˆ˜์˜ ๊ฒฐ๊ณผ๋ฅผ ํ†ตํ•ด์„œ ๋‚˜์˜ค๋Š”) ์ˆœ์ˆ˜ ํ•จ์ˆ˜์—์„œ ์ž˜ ๋™์ž‘ํ•˜๋„๋ก ์„ค๊ณ„๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

(์—ญ์ž ์ฃผ) ์ˆœ์ˆ˜ ํ•จ์ˆ˜๋Š” ๊ฐ™์€ ์ž…๋ ฅ์ด ์ฃผ์–ด์ง„๋‹ค๋ฉด ํ•ญ์ƒ ๊ฐ™์€ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ, ํ•˜์ง€๋งŒ ์ˆœ์ˆ˜ํ•จ์ˆ˜๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ JAX๋Š” Python ์ธํ„ฐํ”„๋ฆฌํ„ฐ์™€ ๋‹ค๋ฅด๊ฒŒ ๋™์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜๋Š” ๊ทธ ์˜ˆ์‹œ ์ž…๋‹ˆ๋‹ค. ์ด์™€ ๊ฐ™์€ ์˜ˆ๋Š” JAX ์‹œ์Šคํ…œ์—์„œ์˜ ๋™์ž‘์ด ๋ณด์žฅ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, JAX๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฐ€์žฅ ์ ์ ˆํ•œ ๋ฐฉ๋ฒ•์€ ํ•จ์ˆ˜์ ์œผ๋กœ ์ˆœ์ˆ˜ํ•œ Pythonํ•จ์ˆ˜์— ๋Œ€ํ•ด์„œ๋งŒ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value

๋งŒ์•ฝ Python ํ•จ์ˆ˜๊ฐ€ ์‹ค์ œ๋กœ ์Šคํ…Œ์ดํŠธํ’€ ๊ฐ์ฒด๋ฅผ ๋‚ด๋ถ€์ ์œผ๋กœ ์‚ฌ์šฉํ•˜๋”๋ผ๋„, ์ด๋ฅผ ์™ธ๋ถ€์—์„œ ์ฝ๊ฑฐ๋‚˜ ์“ฐ์ง€ ์•Š๋Š”๋‹ค๋ฉด ํ•จ์ˆ˜์ ์œผ๋กœ ์ˆœ์ˆ˜ํ•˜๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))

jit์„ ์‚ฌ์šฉํ•˜๋ ค๋Š” JAX ํ•จ์ˆ˜๋‚˜ ์–ด๋–ค ์ œ์–ด ํ๋ฆ„ ๊ตฌ์„ฑ์š”์†Œ์—์„œ iterators๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์€ ๊ถŒ์žฅ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๊ทธ ์ด์œ ๋Š” iterator(๋ฐ˜๋ณต์ž)๊ฐ€ ๋‹ค์Œ ์š”์†Œ๋ฅผ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•œ ์ƒํƒœ(state)๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•œ ํŒŒ์ด์ฌ ๊ฐ์ฒด์ด๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ด์œ ๋กœ, iterators๋Š” JAX์˜ ํ•จ์ˆ˜์  ํ”„๋กœ๊ทธ๋ž˜๋ฐ ๋ชจ๋ธ๊ณผ ํ˜ธํ™˜๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ์ฝ”๋“œ๋Š” JAX์—์„œ iterators๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋Š” ๋ถ€์ ์ ˆํ•œ ์‹œ๋„๋“ค์— ๋Œ€ํ•œ ์˜ˆ์‹œ๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์˜ˆ์ œ๋“ค ๋Œ€๋ถ€๋ถ„์€ ์˜ค๋ฅ˜๋ฅผ ๋ฐ˜ํ™˜ํ•˜์ง€๋งŒ, ์–ด๋–ค ๊ฒฝ์šฐ์—๋Š” ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ค„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

๐Ÿ”ชIn-Place ์—…๋ฐ์ดํŠธ (In-Place Updates)#

Numpy๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ, ์—ฌ๋Ÿฌ๋ถ„๋“ค์€ ์ข…์ข… ์ด๋ ‡๊ฒŒ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

ํ•˜์ง€๋งŒ, ๋งŒ์•ฝ JAX device array์— in-place๋กœ ์—…๋ฐ์ดํŠธ๋ฅผ ์‹œ๋„ํ•˜๋ฉด, ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. (โ˜‰_โ˜‰)

%xmode Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0

๋ณ€์ˆ˜์˜ in-place ๋ณ€ํ˜•์„ ํ—ˆ์šฉํ•˜๋Š” ๊ฒƒ์€ ํ”„๋กœ๊ทธ๋žจ์˜ ๋ถ„์„๊ณผ ๋ณ€ํ™˜์„ ์–ด๋ ต๊ฒŒ ๋งŒ๋“œ๋Š” ์š”์ธ์ž…๋‹ˆ๋‹ค. JAX์—์„œ๋Š” ํ”„๋กœ๊ทธ๋žจ์ด ์ˆœ์ˆ˜ ํ•จ์ˆ˜์—ฌ์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๊ธฐ์–ตํ•ฉ์‹œ๋‹ค.

JAX์—์„œ๋Š” in-place ๊ธฐ๋ฒ• ๋Œ€์‹ ์—, JAX์—์„œ๋Š” JAX array์— .at ์†์„ฑ์„ ์ด์šฉํ•˜์—ฌ ํ•จ์ˆ˜์  ๋ฐฐ์—ด ์—…๋ฐ์ดํŠธ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. (.at property on JAX arrays.)

โš ๏ธ jit๋œ ์ฝ”๋“œ ์™€ lax.while_loop ๋˜๋Š” lax.fori_loop ๋‚ด๋ถ€์—์„œ ์Šฌ๋ผ์ด์Šค์˜ ํฌ๊ธฐ๋Š” ์ธ์ˆ˜ ๊ฐ’์˜ ํ•จ์ˆ˜๊ฐ€ ์•„๋‹ˆ๋ผ ์ธ์ˆ˜ ํ˜•ํƒœ์˜ ํ•จ์ˆ˜์—ฌ์•ผ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. - ์Šฌ๋ผ์ด์Šค ์‹œ์ž‘ ์ธ๋ฑ์Šค์—๋Š” ๊ทธ๋Ÿฌํ•œ ์ œํ•œ์ด ์—†์Šต๋‹ˆ๋‹ค. ์•„๋ž˜์˜ ์ œ์–ด ํ๋ฆ„ ๋ถ€๋ถ„์—์„œ ์ด๋Ÿฌํ•œ ์ œ์•ฝ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ฐฐ์—ด ์—…๋ฐ์ดํŠธ : x.at[idx].set(y)

์˜ˆ๋ฅผ ๋“ค์–ด, ์—…๋ฐ์ดํŠธ๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)

JAX์˜ ์—…๋ฐ์ดํŠธ ํ•จ์ˆ˜๋Š” NumPy์™€๋Š” ๋‹ค๋ฅด๊ฒŒ out-of-place๋กœ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ์—…๋ฐ์ดํŠธ๋œ ๋ฐฐ์—ด์€ ์ƒˆ ๋ฐฐ์—ด๋กœ ๋ฐ˜ํ™˜๋˜๋ฉฐ ์›๋ž˜ ๋ฐฐ์—ด์€ ์—…๋ฐ์ดํŠธ๋กœ ์ˆ˜์ •๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

print("original array unchanged:\n", jax_array)

ํ•˜์ง€๋งŒ, jit์œผ๋กœ ์ปดํŒŒ์ผ ๋œ ์ฝ”๋“œ ๋‚ด์—์„œ x.at[idx].set(y)์˜ ์ž…๋ ฅ ๊ฐ’ x๊ฐ€ ์žฌ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ, ์ปดํŒŒ์ผ๋Ÿฌ๋Š” in-place๋กœ ๋ฐฐ์—ด์ด ์—…๋ฐ์ดํŠธ ๋˜๋„๋ก ์ตœ์ ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค๋ฅธ ์—ฐ์‚ฐ๊ณผ ํ•จ๊ป˜ ๋ฐฐ์—ด ์—…๋ฐ์ดํŠธ#

์ธ๋ฑ์Šค๊ฐ€ ์ง€์ •๋œ ๋ฐฐ์—ด์˜ ์—…๋ฐ์ดํŠธ๋Š” ๋‹จ์ˆœํžˆ ๊ฐ’์„ ๋ฎ์–ด์“ฐ๋Š” ๊ฒƒ์—๋งŒ ์ œํ•œ๋˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์•„๋ž˜์˜ ์˜ˆ์‹œ์™€ ๊ฐ™์ด ์ธ๋ฑ์Šค์— ๋ง์…ˆ์„ ํ•˜๋Š” ์—ฐ์‚ฐ๋„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)

๋ณด๋‹ค ๋” ์ž์„ธํ•œ ์ธ๋ฑ์Šค๋œ ๋ฐฐ์—ด์˜ ์—…๋ฐ์ดํŠธ์— ๊ด€๋ จํ•˜์—ฌ์„œ๋Š”, ํ•ด๋‹น ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•ด์ฃผ์„ธ์š”. (documentatiaon for the .at property)

๐Ÿ”ช ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚œ ์ธ๋ฑ์‹ฑ (Out-of-Bound Indexing)#

NumPy์—์„œ๋Š” ์—ฌ๋Ÿฌ๋ถ„์ด ์ธ๋ฑ์Šค ๋ฐฐ์—ด์˜ ์ธ๋ฑ์Šค ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚˜๋Š” ๋™์ž‘์„ ์ˆ˜ํ–‰ํ•˜๋ฉด ์•„๋ž˜์™€ ๊ฐ™์€ ์—๋Ÿฌ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

np.arange(10)[11]

ํ•˜์ง€๋งŒ, ๊ฐ€์†๊ธฐ์—์„œ ๋™์ž‘ํ•˜๋Š” ์ฝ”๋“œ๋กœ๋ถ€ํ„ฐ ์—๋Ÿฌ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ค๋Š” ๊ฒƒ์€ ์–ด๋ ต๊ฑฐ๋‚˜ ์‹ฌ์ง€์–ด๋Š” ๋ถˆ๊ฐ€๋Šฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฏ€๋กœ, JAX๋Š” ๋ฐฐ์—ด์˜ ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚˜๋Š” ์ธ๋ฑ์‹ฑ์— ๋Œ€ํ•ด์„œ ์˜ค๋ฅ˜๊ฐ€ ์•„๋‹Œ ๋™์ž‘์„ ์„ ํƒํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. (์œ ํšจํ•˜์ง€ ์•Š์€ ๋ถ€๋™ ์†Œ์ˆ˜์ ์˜ ์‚ฐ์ˆ ์  ๊ฒฐ๊ณผ๊ฐ€ NaN์ด ๋˜๋Š” ๊ฒƒ๊ณผ ์œ ์‚ฌํ•ฉ๋‹ˆ๋‹ค.). ๋งŒ์•ฝ ์ธ๋ฑ์‹ฑ ์ž‘์—…์ด ๋ฐฐ์—ด ์ธ๋ฑ์Šค ์—…๋ฐ์ดํŠธ(์˜ˆ: index_add ๋˜๋Š” scatter-์œ ์‚ฌํ•œ ๊ธฐ๋ณธ ์š”์†Œ)์ธ ๊ฒฝ์šฐ, ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚œ ์ธ๋ฑ์Šค์˜ ์—…๋ฐ์ดํŠธ๋Š” ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค. ์ž‘์—…์ด ๋ฐฐ์—ด ์ธ๋ฑ์Šค ๊ฒ€์ƒ‰(์˜ˆ: NumPy ์ธ๋ฑ์‹ฑ ๋˜๋Š” gather-์œ ์‚ฌ ๊ธฐ๋ณธ ์š”์†Œ)์ธ ๊ฒฝ์šฐ, ๋ฌด์–ธ๊ฐ€๋ฅผ ๋ฐ˜ํ™˜ํ•ด์•ผ ํ•˜๋ฏ€๋กœ ์ธ๋ฑ์Šค๊ฐ€ ๋ฐฐ์—ด์˜ ๋ฒ”์œ„์— ๊ณ ์ •๋ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์•„๋ž˜์˜ ์ธ๋ฑ์‹ฑ ๋™์ž‘์—์„œ๋Š” ๋ฐฐ์—ด์˜ ๋งˆ์ง€๋ง‰ ๊ฐ’์ด ๋ฐ˜ํ™˜๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

jnp.arange(10)[11]

์ธ๋ฑ์Šค ๊ฒ€์ƒ‰์— ๋Œ€ํ•œ ์ด๋Ÿฌํ•œ ๋™์ž‘์œผ๋กœ ์ธํ•ด jnp.nanargmin ๋ฐ jnp.nanargmax์™€ ๊ฐ™์€ ํ•จ์ˆ˜๋Š” NaN์œผ๋กœ ๊ตฌ์„ฑ๋œ ์Šฌ๋ผ์ด์Šค์— ๋Œ€ํ•ด -1์„ ๋ฐ˜ํ™˜ํ•˜์ง€๋งŒ Numpy๋Š” ์˜ค๋ฅ˜๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.

์œ„์—์„œ ์„ค๋ช…ํ•œ ๋‘ ๊ฐ€์ง€ ๋™์ž‘์ด ์„œ๋กœ ์—ญ์˜ ๊ด€๊ณ„๊ฐ€ ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์—, ์—ญ๋ฐฉํ–ฅ ์ž๋™ ๋ฏธ๋ถ„(์ธ๋ฑ์Šค ์—…๋ฐ์ดํŠธ๋ฅผ ์ธ๋ฑ์Šค ๊ฒ€์ƒ‰์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๊ทธ ๋ฐ˜๋Œ€๋กœ ์ „ํ™˜)์€ ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚œ ์ธ๋ฑ์‹ฑ์˜ ์˜๋ฏธ๋ฅผ ๋ณด์กดํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ JAX์˜ ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚œ ์ธ๋ฑ์‹ฑ์„ ์ •์˜๋˜์ง€ ์•Š์€ ๋™์ž‘์œผ๋กœ ์ƒ๊ฐํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

๐Ÿ”ช ๋น„๋ฐฐ์—ด ์ž…๋ ฅ: NumPy vs. Jax (Non-array inputs: NumPy vs. JAX)#

NumPy๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ Python์˜ ๋ฆฌ์ŠคํŠธ ๋˜๋Š” ํŠœํ”Œ์„ API ํ•จ์ˆ˜์— ๋Œ€ํ•œ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

np.sum([1, 2, 3])

JAX๋Š” ์ด์™€ ๋‹ค๋ฅด๊ฒŒ ์ผ๋ฐ˜์ ์œผ๋กœ ์œ ์šฉํ•œ ์˜ค๋ฅ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

jnp.sum([1, 2, 3])

์ด๋Š” ์˜๋„์ ์œผ๋กœ ์„ค๊ณ„๋œ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ๊ทธ ์ด์œ ๋Š” ์ถ”์ ๋œ ํ•จ์ˆ˜์— ๋ฆฌ์ŠคํŠธ๋‚˜ ํŠœํ”Œ์„ ์ „๋‹ฌํ•˜๊ฒŒ ๋˜๋ฉด ๊ฐ์ง€ํ•˜๊ธฐ ์–ด๋ ค์šด ์กฐ์šฉํ•œ ์„ฑ๋Šฅ์˜ ์ €ํ•˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, ๋ฆฌ์ŠคํŠธ์˜ ์ž…๋ ฅ์„ ํ—ˆ์šฉํ•˜๋Š” ๋‹ค์Œ ๋ฒ„์ „์˜ jnp.sum์„ ๊ณ ๋ คํ•ด๋ด…์‹œ๋‹ค.

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)

๊ฒฐ๊ณผ๋Š” ์˜ˆ์ƒํ•œ ๋Œ€๋กœ์ด์ง€๋งŒ ์—ฌ๊ธฐ์—๋Š” ์ž ์žฌ์ ์ธ ์„ฑ๋Šฅ ๋ฌธ์ œ๊ฐ€ ์ˆจ๊ฒจ์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. JAX์˜ ์ถ”์  ๋ฐ JIT ์ปดํŒŒ์ผ ๋ชจ๋ธ์—์„œ Python์˜ ๋ฆฌ์ŠคํŠธ ํ˜น์€ ํŠœํ”Œ์˜ ๊ฐ ์š”์†Œ๊ฐ€ ๋ณ„๋„์˜ JAX ๋ณ€์ˆ˜๋กœ ์ทจ๊ธ‰๋˜๋ฉฐ, ์ด๋Ÿฌํ•œ ๋ณ€์ˆ˜๋“ค์€ ๊ฐœ๋ณ„์ ์œผ๋กœ ์ฒ˜๋ฆฌ๋˜์–ด ๋””๋ฐ”์ด์Šค๋กœ ์ „์†ก๋ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์œ„์˜ permissive_sum ํ•จ์ˆ˜์— ๋Œ€ํ•œ jaxpr์—์„œ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

make_jaxpr(permissive_sum)(x)

๋ฆฌ์ŠคํŠธ์˜ ๊ฐ ํ•ญ๋ชฉ์€ ๋ณ„๋„์˜ ์ž…๋ ฅ์œผ๋กœ ์ฒ˜๋ฆฌ๋˜๋ฏ€๋กœ, ๋ฆฌ์ŠคํŠธ์˜ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์„ ํ˜•์ ์œผ๋กœ ์ฆ๊ฐ€ํ•˜๋Š” ์ถ”์  ๋ฐ ์ปดํŒŒ์ผ ์˜ค๋ฒ„ํ—ค๋“œ๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด, JAX๋Š” ๋ฆฌ์ŠคํŠธ ๋ฐ ํŠœํ”Œ์„ ๋ฐฐ์—ด๋กœ ์ž„์‹œ์ ์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์„ ํ”ผํ•ฉ๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ, ํŠœํ”Œ ๋˜๋Š” ๋ฆฌ์ŠคํŠธ๋ฅผ JAX ํ•จ์ˆ˜์— ์ „๋‹ฌํ•˜๋ ค๋ฉด ๋จผ์ € ๋ช…์‹œ์ ์œผ๋กœ ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜ํ•œ ํ›„ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

jnp.sum(jnp.array(x))

๐Ÿ”ช ๋‚œ์ˆ˜ (Random Numbers)#

rand()๋กœ ์ธํ•ด ๊ฒฐ๊ณผ๊ฐ€ ์˜์‹ฌ์Šค๋Ÿฌ์šด ๋ชจ๋“  ๊ณผํ•™ ๋…ผ๋ฌธ๋“ค์ด ๋„์„œ๊ด€ ์ฑ…์žฅ์—์„œ ์‚ฌ๋ผ์ง„๋‹ค๋ฉด ๊ฐ ์ฑ…์žฅ์—๋Š” ์ฃผ๋จน๋งŒํ•œ ๊ฐ„๊ฒฉ์ด ์ƒ๊ธธ ๊ฒ๋‹ˆ๋‹ค. - Numerical Recipes

RNGs์™€ State

์—ฌ๋Ÿฌ๋ถ„๋“ค์€ NumPy ๋ฐ ๊ธฐํƒ€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ์Šคํ…Œ์ดํŠธํ’€ ์˜์‚ฌ ๋‚œ์ˆ˜ ์ƒ์„ฑ๊ธฐ(PRNG)์— ์ต์ˆ™ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์€ ์˜์‚ฌ ๋‚œ์ˆ˜์˜ ์†Œ์Šค๋ฅผ ์ œ๊ณตํ•˜๊ธฐ ์œ„ํ•ด ๋งŽ์€ ์„ธ๋ถ€ ์ •๋ณด๋“ค์„ ๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ ์œ ์šฉํ•˜๊ฒŒ ์ˆจ๊น๋‹ˆ๋‹ค.

print(np.random.random())
print(np.random.random())
print(np.random.random())

๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ numpy๋Š” Mersenne Twister PRNG๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜์‚ฌ ๋‚œ์ˆ˜ ๊ธฐ๋Šฅ์„ ๊ฐ•ํ™”ํ•ฉ๋‹ˆ๋‹ค. PRNG์˜ ์ฃผ๊ธฐ๋Š” \(2^{19937} - 1\)์ด๊ณ  ์–ด๋Š ์‹œ์ ์—์„œ๋“  624๊ฐœ์˜ 32๋น„ํŠธ ๋ถ€ํ˜ธ ์—†๋Š” ์ •์ˆ˜์™€ ์ด โ€œ์—”ํŠธ๋กœํ”ผโ€๊ฐ€ ์–ผ๋งˆ๋‚˜ ๋งŽ์ด ์‚ฌ์šฉ๋˜์—ˆ๋Š”์ง€์— ๋Œ€ํ•œ ์œ„์น˜๋กœ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers...,
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)

์ด ์˜์‚ฌ ๋‚œ์ˆ˜ ์ƒํƒœ ๋ฒกํ„ฐ๋Š” ๋‚œ์ˆ˜๊ฐ€ ํ•„์š”ํ•  ๋•Œ๋งˆ๋‹ค ๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ ์ž๋™์ ์œผ๋กœ ์—…๋ฐ์ดํŠธ๋˜์–ด Mersenne twister ์ƒํƒœ ๋ฒกํ„ฐ์˜ uint32 ์ค‘ 2๊ฐœ๋ฅผ โ€œ์†Œ๋น„โ€ํ•ฉ๋‹ˆ๋‹ค.

_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)

Magic PRNG ์ƒํƒœ์˜ ๋ฌธ์ œ๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ์Šค๋ ˆ๋“œ, ํ”„๋กœ์„ธ์Šค ๋ฐ ์žฅ์น˜์—์„œ ์‚ฌ์šฉ ๋ฐ ์—…๋ฐ์ดํŠธ ๋˜๋Š” ๋ฐฉ์‹์— ๋Œ€ํ•ด ์ถ”๋ก ํ•˜๊ธฐ ์–ด๋ ต๊ณ  ์—”ํŠธ๋กœํ”ผ ์ƒ์„ฑ ๋ฐ ์†Œ๋น„์— ๋Œ€ํ•œ ์„ธ๋ถ€ ์ •๋ณด๊ฐ€ ์ตœ์ข… ์‚ฌ์šฉ์ž์—๊ฒŒ ์ˆจ๊ฒจ์ ธ ์žˆ์„ ๋•Œ ๋ฌธ์ œ๋ฅผ ์ผ์œผํ‚ค๊ธฐ ๋งค์šฐ ์‰ฝ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋˜ํ•œ, Mersenne Twister PRNG๋Š” ๋งŽ์€ ๋ฌธ์ œ๊ฐ€ ์žˆ๋Š” ๊ฒƒ์œผ๋กœ ์•Œ๋ ค์ ธ ์žˆ์œผ๋ฉฐ, 2.5Kb์˜ ํฐ ์ƒํƒœ ํฌ๊ธฐ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์–ด ์ดˆ๊ธฐํ™” ๋ฌธ์ œ๋ฅผ ์•ผ๊ธฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ, ์ตœ์‹  BigCrush ํ…Œ์ŠคํŠธ๋ฅผ ๋งŒ์กฑํ•˜์ง€ ๋ชปํ•˜๊ณ  ์ผ๋ฐ˜์ ์œผ๋กœ ๋Š๋ฆฌ๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

JAX PRNG

JAX๋Š” ๋Œ€์‹  PRNG ์ƒํƒœ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ „๋‹ฌํ•˜๊ณ  ๋ฐ˜๋ณตํ•˜์—ฌ ์—”ํŠธ๋กœํ”ผ ์ƒ์„ฑ ๋ฐ ์†Œ๋น„๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ช…์‹œ์  PRNG๋ฅผ ๊ตฌํ˜„ํ–ˆ์Šต๋‹ˆ๋‹ค. JAX๋Š” ๋ถ„ํ•  ๊ฐ€๋Šฅํ•œ ์ตœ์‹  Threefry counter ๊ธฐ๋ฐ˜ PRNG๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค(Threefry counter-based PRNG). ์ฆ‰, ์ด๋Ÿฌํ•œ ์„ค๊ณ„๋ฅผ ํ†ตํ•ด PRNG ์ƒํƒœ๋ฅผ ๋ณ‘๋ ฌ ํ™•๋ฅ ์  ์ƒ์„ฑ์„ ์œ„ํ•ด ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์ƒˆ๋กœ์šด PRNG๋กœ ๋ถ„๊ธฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ฌด์ž‘์œ„ ์ƒํƒœ๋Š” ํ‚ค(key)๋ผ๊ณ  ๋ถ€๋ฅด๋Š” ๋‘ ๊ฐœ์˜ unsigned-int32๋กœ ์„ค๋ช…๋ฉ๋‹ˆ๋‹ค.

from jax import random
key = random.PRNGKey(0)
key

JAX์˜ ์ž„์˜ ํ•จ์ˆ˜๋Š” PRNG ์ƒํƒœ์—์„œ ์˜์‚ฌ ๋‚œ์ˆ˜๋ฅผ ์ƒ์„ฑํ•˜์ง€๋งŒ ์ƒํƒœ๋ฅผ ๋ณ€๊ฒฝํ•˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค!

๋™์ผํ•œ ์ƒํƒœ๋ฅผ ์žฌ์‚ฌ์šฉํ•˜๋Š” ํ–‰์œ„๋Š” ์Šฌํ””๊ณผ ๋‹จ์กฐ๋กœ์›€์„ ์œ ๋ฐœํ•˜๋ฉฐ ๊ฒฐ๊ตญ ์ตœ์ข… ์‚ฌ์šฉ์ž์—๊ฒŒ ํ˜ผ๋ˆ์„ ๋ถˆ์–ด๋„ฃ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ดˆ๋ž˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!
(Reusing the same state will cause sadness and monotony, depriving the end user of lifegiving chaos:)

print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)

๋Œ€์‹ , ์ƒˆ๋กœ์šด ์˜์‚ฌ ๋‚œ์ˆ˜๊ฐ€ ํ•„์š”ํ•  ๋•Œ๋งˆ๋‹ค PRNG๋ฅผ ๋ถ„ํ• ํ•˜์—ฌ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ•˜์œ„ ํ‚ค๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

์ƒˆ๋กœ์šด ๋‚œ์ˆ˜๊ฐ€ ํ•„์š”ํ•  ๋•Œ๋งˆ๋‹ค ํ‚ค๋ฅผ ์ „ํŒŒํ•˜๊ณ  ์ƒˆ ํ•˜์œ„ ํ‚ค๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

ํ•œ ๋ฒˆ์— ๋‘˜ ์ด์ƒ์˜ ํ•˜์œ„ํ‚ค๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))

๐Ÿ”ช ์ œ์–ด ํ๋ฆ„ (Control Flow)#

โœ” python control_flow + autodiff โœ”

Python ํ•จ์ˆ˜์— grad๋ฅผ ์ ์šฉํ•˜๋ ค๋Š” ๊ฒฝ์šฐ Autograd(๋˜๋Š” Pytorch ๋˜๋Š” TF Eager)๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ๋ฌธ์ œ ์—†์ด ์ผ๋ฐ˜์ ์ธ Python ์ œ์–ด ํ๋ฆ„ ๊ตฌ์„ฑ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!

Python control flow + JIT

jit์™€ ํ•จ๊ป˜ ์ œ์–ด ํ๋ฆ„์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์€ ๋” ๋ณต์žกํ•˜๋ฉฐ ๊ธฐ๋ณธ์ ์œผ๋กœ ๋” ๋งŽ์€ ์ œ์•ฝ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ์˜ˆ์‹œ๋Š” ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))

์•„๋ž˜ ์˜ˆ์‹œ๋„ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))

ํ•˜์ง€๋งŒ ์ด ์˜ˆ์‹œ๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2)

์™œ ๊ทธ๋Ÿด๊นŒ!?

ํ•จ์ˆ˜๋ฅผ jit ์ปดํŒŒ์ผํ•  ๋•Œ ์ผ๋ฐ˜์ ์œผ๋กœ ์ปดํŒŒ์ผ๋œ ์ฝ”๋“œ๋ฅผ ์บ์‹œํ•˜์—ฌ ๋‹ค์–‘ํ•œ ์ธ์ˆ˜ ๊ฐ’์— ๋Œ€ํ•ด ์ž‘๋™ํ•˜๋Š” ํ•จ์ˆ˜ ๋ฒ„์ „์„ ์ปดํŒŒ์ผํ•˜์—ฌ ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์œผ๋กœ ๊ฐ ํ•จ์ˆ˜ ํ‰๊ฐ€๋งˆ๋‹ค ๋‹ค์‹œ ์ปดํŒŒ์ผํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด jnp.array([1., 2., 3.], jnp.float32) ๋ฐฐ์—ด์—์„œ @jit ํ•จ์ˆ˜๋ฅผ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด jnp.array([4., 5., 6.], jnp.float32)์—์„œ ์‚ฌ์šฉํ–ˆ๋˜ ์ปดํŒŒ์ผ๋œ ์ฝ”๋“œ๋ฅผ ์žฌ์‚ฌ์šฉํ•˜์—ฌ ํ•˜์—ฌ ์ปดํŒŒ์ผ์— ์ˆ˜ํ–‰๋˜๋Š” ์‹œ๊ฐ„์„ ์ ˆ์•ฝํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

JAX๋Š” Python ์ฝ”๋“œ์˜ ๋‹ค์–‘ํ•œ ์ธ์ˆ˜ ๊ฐ’์— ์œ ํšจํ•œ ๋ทฐ๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด JAX๋Š” ๊ฐ€๋Šฅํ•œ ์ž…๋ ฅ ์ง‘ํ•ฉ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ถ”์ƒ ๊ฐ’์œผ๋กœ ์ฝ”๋“œ๋ฅผ ์ถ”์ ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์–‘ํ•œ ์ถ”์ƒํ™” ์ˆ˜์ค€์ด ์žˆ์œผ๋ฉฐ ์„œ๋กœ ๋‹ค๋ฅธ ๋ณ€ํ™˜์€ ์„œ๋กœ ๋‹ค๋ฅธ ์ถ”์ƒํ™” ์ˆ˜์ค€์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๊ธฐ๋ณธ์ ์œผ๋กœ jit์€ ShapedArray ์ถ”์ƒํ™” ์ˆ˜์ค€์—์„œ ์ฝ”๋“œ๋ฅผ ์ถ”์ ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ์ถ”์ƒ ๊ฐ’์€ ๊ณ ์ •๋œ ๋ชจ์–‘๊ณผ dtype์„ ๊ฐ€์ง€๋Š” ๋ชจ๋“  ๋ฐฐ์—ด ๊ฐ’์˜ ์ง‘ํ•ฉ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์ถ”์ƒ ๊ฐ’ ShapedAray((3,), jnp.float32)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”์ ํ•˜๋ฉด ํ•ด๋‹น ๋ฐฐ์—ด ์„ธํŠธ์˜ ๊ตฌ์ฒด์ ์ธ ๊ฐ’์— ๋Œ€ํ•ด ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ํ•จ์ˆ˜์˜ ๋ทฐ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฆ‰, ์ปดํŒŒ์ผ ์‹œ๊ฐ„์„ ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ์—ฌ๊ธฐ์—๋Š” ์žฅ๋‹จ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํŠน์ • ๊ตฌ์ฒด์ ์ธ ๊ฐ’์ด ๊ฒฐ์ •๋˜์ง€ ์•Š์€ ShapedArray((), jnp.float32)์—์„œ Python ํ•จ์ˆ˜๋ฅผ ์ถ”์ ํ•˜๋Š” ๊ฒฝ์šฐ if x < 3๊ณผ ๊ฐ™์€ ์ค„์— ๋„๋‹ฌํ•˜๋ฉด ํ‘œํ˜„์‹ x < 3์€ {True, False} ์ง‘ํ•ฉ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ถ”์ƒ ShapedArray((), jnp.bool_)๋กœ ํ‰๊ฐ€๋ฉ๋‹ˆ๋‹ค. Python์ด ์ด๋ฅผ ๊ตฌ์ฒด์ ์ธ True ๋˜๋Š” False๋กœ ๊ฐ•์ œํ•˜๋ ค๊ณ  ํ•˜๋ฉด ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค. ์–ด๋–ค ๋ถ„๊ธฐ๋ฅผ ์„ ํƒํ•ด์•ผ ํ• ์ง€ ๋ชจ๋ฅด๊ณ  ์ถ”์ ์„ ๊ณ„์†ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค! ๋‹จ์ ์€ ์ถ”์ƒํ™” ์ˆ˜์ค€์ด ๋†’์„์ˆ˜๋ก Python ์ฝ”๋“œ์— ๋Œ€ํ•œ ๋ณด๋‹ค ์ผ๋ฐ˜์ ์ธ ๋ทฐ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์ง€๋งŒ(๋”ฐ๋ผ์„œ ์žฌ์ปดํŒŒ์ผ์„ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.) ์ถ”์ ์„ ์™„๋ฃŒํ•˜๋ ค๋ฉด Python ์ฝ”๋“œ์— ๋” ๋งŽ์€ ์ œ์•ฝ์ด ํ•„์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ข‹์€ ์†Œ์‹์€, ์ด ํŠธ๋ ˆ์ด๋“œ์˜คํ”„๋ฅผ ์ง์ ‘ ์ œ์–ดํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ณด๋‹ค ์ •๋ฐ€ํ•œ ์ถ”์ƒ ๊ฐ’์— ๋Œ€ํ•œ jit ์ถ”์ ์„ ํ†ตํ•ด ์ถ”์  ๊ฐ€๋Šฅ์„ฑ ์ œ์•ฝ์„ ์™„ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด jit์— static_argnums ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ผ๋ถ€ ์ธ์ˆ˜์˜ ๊ตฌ์ฒด์ ์ธ ๊ฐ’์„ ์ถ”์ ํ•˜๋„๋ก ์ง€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์€ ์ด์— ๋Œ€ํ•œ ์˜ˆ์ œ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))

๋ฃจํ”„๋ฅผ ํฌํ•จํ•œ ๋˜๋‹ค๋ฅธ ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2)

static_argnums๋ฅผ ์ด์šฉํ•œ ํšจ๊ณผ๋กœ ์ธํ•ด ๋ฃจํ”„๊ฐ€ ์ •์ ์œผ๋กœ ํŽผ์ณ์ง‘๋‹ˆ๋‹ค. JAX๋Š” ๋˜ํ•œ Unshaped์™€ ๊ฐ™์€ ๋” ๋†’์€ ์ˆ˜์ค€์˜ ์ถ”์ƒํ™”์—์„œ ์ถ”์ ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ํ˜„์žฌ ๋ณ€ํ™˜์˜ ๊ธฐ๋ณธ๊ฐ’์€ ์•„๋‹™๋‹ˆ๋‹ค.

โš ๏ธย ์ธ์ˆ˜ ๊ฐ’์— ๋”ฐ๋ผ ํ˜•ํƒœ๊ฐ€ ๋ฐ”๋€Œ๋Š” ํ•จ์ˆ˜

์ด๋Ÿฌํ•œ ์ œ์–ด ํ๋ฆ„ ๋ฌธ์ œ๋Š” ๋ณด๋‹ค ๋ฏธ๋ฌ˜ํ•œ ๋ฐฉ์‹์œผ๋กœ๋„ ๋‚˜ํƒ€๋‚ฉ๋‹ˆ๋‹ค. jit ํ•˜๋ ค๋Š” ์ˆ˜์น˜ ํ•จ์ˆ˜๋Š” ๋‚ด๋ถ€ ๋ฐฐ์—ด์˜ ๋ชจ์–‘์„ ์ธ์ˆ˜ ๊ฐ’์— ๋”ฐ๋ผ ํŠน์ •ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค(์ธ์ˆ˜ ๋ชจ์–‘์— ๋”ฐ๋ผ ํŠน์ •ํ•˜๋Š” ๊ฒƒ์€ ๊ดœ์ฐฎ์Šต๋‹ˆ๋‹ค). ๊ฐ„๋‹จํ•œ ์˜ˆ๋กœ ์ž…๋ ฅ ๋ณ€์ˆ˜ ๊ธธ์ด์— ๋”ฐ๋ผ ์ถœ๋ ฅ์ด ๋‹ฌ๋ผ์ง€๋Š” ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))

static_argnums๋Š” ์˜ˆ์ œ์—์„œ ๊ธธ์ด์˜ ๋ณ€๊ฒฝ์ด ์žฆ์ง€ ์•Š์€ ๊ฒฝ์šฐ์—๋Š” ํŽธ๋ฆฌํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ๋ณ€๊ฒฝ์ด ์žฆ์€ ๊ฒฝ์šฐ ์žฌ์•™์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!

๋งˆ์ง€๋ง‰์œผ๋กœ ํ•จ์ˆ˜์— ์ „์—ญ์ ์ธ ๋ถ€์ˆ˜ํšจ๊ณผ๋“ค์ด ์žˆ๋Š” ๊ฒฝ์šฐ JAX์˜ ์ถ”์  ํ”„๋กœ๊ทธ๋žจ์œผ๋กœ ์ธํ•ด ์ด์ƒํ•œ ์ผ์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ ๋ฌธ์ œ๋Š” jitโ€™d ํ•จ์ˆ˜ ๋‚ด์—์„œ ๋ฐฐ์—ด์„ ์ถœ๋ ฅํ•˜๋ ค๊ณ  ์‹œ๋„ํ•  ๋•Œ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)

๊ตฌ์กฐ์  ์ œ์–ด ํ๋ฆ„ ๊ธฐ๋ณธ ์š”์†Œ#

JAX์—๋Š” ์ œ์–ด ํ๋ฆ„์— ๋Œ€ํ•œ ๋‹ค์–‘ํ•œ ์˜ต์…˜์ด ๋งŽ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์žฌ์ปดํŒŒ์ผ์„ ํ”ผํ•˜๊ณ  ์ถ”์  ๊ฐ€๋Šฅํ•œ ์ œ์–ด ํ๋ฆ„์„ ์‚ฌ์šฉํ•˜๋ฉด์„œ ํฐ ๋ฃจํ”„๋ฅผ ํ’€๊ณ  ์‹ถ์ง€ ์•Š๋‹ค๋ฉด ์•„๋ž˜์˜ 4๊ฐ€์ง€ ๊ตฌ์กฐ์  ์ œ์–ด ํ๋ฆ„ ๊ธฐ๋ณธ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • lax.condย differentiable

  • lax.while_loopย fwd-mode-differentiable

  • lax.fori_loopย fwd-mode-differentiableย in general;ย fwd and rev-mode differentiableย if endpoints are static.

  • lax.scanย differentiable

cond#

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)

jax.lax์—๋Š” ๋™์  ์กฐ๊ฑด์— ๋”ฐ๋ผ ๋ถ„๊ธฐํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ค๋ฅธ ๋‘ ๊ฐœ์˜ ํ•จ์ˆ˜๊ฐ€ ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค.

  • lax.select๋Š” lax.cond์˜ ๋ฐฐ์น˜ ๋ฒ„์ „์ด์ง€๋งŒ, ์„ ํƒ์ง€๋Š” ์ด์ „์— ๊ณ„์‚ฐ๋œ ๋ฐฐ์—ด๋กœ ํ‘œํ˜„๋ฉ๋‹ˆ๋‹ค.

  • lax.switch๋Š” lax.cond์™€ ์œ ์‚ฌํ•˜์ง€๋งŒ, ์–ด๋–ค ์ˆ˜์˜ ํ˜ธ์ถœ ๊ฐ€๋Šฅํ•œ ์„ ํƒ์ง€ ์‚ฌ์ด์— ์ „ํ™˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ, jax.numpy์—์„œ๋Š” ์ด๋Ÿฌํ•œ ํ•จ์ˆ˜์— ๋Œ€ํ•œ ๋‹ค์ˆ˜์˜ Numpy ์Šคํƒ€์ผ ์ธํ„ฐํŽ˜์ด์Šค๊ฐ€ ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค.

  • jnp.where๋Š” 3๊ฐœ์˜ ์ธ์ˆ˜๊ฐ€์žˆ๋Š” lax.select์˜ Numpy ์Šคํƒ€์ผ ๋ž˜ํผ(wrapper)์ž…๋‹ˆ๋‹ค.

  • jnp.piecewise๋Š” lax.switch์˜ Numpy ์Šคํƒ€์ผ ๋ž˜ํผ(wrapper)์ด์ง€๋งŒ, ๋‹จ์ผ ์Šค์นผ๋ผ ์ธ๋ฑ์Šค ๋Œ€์‹ ์— ๋ถˆ๋ฆฌ์–ธ ์กฐ๊ฑด์˜ ๋ชฉ๋ก์— ๋”ฐ๋ผ ์ „ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

  • jnp.select๋Š” jnp.piecewise์™€ ์œ ์‚ฌํ•œ API๋ฅผ ๊ฐ€์ง€์ง€๋งŒ, ์„ ํƒ์ง€๋Š” ์‚ฌ์ „ ๊ณ„์‚ฐ๋œ ๋ฐฐ์—ด๋กœ ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ์ ์œผ๋กœ lax.select์˜ ์—ฌ๋Ÿฌ ํ˜ธ์ถœ๋กœ ๊ตฌํ˜„๋ฉ๋‹ˆ๋‹ค.

while_loop#

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)

fori_loop#

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)

Summary#

์Šคํฌ๋ฆฐ์ƒท 2023-02-05 ์˜คํ›„ 10.08.19.png

๐Ÿ”ช ๋™์  ํ˜•ํƒœ (Dynamic Shapes)#

jax.jit, jax.vmap, jax.grad ๋“ฑ๊ณผ ๊ฐ™์€ ๋ณ€ํ™˜ ๋‚ด์—์„œ ์‚ฌ์šฉ๋˜๋Š” JAX ์ฝ”๋“œ๋Š” ๋ชจ๋“  ์ถœ๋ ฅ ๋ฐฐ์—ด๊ณผ ์ค‘๊ฐ„ ๋ฐฐ์—ด์ด ์ •์  ๋ชจ์–‘์„ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ๋ชจ์–‘์€ ๋‹ค๋ฅธ ๋ฐฐ์—ด ๋‚ด์˜ ๊ฐ’์— ์˜์กดํ•˜์ง€ ์•Š์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, jnp.nansum์˜ ๋ฒ„์ „์„ ์ง์ ‘ ๊ตฌํ˜„ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค

def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum()

JIT ๋ฐ ๊ธฐํƒ€ ๋ณ€ํ™˜ ์™ธ๋ถ€์—์„œ๋Š” ์˜ˆ์ƒ๋Œ€๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.

x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))

jax.jit ๋˜๋Š” ๋‹ค๋ฅธ ๋ณ€ํ™˜์„ ์ด ํ•จ์ˆ˜์— ์ ์šฉํ•˜๋ ค๊ณ  ํ•˜๋ฉด ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

jax.jit(nansum)(x)

๋ฌธ์ œ๋Š” x_without_nans์˜ ํฌ๊ธฐ๊ฐ€ x ๋‚ด์˜ ๊ฐ’์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ๊ฒฐ์ •๋˜๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค. ์ข…์ข… JAX์—์„œ๋Š” ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์„ ํ†ตํ•ด ๋™์ ์œผ๋กœ ํฌ๊ธฐ ์กฐ์ •๋œ ๋ฐฐ์—ด์˜ ํ•„์š”์„ฑ์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด jnp.where์˜ 3๊ฐœ ์ธ์ˆ˜ ํ˜•์‹์„ ์‚ฌ์šฉํ•˜์—ฌ NaN ๊ฐ’์„ 0์œผ๋กœ ๋Œ€์ฒดํ•จ์œผ๋กœ์จ ๋™์  ๋ชจ์–‘์„ ํ”ผํ•˜๋ฉด์„œ๋„ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()

print(nansum_2(x))

๋™์  ๋ชจ์–‘์˜ ๋ฐฐ์—ด์ด ๋ฐœ์ƒํ•˜๋Š” ๋‹ค๋ฅธ ์ƒํ™ฉ์—์„œ๋„ ์œ ์‚ฌํ•œ ํŠธ๋ฆญ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๐Ÿ”ช NaNs#

NaNs ๋””๋ฒ„๊น…#

ํ•จ์ˆ˜ ๋˜๋Š” ๊ทธ๋ž˜๋””์–ธํŠธ์—์„œ NaN์ด ๋ฐœ์ƒํ•˜๋Š” ์œ„์น˜๋ฅผ ์ถ”์ ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด NaN ๊ฒ€์‚ฌ๊ธฐ๋ฅผ ์ผค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • JAX_DEBUG_NANS=Trueย ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •

  • ๋ฉ”์ธ ํŒŒ์ผ ์ƒ๋‹จ์— ย fromย jax.configย importย configย ์™€ย config.update("jax_debug_nans",True)ย ๋ฅผ ์ถ”๊ฐ€ํ•˜์„ธ์š”.

  • ๋ฉ”์ธ ํŒŒ์ผ์—ย fromย jax.configย importย configย ์™€ย config.parse_flags_with_absl()ย ๋ฅผ ์ถ”๊ฐ€ํ•˜์„ธ์š”. ๊ทธ๋Ÿฐ ๋‹ค์Œ ๋ช…๋ น ์ค„ ํ”Œ๋ž˜๊ทธ์—ย -jax_debug_nans=True ์„ ์ด์šฉํ•˜์—ฌ ์˜ต์…˜์„ ์„ค์ •ํ•˜์„ธ์š”.

์ด๋กœ ์ธํ•ด NaN ์ƒ์„ฑ ์ฆ‰์‹œ ๊ณ„์‚ฐ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค. ์ด ์˜ต์…˜์„ ์ผœ๋ฉด XLA์—์„œ ์ƒ์„ฑ๋œ ๋ชจ๋“  ๋ถ€๋™ ์†Œ์ˆ˜์  ์œ ํ˜• ๊ฐ’์— nan ๊ฒ€์‚ฌ๊ฐ€ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค. ์ฆ‰, @jit์— ์†ํ•˜์ง€ ์•Š๋Š” ๋ชจ๋“  ๊ธฐ๋ณธ ์ž‘์—…์— ๋Œ€ํ•ด ๊ฐ’์„ ๋‹ค์‹œ ํ˜ธ์ŠคํŠธ๋กœ ๊ฐ€์ ธ์™€ ndarry๋กœ ๊ฒ€์‚ฌํ•ฉ๋‹ˆ๋‹ค. @jit ํ•˜์œ„์— ์žˆ๋Š” ์ฝ”๋“œ์˜ ๊ฒฝ์šฐ ๋ชจ๋“  @jitํ•จ์ˆ˜์˜ ์ถœ๋ ฅ์„ ๊ฒ€์‚ฌํ•˜๊ณ  NaN์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์ตœ์ ํ™”๋˜์ง€ ์•Š์€ op-by-op ๋ชจ๋“œ์—์„œ ํ•จ์ˆ˜๋ฅผ ๋‹ค์‹œ ์‹คํ–‰ํ•˜์—ฌ ํ•œ ๋ฒˆ์— ํ•œ ๋ ˆ๋ฒจ์”ฉ @jit์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค.

@jit์—์„œ๋งŒ ๋ฐœ์ƒํ•˜์ง€๋งŒ ์ตœ์ ํ™”๋˜์ง€ ์•Š์€ ๋ชจ๋“œ์—์„œ๋Š” ์ƒ์„ฑ๋˜์ง€ ์•Š๋Š” nan๊ณผ ๊ฐ™์€ ๊นŒ๋‹ค๋กœ์šด ์ƒํ™ฉ์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€๊ฐ€ ์ถœ๋ ฅ๋˜์ง€๋งŒ ์ฝ”๋“œ๋Š” ๊ณ„์† ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.

๊ทธ๋ž˜๋””์–ธํŠธ ํ‰๊ฐ€์˜ ์—ญ๋ฐฉํ–ฅ ํŒจ์Šค์—์„œ nans๊ฐ€ ์ƒ์„ฑ๋˜๋Š” ๊ฒฝ์šฐ ์Šคํƒ ์ถ”์ ์—์„œ ๋ช‡ ํ”„๋ ˆ์ž„ ์œ„๋กœ ์˜ˆ์™ธ๊ฐ€ ๋ฐœ์ƒํ•˜๋ฉด backward_pass ํ•จ์ˆ˜ ๋‚ด๋ถ€๋กœ ์ง„์ž…ํ•ฉ๋‹ˆ๋‹ค. ์ด ํ•จ์ˆ˜๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ๊ธฐ๋ณธ ์ž‘์—… ์‹œํ€€์Šค๋ฅผ ์—ญ์ˆœ์œผ๋กœ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฐ„๋‹จํ•œ jaxpr ์ธํ„ฐํ”„๋ฆฌํ„ฐ์ž…๋‹ˆ๋‹ค. ์•„๋ž˜ ์˜ˆ์—์„œ env JAX_DEBUG_NANS=True ipython ๋ช…๋ น์ค„์„ ์‚ฌ์šฉํ•˜์—ฌ ipython repl์„ ์‹œ์ž‘ํ•œ ํ›„, ๋‹ค์Œ์„ ์‹คํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค.

import jax.numpy as jnp
jnp.divide(0., 0.)

์ƒ์„ฑ๋œ NaN์ด ์žกํ˜”์Šต๋‹ˆ๋‹ค. %debug๋ฅผ ์‹คํ–‰ํ•˜๋ฉด ์‚ฌํ›„ ๋””๋ฒ„๊ฑฐ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์•„๋ž˜ ์˜ˆ์ œ์™€ ๊ฐ™์ด @jit ์œผ๋กœ ๊ฐ์‹ธ์ง„ ํ•จ์ˆ˜์—์„œ๋„ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.

In [1]: import jax.numpy as jnp

In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... stack trace ...

.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return DeviceArray(device_buffer, *result_shape)

FloatingPointError: invalid value
@jit
def f(x, y):
    a = x * y
    b = (x + y) / (x - y)
    c = a + 2
    return a + b * c
In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

์ด ์ฝ”๋“œ๋Š” @jit ํ•จ์ˆ˜์˜ ์ถœ๋ ฅ์—์„œ nan์„ ๋ฐœ๊ฒฌํ•˜๋ฉด ์ตœ์ ํ™”๋˜์ง€ ์•Š์€ ์ฝ”๋“œ๋ฅผ ํ˜ธ์ถœํ•˜๋ฏ€๋กœ ์—ฌ์ „ํžˆ ๋ช…ํ™•ํ•œ ์Šคํƒ์„ ์ถ”์ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  %debug๋กœ ์‚ฌํ›„ ๋””๋ฒ„๊ฑฐ๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ์˜ค๋ฅ˜๋ฅผ ํŒŒ์•…ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋“  ๊ฐ’์„ ๊ฒ€์‚ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

โš ๏ธ ๋””๋ฒ„๊น…ํ•˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ NaN ๊ฒ€์‚ฌ๊ธฐ๋ฅผ ์ผœ์„œ๋Š” ์•ˆ ๋ฉ๋‹ˆ๋‹ค. ๋งŽ์€ ์žฅ์น˜-ํ˜ธ์ŠคํŠธ ์™•๋ณต ๋ฐ ์„ฑ๋Šฅ ์ €ํ•˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค!

โš ๏ธ NaN ๊ฒ€์‚ฌ๊ธฐ๋Š” pmap์—์„œ ์ž‘๋™ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. pmap ์ฝ”๋“œ์—์„œ nans๋ฅผ ๋””๋ฒ„๊น…ํ•˜๋ ค๋ฉด pmap์„ vmap์œผ๋กœ ๊ต์ฒดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๐Ÿ”ช Double (64bit) ์ •๋ฐ€๋„ (Double (64bit) precision)#

ํ˜„์žฌ JAX๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ NumPy API๊ฐ€ ํ”ผ์—ฐ์‚ฐ์ž๋ฅผ ๊ฐ•์ œ๋กœ ๋”๋ธ”ํ˜•(double)์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒฝํ–ฅ์„ ์™„ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋‹จ์ •๋ฐ€๋„(single-precision) ์ˆซ์ž๋ฅผ ๊ฐ•์ œ๋กœ ์ ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋งŽ์€ ๋จธ์‹ ๋Ÿฌ๋‹ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์—์„œ ์›ํ•˜๋Š” ๋™์ž‘์ด์ง€๋งŒ, ์ด๋Š” ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ดˆ๋ž˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!

x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype

Double ์ •๋ฐ€๋„์˜ ์ˆซ์ž๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด, ์‹œ์ž‘ ์‹œ jax_enable_x64 ๊ตฌ์„ฑ ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ๋ช‡๊ฐ€์ง€ ๋ฐฉ๋ฒ•์ด ์žˆ์Šต๋‹ˆ๋‹ค.

  1. JAX_ENABLE_X64=True๋กœ ์„ค์ •ํ•˜์—ฌ 64๋น„ํŠธ ๋ชจ๋“œ๋ฅผ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  2. ์‹œ์ž‘ ์‹œ์—ย jax_enable_x64ย ๊ตฌ์„ฑ ํ”Œ๋ž˜๊ทธ๋ฅผ ์ˆ˜๋™์œผ๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
  1. absl.app.run(main)์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ช…๋ น์ค„ ํ”Œ๋ž˜๊ทธ๋ฅผ ํŒŒ์‹ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from jax.config import config
config.config_with_absl()
  1. absl.app.run(main)๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  JAX๊ฐ€ absl ํŒŒ์‹ฑ์„ ์ˆ˜ํ–‰ํ•˜๊ฒŒ ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:

from jax.config import config
if __name__ == '__main__':
  # calls config.config_with_absl() *and* runs absl parsing
  config.parse_flags_with_absl()

#2-#4๋Š” JAX์˜ ๋ชจ๋“  ๊ตฌ์„ฑ ์˜ต์…˜์—์„œ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.

๊ทธ๋Ÿฐ ๋‹ค์Œ x64 ๋ชจ๋“œ๊ฐ€ ํ™œ์„ฑํ™”๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')

์ฃผ์˜์‚ฌํ•ญ#

โš ๏ธ XLA๋Š” ๋ชจ๋“  ๋ฐฑ์—”๋“œ์—์„œ 64๋น„ํŠธ ์ปจ๋ณผ๋ฃจ์…˜์„ ์ง€์›ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค!

๐Ÿ”ช NumPy์—์„œ ์œ ๋ž˜๋œ ์—ฌ๋Ÿฌ๊ฐ€์ง€ ํŒŒ์ƒ๋“ค (Miscellaneous Divergences from NumPy)#

jax.numpy๋Š” Numpy API ๋™์ž‘์„ ์œ ์‚ฌํ•˜๊ฒŒ ๋™์ž‘ํ•˜๋„๋ก ์„ค๊ณ„๋˜์—ˆ์ง€๋งŒ, ๋™์ž‘์ด ๋‹ค๋ฅธ ์ฝ”๋„ˆ์ผ€์ด์Šค๋“ค์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋งŽ์€ ๊ฒฝ์šฐ๋“ค์€ ์œ„ ์„น์…˜์—์„œ ์ž์„ธํžˆ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—๋Š” API๊ฐ€ ๋‹ค๋ฅด๊ฒŒ ๋™์ž‘ํ•˜๋Š” ๋ช‡ ๊ฐ€์ง€ ๋‹ค๋ฅธ ์‚ฌ๋ก€๋“ค์„ ๋‚˜์—ปํ•ฉ๋‹ˆ๋‹ค.

  • ๋ฐ”์ด๋„ˆ๋ฆฌ ์ž‘์—…์˜ ๊ฒฝ์šฐ JAX์˜ ์œ ํ˜• ์Šน๊ฒฉ ๊ทœ์น™์€ NumPy์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ทœ์น™๊ณผ ๋‹ค์†Œ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Type Promotion Semantics๋ฅผ ์ฐธ์กฐํ•˜์‹ญ์‹œ์˜ค.

  • ์•ˆ์ „ํ•˜์ง€ ์•Š์€ ์œ ํ˜• ์บ์ŠคํŒ…(์ฆ‰, ๋Œ€์ƒ dtype์ด ์ž…๋ ฅ ๊ฐ’์„ ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์—†๋Š” ์บ์ŠคํŒ…)๋ฅผ ์ˆ˜ํ–‰ํ•  ๋•Œ JAX์˜ ๋™์ž‘์€ ๋ฐฑ์—”๋“œ์— ๋”ฐ๋ผ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์œผ๋ฉฐ ์ผ๋ฐ˜์ ์œผ๋กœ NumPy์˜ ๋™์ž‘๊ณผ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Numpy๋Š” ์บ์ŠคํŒ… ์ธ์ˆ˜๋ฅผ ํ†ตํ•ด ์ด๋Ÿฌํ•œ ์‹œ๋‚˜๋ฆฌ์˜ค์—์„œ ๊ฒฐ๊ณผ๋ฅผ ์ œ์–ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(np.ndarray.astype ์ฐธ์กฐ). JAX๋Š” XLA:ConvertElementType์˜ ๋™์ž‘์„ ์ง์ ‘ ์ƒ์†ํ•˜๋Š” ๋Œ€์‹  ์ด๋Ÿฌํ•œ ๊ตฌ์„ฑ์„ ์ œ๊ณตํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์— NumPy์™€ JAX์˜ ์•ˆ์ „ํ•˜์ง€ ์•Š์€ ์บ์ŠคํŒ…์— ๋”ฐ๋ฅธ ๋‹ค๋ฅธ ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.

np.arange(254.0, 258.0).astype('uint8')
array([254, 255,   0,   1], dtype=uint8)

jnp.arange(254.0, 258.0).astype('uint8')
DeviceArray([254, 255, 255, 255], dtype=uint8)

์ด๋Ÿฌํ•œ ์ข…๋ฅ˜์˜ ๋ถˆ์ผ์น˜๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ๋ถ€๋™์—์„œ ์ •์ˆ˜ ์œ ํ˜•์œผ๋กœ ๋˜๋Š” ๊ทธ ๋ฐ˜๋Œ€๋กœ ๊ทน๋‹จ์ ์ธ ๊ฐ’์„ ์บ์ŠคํŒ…ํ•  ๋•Œ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

Fin.#

์—ฌ๊ธฐ์—์„œ ๋‹ค๋ฃจ์ง€ ์•Š์€ ๋ช‡๋ช‡ ๋‹น์‹ ์„ ํ™”๋‚˜๊ฒŒ ํ•˜๋Š” ์›์ธ๋“ค์„ ์ œ๋ณดํ•ด์ฃผ์‹œ๋ฉด ํ•ด๋‹น ํŠœํ† ๋ฆฌ์–ผ ํŽ˜์ด์ง€์— ๋ฐ˜์˜ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.