JAX의 고급 자동 미분#

Open in Colab

저자: Vlatimir Miklik & Matteo Hessel

역자 : 이영빈

검수 : 김한빈, 박정현

기울기(gradient)를 계산하는 것은 현대 머신러닝 기법에서 중요한 영역입니다. 이번 섹션은 현대 머신러닝에 관련된 자동 미분의 몇가지 고급 주제를 다룹니다.

JAX를 사용하는 대부분의 경우, 자동 미분 동작 방식을 이해할 필요는 없습니다. 하지만 더 깊은 이해를 위해 이 영상을 시청하는 것을 권장합니다.

자동 미분 쿡북 섹션은 JAX 백엔드에서 자동 미분 아이디어가 구현되는 방식에 대해 더 높은 수준의 자세한 설명을 제공합니다. 이는 JAX를 이해하기 위해 필수 조건은 아닙니다. 그러나 맞춤 미분 정의(defining custom derivatives)와 같은 몇가지 기능들은 이것을 이해해야 사용할 수 있습니다. 따라서 그러한 기능들을 사용해야할 때를 대비하여 알아둘 가치가 있습니다.

고계도함수(Higher-order derivatives)#

JAX의 자동미분을 통해 고계도함수를 쉽게 계산할 수 있다. 도함수를 계산하는 함수 그 자체가 미분가능한 함수이기 때문이다. 그러므로 고계도함수는 변환을 쌓는것처럼 쉽게 구현할 수 있습니다.

이는 단일 변수 사례를 통해 확인할 수 있습니다.

\(f(x) = x^3 + 2x^2 - 3x + 1\)의 도함수는 다음과 같이 계산됩니다.

import jax

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f)

\(f\)의 고차 미분은 다음과 같습니다.

\[\begin{split} \begin{array}{l} f'(x) = 3x^2 + 4x -3\\ f''(x) = 6x + 4\\ f'''(x) = 6\\ f^{iv}(x) = 0 \end{array} \end{split}\]

JAX에서 이 모든 계산은 grad 함수를 연쇄적으로 사용하는 것만으로 쉽게 해결됩니다.

d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

위의 내용을 \(x=1\)이라 넣고 계산하면 다음과 같습니다.:

\[\begin{split} \begin{array}{l} f'(1) = 4\\ f''(1) = 10\\ f'''(1) = 6\\ f^{iv}(1) = 0 \end{array} \end{split}\]

JAX를 사용하면:

print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
4.0
10.0
6.0
0.0

다변수인 경우, 고계도함수는 더 복잡합니다. 어떤 함수의 2계도함수는 해당 함수의 헤시안 행렬로 표현될 수 있습니다. 이는 다음과 같이 정의됩니다.

\[(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.\]

여러 변수의 실수 함수의 헤시안은 (\(f: \mathbb R^n\to\mathbb R\))은 함수의 그레디언트의 자코비안 행렬과 동일하게 볼 수 있습니다. JAX는 함수의 자코비안을 계산하기 위해 jax.jacfwdjax.jacrev라는 2가지 변환을 제공합합니다. jax.jacfwd는 순뱡향 자동미분이며 jax.jacrev는 역방향향 자동미분이다. 이 변환들은 같은 답을 제공하지만 환경에 따른 효율성 차이가 있습니다. -자세한 내용은 자동미분에 대한 비디오를 참고하세요.

def hessian(f):
  return jax.jacfwd(jax.grad(f))

접곱에서도 맞는지 다시 한번 확인해보자. \(f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}\).

if \(i=j\), \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2\). Else, \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0\).

import jax.numpy as jnp

def f(x):
  return jnp.dot(x, x)

hessian(f)(jnp.array([1., 2., 3.]))
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-3-5df217671693> in <module>
      4   return jnp.dot(x, x)
      5 
----> 6 hessian(f)(jnp.array([1., 2., 3.]))


NameError: name 'hessian' is not defined

한편 헤시안 행렬 전체를 항상 계산할 필요는 없으며, 이는 매우 비효율적이기도 합니다. 자동미분 쿡북에서 헤시안-벡터곱(Hessian-vector product)과 같은 몇가지 트릭을 설명합니다. 헤시안-벡터곱은 헤시안 행렬 전체를 구현하지 않으면서 헤시안을 사용합니다.

만일 JAX에서 고계도함수를 사용하고자 한다면, 자동미분 쿡북을 일독하는 것을 강력히 권장합니다.

고차 최적화(Higher order optimization)

Model-Agnostic Meta-Learning(MAML) 과 같은 메타러닝 기술들은 기울기 업데이트를 통한 미분이 필요합니다. 다른 프레임워크에서 이는 꽤 번거롭지만, JAX에서는 매우 간단합니다.

def meta_loss_fn(params, data):
  """Computes the loss after one step of SGD."""
  grads = jax.grad(loss_fn)(params, data)
  return loss_fn(params - lr * grads, data)

meta_grads = jax.grad(meta_loss_fn)(params, data)

기울기 중지#

자동미분은 함수의 입력에 대한 기울기를 자동으로 계산합니다. 그러나 때때로 추가적인 제어가 필요합니다. 예를 들어 계산 그래프 중 일부에서 기울기 역전파를 원하지 않을 수도 있습니다.

예를 들어 TD(0)(temporal difference) 강화학습 업데이트를 예로 들겠습니다. 이 업데이트는 어떤 환경과 상호작용하는 경험을 통해 그 환경에서 상태의 가치를 추정할 때 사용됩니다. 상태 \(s_{t-1}\)에 있는 값 추정치 \(v_{\theta}(s_{t-1}\))가 선형 함수에 의해 파라미터화된다는걸 가정해봅시다.

# 가치함수 와 초기 매개변수
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

우리가 보상 \(r_t\)를 관찰하고 있을때 상태 \(s_{t-1}\)에서 상태 \(s_t\)로 이동한다고 가정해봅시다.

# 전환예시
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

네트워크 매개변수들을 업데이트한 TD(0)는 다음과 같습니다.

\[ \Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) \]

이 업데이트는 어떠한 손실 함수의 그레디언트가 아닙니다.

그러나 만일 파라미터 \(\theta\)에서 타겟인 \(r_t + v_{\theta}(s_t)\)의 의존성을 무시한다면 해당 업데이트는 가짜 손실함수의 그레디언트로 쓰일수도 있습니다.

\[ L(\theta) = [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 \]

어떻게 하면 JAX로 이를 구현할 수 있을까요? 우리가 유사 손실을 나이브하게 작성한다면 아래와 같습니다.

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (target - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
DeviceArray([ 2.4, -2.4,  2.4], dtype=float32)

그러나 td_update는 TD(0) 업데이트를 계산하지 않습니다.. 왜냐하면 그레디언트 계산은 \(\theta\)에서 target의 의존성을 포함할 것이기 때문입니다.

우리는 jax.lax.stop_gradient를 이용해 JAX가 \(\theta\)에서 타겟의 의존성을 무시되도록 강제할 수 있습니다.

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (jax.lax.stop_gradient(target) - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

이것은 마치 파라미터 \(\theta\)의존하지 않고 파라미터로 정확한 업데이트를 계산하는 것처럼 target을 계산합니다.

jax.lax.stop_gradient는 다른 상황에서도 매우 유용할 수도 있습니다. 예를 들어 만일 당신이 뉴럴 네트워크의 파라미터중 일부에게만 영향을 주기 위해 몇몇개의 손실로부터 그레디언트를 원한다면 유용하게 사용할 수 있다. 왜냐하면 다른 파라미터들은 다른 손실을을 사용해 훈련할 수 있기 떄문이다.

stop_gradient를 이용한 Straight-through 측정기 (STE)#

STE는 STE를 사용하지 않으면 미분불가능한 함수의 ‘그레디언트’를 정의할 때 쓰는 트릭입니다다. 미분불가능한 함수 \(f : \mathbb{R}^n \to \mathbb{R}^n\) 가 우리가 더 큰 함수의 일부이며 우리가 그 함수의 그레디언트를 찾는것이라고 가정해봅시다. 우리는 단순히 역전파 하는 동안에 \(f\)가 항등함수로 간주합니다. 이는 jax.lax.stop_gradient를 사용해 깔끔하게 구현됩니다.

def f(x):
  return jnp.round(x)  #미분불가능합니다.

def straight_through_f(x):
  #정확히 한 개의 기울기를 가진 Sterbenz 보조정리를 사용하여 정확히 0인 식을 만듭니다.
  zero = x - jax.lax.stop_gradient(x)
  return zero + jax.lax.stop_gradient(f(x))

print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x):  3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0

샘플별(Per-example) 기울기#

대부분의 머신러닝 시스템은 계산 효율성 또는 분산 감소를 위해 데이터 배치(batches)로부터 기울기와 업데이트를 계산합니다. 하지만 때로는 배치(batch)의 특정 샘플과 연관된 기울기 및 업데이트에 접근해야 합니다.

예를 들어 기울기 크기에 따라 데이터 우선 순위를 정하거나 각 샘플 단위로 클리핑(clipping), 정규화를 적용하기 위해 필요합니다.

Pytorch, TF, Theano 등의 프레임워크에서 샘플별 기울기를 계산하는 것은 작은 일이 아닙니다. 해당 라이브러리들이 배치의 기울기를 직접 누적하기 때문입니다. 샘플별 손실을 각각 계산한 후 결과로 나온 기울기를 집계하는 것과 같은 나이브한 차선책은 매우 비효율적입니다.

JAX에서는 쉽고 효율적인 방법으로 샘플별 기울기 계산을 정의할 수 있습니다.

jit, vmap 그리고 grad 변환을 같이 조합하기만 하면 됩니다.

perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

# 테스트해봅시다.
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

이 변환을 한 번에 하나씩 살펴봅시다.

우선 td_lossjax.grad를 적용해 (배치 적용 되지 않은) 단일입력의 매개변수에 대한 손실의 기울기를 계산하는 함수를 얻습니다.

dtdloss_dtheta = jax.grad(td_loss)

dtdloss_dtheta(theta, s_tm1, r_t, s_t)
DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

이 함수는 위 배열의 한 행을 계산합니다.

그리고나서 우리는 jax.vmap을 사용해 이 함수를 벡터화합니다. jax.vmap은 배치 차원에 모든 입력과 출력이 추가됩니다. 지금 입력의 배치가 있다고 하면 우리는 출력력의 배치를 얻는다. 배치의 각 출력은 입력 배치의 원소에 대한 기울기와 대응합니다.

almost_perex_grads = jax.vmap(dtdloss_dtheta)

batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

이것은 우리가 원하는 것이 아닙니다. 왜냐하면 우리는 이 함수를 수동으로 theta의 배치들을 줘야 하는데 우리는 실질적으로 하나의 theta를 사용하고 싶기 때문입니다. 우리는 in_axesjax.vmap을 추가해서 해결할 수 있습니다. 이때 theta는 None으로 치고 다른 매개변수들은 0으로 지정합니다. 이 방식은 결과로 나온 함수를 만들고 다른 매개변수들만 추가 축을 더하고 theta를 배치가 되지 않은채로 뺀다.

inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))

inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

거의 다 왔습니다! 하지만 이대로는 목표로했던 것보다 느립니다. 컴파일된 효율적인 버전을 얻기 위해 전부 jax.jit으로 감쌉니다.

perex_grads = jax.jit(inefficient_perex_grads)

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
10.6 ms ± 4.24 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
53.9 µs ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)