배치 정규화 (Batch normalization)#
번역: 장혜선
이 가이드에서는 flax.linen.BatchNorm을 사용하여 배치 정규화를 적용하는 방법을 알아볼 것입니다.
배치 정규화는 학습 속도를 높이고 수렴을 개선하기 위해 사용되는 정규화 기법입니다. 학습 중, 특징 차원을 기준으로 이동 평균을 계산합니다. 이로 인해 비미분 (non-differentiable) 상태의 새로운 형태가 추가되므로 적절히 처리해야 합니다.
가이드에서 Flax의 BatchNorm 을 사용한 코드 예제와 사용하지 않은 코드 예제를 비교할 수 있습니다.
배치 정규화를 사용하여 모델 정의하기#
Flax에서 BatchNorm은 학습과 추론 사이에 서로 다른 런타임(runtime, 실행 시간?) 동작을 나타내는 flax.linen.Module입니다. 이는 아래 예시와 같이 use_running_average 인수를 사용하여 명시적으로 지정할 수 있습니다.
일반적인 패턴은 부모(parent) Flax 모듈(Module)에서 학습(학습 중인) 인수를 받아와 BatchNorm의 use_running_average 인수를 정의하는 것 입니다.
참고: Pytorch나 TensorFlow (Keras)와 같은 다른 머신 러닝 프레임워크에서는 이를 가변 상태 또는 호출 플래그(flag)를 통해 지정합니다. (예를 들어, torch.nn.Module.eval이나 tf.keras.Model에서 학습 플래그를 설정하는 것)
No BatchNorm
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=4)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
With BatchNorm
class MLP(nn.Module):
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(features=4)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
모델을 생성한 후에는 flax.linen.init()를 호출하여 variable 구조를 초기화합니다. 여기서 BatchNorm을 사용하지 않는 코드와 BatchNorm을 사용하는 코드의 주요 차이점은 train인수가 꼭 제공되어야 한다는 것입니다.
batch_stats 컬렉션(collection)#
BatchNorm은 params 컬렉션 외에도 배치 통계의 이동 평균을 포함하는 batch_stats 컬렉션을 추가합니다.
참고: flax.linen 변수 API 문서에서 더 자세한 내용을 확인할 수 있습니다.
batch_stats 컬렉션은 나중에 사용할 수 있도록 variables에서 꼭 추출되어야 합니다.
No BatchNorm
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x)
params = variables['params']
jax.tree_util.tree_map(jnp.shape, variables)
With BatchNorm
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']
jax.tree_util.tree_map(jnp.shape, variables)
Flax의 BatchNorm은 총 4개의 변수를 추가합니다: batch_stats 컬렉션에 있는 mean과 var, 그리고 params 컬렉션에 있는 scale과 bias 입니다.
No BatchNorm
FrozenDict({
'params': {
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
With BatchNorm
FrozenDict({
'batch_stats': {
'BatchNorm_0': {
'mean': (4,),
'var': (4,),
},
},
'params': {
'BatchNorm_0': {
'bias': (4,),
'scale': (4,),
},
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
flax.linen.apply 수정하기#
flax.linen.apply를 사용하여 train==True 인수와 모델을 실행할 때 (즉, BatchNorm 호출에서 use_running_average==False로 설정한 경우) 다음 사항을 고려해야 합니다.:
batch_stats는 입력 변수로 전달되어야 합니다.batch_stats컬렉션은mutable=['batch_stats]로 설정하여 가변으로 표시되어야 합니다.변경된 변수는 두 번째 출력으로 반환됩니다. 업데이트 된
batch_stats는 여기에서 추출되어야 합니다.
No BatchNorm
y = mlp.apply(
{'params': params},
x,
)
...
With BatchNorm
y, updates = mlp.apply(
{'params': params, 'batch_stats': batch_stats},
x,
train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']
학습 및 평가#
BatchNorm을 사용하는 모델을 학습 루프에 통합할 때 가장 큰 문제는 추가 batch_stats 상태를 처리하는 것 입니다. 이를 수행하기 위해서는 다음과 같이 해야합니다.:
사용자 정의
flax.training.train_state.TrainState클래스에batch_stats필드를 추가합니다.batch_stats값을train_state.TrainState.create메서드에 전달합니다.
No BatchNorm
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=mlp.apply,
params=params,
tx=optax.adam(1e-3),
)
With BatchNorm
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn=mlp.apply,
params=params,
batch_stats=batch_stats,
tx=optax.adam(1e-3),
)
또한 이러한 변경 사항을 반영하도록 train_step 함수를 업데이트 해야합니다.:
모든 새 매개변수를
flax.linen.apply에 전달합니다. (이전 설명대로)batch_stats에 대한updates는loss_fn에서 전파되어야 합니다.TrainState의batch_stats를 업데이트 해야합니다.
No BatchNorm
@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
With BatchNorm
@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
x=batch['image'], train=True, mutable=['batch_stats'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, (logits, updates)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, updates)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=updates['batch_stats'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
eval_step은 훨씬 간단합니다. 왜냐하면 batch_stats가 가변적이지 않기 때문에 업데이트를 전파할 필요가 없습니다. 다만 flax.linen.apply에 batch_stats를 전달하고 train 인수가 False로 설정되어 있는지 확인해야 합니다.:
No BatchNorm
@jax.jit
def eval_step(state: TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
With BatchNorm
@jax.jit
def eval_step(state: TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
x=batch['image'], train=False)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics