Quick Start#

번역: 김한빈, 박정현

Flax에 오신 것을 환영합니다!

Flax는 JAX 위에 구축된 오픈 소스 Python 신경망 라이브러리입니다. 이 튜토리얼은 Flax Linen API를 사용하여 간단한 합성곱 신경망(CNN)을 구축하고 MNIST 데이터셋에서 이미지 분류를 위해 해당 신경망을 훈련하는 방법을 보여줍니다.

1. Flax 설치하기#

!pip install -q flax

2. 데이터 로드하기#

Flax는 어떤 데이터 로딩 파이프라인이든 사용할 수 있으며, 이 예제에서는 TFDS를 활용하는 방법을 보여줍니다. MNIST 데이터셋을 로드하고 준비하는 함수를 정의하고, 샘플을 부동 소수점 숫자로 변환하는 함수입니다.

import tensorflow_datasets as tfds  # TFDS for MNIST
import tensorflow as tf             # TensorFlow operations

def get_datasets(num_epochs, batch_size):
  """Load MNIST train and test datasets into memory."""
  train_ds = tfds.load('mnist', split='train')
  test_ds = tfds.load('mnist', split='test')

  train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                           tf.float32) / 255.,
                                          'label': sample['label']}) # normalize train set
  test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                         tf.float32) / 255.,
                                        'label': sample['label']}) # normalize test set

  train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
  test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency

  return train_ds, test_ds

3. 네트워크 정의하기#

Flax Linen API를 사용하여 Flax Module을 서브클래싱하여 합성곱 신경망을 생성합니다. 이 예제에서 사용하는 아키텍처는 비교적 간단하므로-레이어를 단순히 쌓는 것- call 메소드 내에서 인라인 서브모듈을 직접 정의하고 @compact 데코레이터로 감싸는 방식으로 구현할 수 있습니다. Flax Linen @compact 데코레이터에 대해 자세히 알아보려면 “Setup vs Compact 가이드”를 참조하시기 바랍니다.

from flax import linen as nn  # Linen API

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

View model layers#

Flax Module의 인스턴스를 생성하고, Module.tabulate 메서드를 사용하여 모델 레이어의 테이블을 시각화합니다. 이를 위해 RNG 키와 템플릿 이미지 입력을 전달합니다.

import jax
import jax.numpy as jnp  # JAX NumPy

cnn = CNN()
print(cnn.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1))))

4. Create a TrainState#

Flax에서 일반적인 패턴은 step number, parameters, optimizer state를 포함한 전체 훈련 상태를 나타내는 단일 데이터 클래스를 생성하는 것입니다.

이러한 패턴은 매우 일반적이므로, Flax는 대부분의 기본 사용 사례를 지원하는 flax.training.train_state.TrainState 클래스를 제공합니다.

!pip install -q clu
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax                           # Common loss functions and optimizers

메트릭을 계산하기 위해 clu 라이브러리를 사용할 것입니다. clu에 대한 자세한 내용은 레포지토리와 노트북을 참조하세요.

@struct.dataclass
class Metrics(metrics.Collection):
  accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss')

그런 다음 metrics를 포함하는 train_state.TrainState의 서브클래스를 작성하여야 합니다. 이렇게 하면 train_step()(조금 더 아래에 코드가 있습니다)과 같은 함수에 단일 인수를 전달하여 손실을 계산하고 매개변수를 업데이트하며 동시에 메트릭을 계산할 수 있는 이점이 있습니다.

class TrainState(train_state.TrainState):
  metrics: Metrics

def create_train_state(module, rng, learning_rate, momentum):
  """Creates an initial `TrainState`."""
  params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
  tx = optax.sgd(learning_rate, momentum)
  return TrainState.create(
      apply_fn=module.apply, params=params, tx=tx,
      metrics=Metrics.empty())

5. Training step#

아래와 같은 기능을 수행하는 함수입니다:

  • TrainState.apply_fn (Module.apply 메소드(forward pass)를 포함하는)을 사용하여 매개변수와 일괄적인 입력 이미지로 신경망을 평가합니다.

  • 미리 정의된 optax.softmax_cross_entropy_with_integer_labels()를 사용하여 교차 엔트로피 손실을 계산합니다. 이 함수는 정수 레이블을 예상하므로 레이블을 원핫 인코딩으로 변환할 필요가 없습니다.

  • jax.grad를 사용하여 손실 함수의 기울기를 계산합니다.

  • 파라미터를 업데이트하기 위해 그래디언트의 pytree를 옵티마이저에 적용합니다.

JAX의 @jit 데코레이터를 사용하여 train_step 함수 전체를 추적하고 XLA로 JIT 컴파일하여 하드웨어 가속기에서 더 빠르고 효율적으로 실행되는 fused device 연산으로 변환합니다.

@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn({'params': params}, batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
    return loss
  grad_fn = jax.grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

6. Metric computation#

손실과 정확도 메트릭을 위한 별도의 함수를 작성합니다. 손실은 optax.softmax_cross_entropy_with_integer_labels 함수를 사용하여 계산하고, 정확도는 clu.metrics를 사용하여 계산합니다.

@jax.jit
def compute_metrics(*, state, batch):
  logits = state.apply_fn({'params': state.params}, batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
  metric_updates = state.metrics.single_from_model_output(
    logits=logits, labels=batch['label'], loss=loss)
  metrics = state.metrics.merge(metric_updates)
  state = state.replace(metrics=metrics)
  return state

7. 데이터 다운로드#

num_epochs = 10
batch_size = 32

train_ds, test_ds = get_datasets(num_epochs, batch_size)

8. Seed randomness#

  • 데이터셋 셔플을 재현할 수 있도록 TF random seed를 설정합니다.(tf.data.Dataset.shuffle 사용)

  • PRNGKey를 사용해 매개변수를 초기화 합니다.( JAX PRNG 디자인PRNG chains에 대해 자세히 알아보기).

tf.random.set_seed(0)
init_rng = jax.random.PRNGKey(0)

9. TrainState 초기화#

create_train_state 함수는 모델 매개변수, 옵티마이저 및 메트릭을 초기화 합니다. 이는 학습 상태(training state) 데이터 클래스에 입력되고, 해당 데이터 클래스가 함수의 출력으로 반환됩니다.

learning_rate = 0.01
momentum = 0.9
state = create_train_state(cnn, init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

10. 학습 및 평가#

“셔플된” 데이터셋을 생성합니다.

  • 데이터셋은 학습 에폭 수만큼 반복됩니다.

  • 무작위 배치를 샘플링할 1,024 크기의 버퍼를 할당합니다. 해당 버퍼는 첫 1,024개의 샘플을 포함합니다.

    • 버퍼에서 샘플이 무작위로 추출될 때마다, 데이터셋 내 다하음 샘플이 버퍼에 로드됩니다.

학습 루프를 정의합니다.

  • 데이터셋에서 배치를 무작위 샘플링합니다.

  • 각 학습 배치마다 최적화 단계를 실행합니다.

  • 에폭의 각 배치마다 평균 학습 메트릭을 계산합니다.

  • 업데이트된 매개변수를 사용하여 테스트셋 메트릭을 계산합니다.

  • 시각화를 위해 학습 및 테스트 메트릭을 기록합니다.

10 에폭 뒤 학습 및 테스트가 완료되면, 대략 99%의 정확도가 달성된 것을 확인할 수 있습니다.

# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
for step,batch in enumerate(train_ds.as_numpy_iterator()):

  # Run optimization steps over training batches and compute batch metrics
  state = train_step(state, batch) # get updated train state (which contains the updated parameters)
  state = compute_metrics(state=state, batch=batch) # aggregate batch metrics

  if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
    for metric,value in state.metrics.compute().items(): # compute metrics
      metrics_history[f'train_{metric}'].append(value) # record metrics
    state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch

    # Compute metrics on the test set after each training epoch
    test_state = state
    for test_batch in test_ds.as_numpy_iterator():
      test_state = compute_metrics(state=test_state, batch=test_batch)

    for metric,value in test_state.metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)

    print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
          f"loss: {metrics_history['train_loss'][-1]}, "
          f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
    print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
          f"loss: {metrics_history['test_loss'][-1]}, "
          f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")

11. 메트릭 시각화#

import matplotlib.pyplot as plt  # Visualization

# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train','test'):
  ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
  ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
plt.clf()

12. 테스트셋에서 추론 수행#

jit 컴파일된 추론 함수 pred_step을 정의합니다. 학습된 매개변수를 사용하여 테스트셋에서 모델 추론을 수행하고, 입력 이미지와 예측된 레이블을 시각화합니다.

@jax.jit
def pred_step(state, batch):
  logits = state.apply_fn({'params': state.params}, test_batch['image'])
  return logits.argmax(axis=1)

test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(state, test_batch)
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
    ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
    ax.set_title(f"label={pred[i]}")
    ax.axis('off')