파라미터와 상태 관리하기#

번역: 이영빈 \

이번 세션에서는 다음과 같은 내용을 살펴볼 것이다.

  • 초기화부터 업데이트까지 변수 관리하기

  • 파라미터와 상태를 분리하고 재결합하기

  • 배치 종속적인 상태와 함께 vmap 사용하기


! pip install --upgrade pip jax jaxlib
! pip install --upgrade git+https://github.com/google/flax.git
! pip install optax
from functools import partial

import jax
from jax import numpy as jnp
from jax import random
import optax

from flax import linen as nn



# 랜덤 변수 초기화하기
dummy_input = jnp.ones((32, 5))

X = random.uniform(random.PRNGKey(0), (128, 5),  minval=0.0, maxval=1.0)
noise = random.uniform(random.PRNGKey(0), (),  minval=0.0, maxval=0.1)
X += noise

W = random.uniform(random.PRNGKey(0), (5, 1),  minval=0.0, maxval=1.0)
b = random.uniform(random.PRNGKey(0), (),  minval=0.0, maxval=1.0)

Y = jnp.matmul(X, W) + b

num_epochs = 5
class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats', 'mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      mean.value = (self.momentum * mean.value +
                    (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
    return mean.value + bias

이번 예시 모델은 파라미터(self.param)와 상태 변수(self.variables)를 모두 포함하는 간단한 모델이다.

여기서 초기화할 때 까다로운 부분은 상태 변수와 최적화할 매개변수를 분리해야 한다는 점입니다.

먼저 update_step을 다음과 같이 정의한다.(더미 손실 함수를 사용하여 사용자 함수로 대체해야 함):

@partial(jax.jit, static_argnums=(0,))
def update_step(apply_fn, x, opt_state, params, state):
  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum() # Replace with your loss here.
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state

실제 학습 코드를 작성할 수 있게 된다.

model = BiasAdderWithRunningMean()
variables = model.init(random.PRNGKey(0), dummy_input)
# 옵티마이저에 의해 업데이트 된 상태변수와 파라미터들을 나눈다.

state, params = variables.pop('params')
del variables  # 리소스를 낭비하는 변수를 제거한다.
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for epoch_num in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state)

##배치 차원에 걸친 vmap

vmap을 사용하면서 배치 차원에 따라 달라지는 상태를 관리하는 경우(예: BatchNorm을 사용하는 경우)에는 위의 설정을 약간 수정해야 한다. 배치 차원에 따라 상태가 달라지는 레이어는 엄밀히 말해 벡터화할 수 없기 때문이다. 배치 차원에 대한 통계의 평균을 내기 위해 lax.pmean()을 사용하여 배치의 각 항목에 대해 상태가 동기화되도록 해야 한다.

이를 위해서는 두 가지 작은 변경점이 있다. 먼저 모델 정의에서 배치 축의 이름을 지정해야 한다. 여기서는 BatchNormaxis_name 인수를 지정하여 이 작업을 수행한다. 사용자 코드에서는 lax.pmean()axis_name 인수를 직접 지정해야 할 수도 있다.

class MLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x, train=False):
    norm = partial(
        nn.BatchNorm,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        axis_name="batch", # 배치 차원의 이름을 짓는다.
    )

    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    y = nn.Dense(self.out_size)(x)

    return y

둘째, 훈련 코드에서 vmap을 호출할 때 동일한 이름을 지정해야 한다.

@partial(jax.jit, static_argnums=(0,))
def update_step(apply_fn, x_batch, y_batch, opt_state, params, state):

  def batch_loss(params):
    def loss_fn(x, y):
      pred, updated_state = apply_fn(
        {'params': params, **state},
        x, mutable=list(state.keys())
      )
      return (pred - y) ** 2, updated_state

    loss, updated_state = jax.vmap(
      loss_fn, out_axes=(0, None),  # `updated_state`을 vmap하면 안된다.
      axis_name='batch'  # 배치 차원의 이름을 짓는다.
    )(x_batch, y_batch)  # `state`는 빼고 `x`와 `y`에 vmap을 설정해라
    return jnp.mean(loss), updated_state

  (loss, updated_state), grads = jax.value_and_grad(
    batch_loss, has_aux=True
  )(params)

  updates, opt_state = tx.update(grads, opt_state)  # 아래와 같이 정의한다.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state, loss

Note that we also need to specify that the model state does not have a batch dimension. Now we are able to train the model:

모델 상태에 배치 차원이 없음을 명시해야 한다. 이제 모델을 훈련할 수 있다.

model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.PRNGKey(0), dummy_input)
# 옵티마이저에 의해 업데이트 된 상태변수와 파라미터들을 나눈다.
state, params = variables.pop('params')
del variables # 리소스를 낭비하는 변수를 제거한다.
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for epoch_num in range(num_epochs):
  opt_state, params, state, loss = update_step(
      model.apply, X, Y, opt_state, params, state)
  print(f"Loss for epoch {epoch_num + 1}:", loss)