JAX로 fit()에서 일어나는 일 사용자 정의하기 (Customizing what happens in fit() with JAX)#

저자: fchollet
역자: 조현석
검수: 이영빈, 박정현
생성 날짜: 2023/06/27
마지막 수정: 2023/06/27
설명: 모델 클래스의 학습 단계를 JAX로 오버라이드합니다.

View in Colab GitHub source

소개 (Introduction)#

지도 학습을 할 때 fit()을 사용하면 모든 것이 원활하게 작동합니다.

모든 세부 사항을 제어해야 하는 경우, 자신만의 학습 루프를 완전히 처음부터 작성할 수 있습니다.

하지만 커스텀 학습 알고리즘이 필요하지만 여전히 콜백, 기본 제공 배포 지원과 같은 fit()의 편리한 기능을 활용하고 싶다면 어떻게 해야 할까요? 또는 단계 융합과 같은 편리한 기능을 활용하고 싶으신가요?

Keras의 핵심 원칙은 복잡성의 점진적 공개입니다. 사용자는 항상 점진적인 방식으로 로우 레벨의 워크플로우에 진입할 수 있어야 합니다. 하이 레벨의 기능이 사용 사례와 정확히 일치하지 않는다고 해서 갑자기 로우 레벨로 바뀌면 안 됩니다. 높은 수준의 편의성을 유지하면서 작은 세부 사항을 더 잘 제어할 수 있어야 합니다.

fit()의 기능을 커스텀해야 하는 경우, Model 클래스의 학습 단계 함수를 오버라이드해야 합니다. 이 함수는 모든 데이터 배치에 대해 fit()에 의해 호출되는 함수입니다. 그러면 평소처럼 fit()을 호출할 수 있으며, 자체 학습 알고리즘이 실행됩니다.

이 패턴은 함수형 API로 모델을 빌드하는 것을 방해하지 않습니다. ‘시퀀셜’ 모델, 함수형 API 모델 또는 하위 클래스 모델을 빌드하든 상관없이 이 작업을 수행할 수 있습니다.

어떻게 작동하는지 살펴보겠습니다.

설정 (Setup)#

import os

# 이 가이드는 JAX 백엔드에서만 실행할 수 있습니다.
os.environ["KERAS_BACKEND"] = "jax"

import jax
import keras_core as keras
import numpy as np
Using JAX backend.

첫 번째 간단한 예시 (A first simple example)#

간단한 예제부터 시작하겠습니다:

  • 우리는 keras.Model을 상속하는 새로운 클래스를 생성합니다.

  • 모델의 비학습 가능 변수에 대한 업데이트된 값과 손실을 계산하기 위해 완전 스테이트리스 compute_loss_and_updates() 메서드를 구현합니다. 내부적으로는 stateless_call()과 내장된 compute_loss()를 호출합니다.

  • 완전 스테이트리스 train_step() 메서드를 구현하여 현재 메트릭 값(손실 포함)과 학습 가능한 변수, 옵티마이저 변수, 메트릭 변수에 대한 업데이트된 값을 계산합니다.

참고로 sample_weight 인수를 다음과 같이 고려할 수도 있습니다:

  • 데이터를 x, y, sample_weight = data로 언패킹합니다.

  • sample_weight compute_loss()`에 전달합니다.

  • sample_weight yy_pred 함께 전달합니다. 함께 stateless_update_state()`의 메트릭에 전달하기

class CustomModel(keras.Model):
    def compute_loss_and_updates(
        self,
        trainable_variables,
        non_trainable_variables,
        x,
        y,
        training=False,
    ):
        y_pred, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            x,
            training=training,
        )
        loss = self.compute_loss(x, y, y_pred)
        return loss, (y_pred, non_trainable_variables)

    def train_step(self, state, data):
        (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        ) = state
        x, y = data

        # 그라데이션 함수를 가져옵니다.
        grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

        # 그라디언트를 계산합니다.
        (loss, (y_pred, non_trainable_variables)), grads = grad_fn(
            trainable_variables,
            non_trainable_variables,
            x,
            y,
            training=True,
        )

        # 학습 가능한 변수 및 최적화 변수 업데이트합니다.
        (
            trainable_variables,
            optimizer_variables,
        ) = self.optimizer.stateless_apply(
            optimizer_variables, grads, trainable_variables
        )

        # 메트릭을 업데이트합니다.
        new_metrics_vars = []
        for metric in self.metrics:
            this_metric_vars = metrics_variables[
                len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
            ]
            if metric.name == "loss":
                this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
            else:
                this_metric_vars = metric.stateless_update_state(
                    this_metric_vars, y, y_pred
                )
            logs = metric.stateless_result(this_metric_vars)
            new_metrics_vars += this_metric_vars

        # 메트릭 로그와 업데이트된 상태 변수를 반환합니다.
        state = (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            new_metrics_vars,
        )
        return logs, state

한번 사용해 보겠습니다:

# CustomModel의 인스턴스를 생성하고 컴파일합니다.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# 평소처럼 'fit'을 사용하세요.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)

Epoch 1/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - mae: 0.4484 - loss: 0.2870
Epoch 2/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - mae: 0.4020 - loss: 0.2704
Epoch 3/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - mae: 0.4182 - loss: 0.2542





<keras_core.src.callbacks.history.History at 0x7be1c8068400>

로우 레벨로 해보기 (Going lower-level)#

당연히 compile()에서 손실 함수를 전달하는 것을 건너뛰고 대신 train_step에서 모든 것을 수동으로 할 수 있습니다. 메트릭도 마찬가지입니다.

다음은 compile()만 사용하여 옵티마이저를 구성하는 로우 레벨의 예제입니다:

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()

    def compute_loss_and_updates(
        self,
        trainable_variables,
        non_trainable_variables,
        x,
        y,
        training=False,
    ):
        y_pred, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            x,
            training=training,
        )
        loss = self.loss_fn(y, y_pred)
        return loss, (y_pred, non_trainable_variables)

    def train_step(self, state, data):
        (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        ) = state
        x, y = data

        # 그라데이션 함수를 가져옵니다.
        grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

        # 그라디언트를 계산합니다.
        (loss, (y_pred, non_trainable_variables)), grads = grad_fn(
            trainable_variables,
            non_trainable_variables,
            x,
            y,
            training=True,
        )

        # 학습 가능한 변수 및 최적화 변수 업데이트합니다.
        (
            trainable_variables,
            optimizer_variables,
        ) = self.optimizer.stateless_apply(
            optimizer_variables, grads, trainable_variables
        )

        # 메트릭을 업데이트합니다.
        loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]
        mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]

        loss_tracker_vars = self.loss_tracker.stateless_update_state(
            loss_tracker_vars, loss
        )
        mae_metric_vars = self.mae_metric.stateless_update_state(
            mae_metric_vars, y, y_pred
        )

        logs = {}
        logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
            loss_tracker_vars
        )
        logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)

        new_metrics_vars = loss_tracker_vars + mae_metric_vars

        # 메트릭 로그와 업데이트된 상태 변수를 반환합니다.
        state = (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            new_metrics_vars,
        )
        return logs, state

    @property
    def metrics(self):
        # 여기에 `Metric` 객체를 나열하여 `reset_states()`가
        # 각 에포크가 시작될 때 자동으로 호출되거나
        # 또는 `evaluate()`가 시작될 때 자동으로 호출될 수 있도록 합니다.
        return [self.loss_tracker, self.mae_metric]


# CustomModel의 인스턴스를 생성하고 컴파일합니다.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# 여기서는 손실이나 지표를 전달하지 않습니다.
model.compile(optimizer="adam")

# 평소처럼 `fit`을 사용하면 됩니다. 콜백 등을 사용할 수 있습니다.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)

Epoch 1/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.3712 - mae: 0.4860
Epoch 2/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.2637 - mae: 0.4173
Epoch 3/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.2385 - mae: 0.4012
Epoch 4/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.2407 - mae: 0.3952
Epoch 5/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.2145 - mae: 0.3782





<keras_core.src.callbacks.history.History at 0x7be1b02f3a90>

자체 평가 단계 만들기 (Providing your own evaluation step)#

model.evaluate() 호출에 대해 동일한 작업을 수행하려면 어떻게 해야 할까요? 그렇다면 test_step`을 정확히 같은 방식으로 오버라이드하면 됩니다. 이렇게 하면 됩니다:

class CustomModel(keras.Model):
    def test_step(self, state, data):
        # 데이터 언팩킹합니다.
        x, y = data
        (
            trainable_variables,
            non_trainable_variables,
            metrics_variables,
        ) = state

        # 예측과 손실을 계산합니다.
        y_pred, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            x,
            training=False,
        )
        loss = self.compute_loss(x, y, y_pred)

        # 메트릭을 업데이트합니다.
        new_metrics_vars = []
        for metric in self.metrics:
            this_metric_vars = metrics_variables[
                len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
            ]
            if metric.name == "loss":
                this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
            else:
                this_metric_vars = metric.stateless_update_state(
                    this_metric_vars, y, y_pred
                )
            logs = metric.stateless_result(this_metric_vars)
            new_metrics_vars += this_metric_vars

        # 메트릭 로그와 업데이트된 상태 변수를 반환합니다.
        state = (
            trainable_variables,
            non_trainable_variables,
            new_metrics_vars,
        )
        return logs, state


# CustomModel의 인스턴스를 생성합니다.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# 사용자 정의 test_step으로 평가하기
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)

32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - mae: 0.6693 - loss: 0.6276





[0.6276098489761353, 0.6762693524360657]

이게 전부입니다!