JAX에서 병렬 평가(Parallel Evaluation)#

Open in Colab

저자 : Vladimir Mikulik & Roman Ring
번역 : 유현아
검수: 장혜선, 조영빈

이번 세션에서는 SPMD(single-program, multiple-data) 코드를 위해 JAX에 내장된 기능을 설명합니다.

SPMD는 동일한 계산(예: 신경망의 순전파(forward pass))이 병렬로 다른 입력 데이터(예: 배치의 다른 입력)에서 다른 디바이스(예: 여러 개의 TPU)에 대해 실행되는 병렬 처리 기술을 의미합니다.

개념적으로 이것은 동일한 작업이 동일한 디바이스의 다른 메모리 부분에서 병렬로 발생하는 벡터화와 크게 다르지 않습니다. JAX에서 프로그램 변환인 jax.vmap으로 벡터화가 지원되는 것을 이미 살펴보았습니다. JAX는 jax.pmap을 사용하여 하나의 디바이스를 대상으로 작성된 함수를 여러 디바이스에서 병렬로 실행되는 함수로 변환하여 디바이스 병렬화를 유사하게 지원합니다. 이번 세션에서 모든 것을 알려드립니다.

Colab TPU 설정(Setup)#

Google Colab에서 이 코드를 실행하는 경우 런타임런타임 유형 변경을 선택하고 하드웨어 가속기 메뉴에서 TPU를 선택해야 합니다.

하드웨어가속기설정.png

이 작업이 완료되면 다음을 실행하여 JAX와 함께 사용할 Colab TPU를 설정할 수 있습니다.

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

다음을 실행하여 사용 가능한 TPU 기기를 확인합니다.

import jax
jax.devices()
[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]

기본(The basics)#

jax.pmap의 가장 기본적인 사용법은 jax.vmap과 완전히 유사하므로 Vectorisation notebook 에서 다룬 컨볼루션(convolution) 예제로 돌아가보겠습니다.

import numpy as np
import jax.numpy as jnp

x = np.arange(5)
w = np.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)
DeviceArray([11., 20., 29.], dtype=float32)

이제 convolve 함수를 데이터의 전체 배치에서 실행되도록 변환해 보겠습니다. 배치를 여러 디바이스에 분산시킬 것을 대비하여 배치 크기를 디바이스의 수와 동일하게 만들겠습니다.

n_devices = jax.local_device_count() 
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)

xs
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34],
       [35, 36, 37, 38, 39]])
ws
array([[2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.]])

이전과 마찬가지로 jax.vmap을 사용하여 벡터화할 수 있습니다.

jax.vmap(convolve)(xs, ws)
DeviceArray([[ 11.,  20.,  29.],
             [ 56.,  65.,  74.],
             [101., 110., 119.],
             [146., 155., 164.],
             [191., 200., 209.],
             [236., 245., 254.],
             [281., 290., 299.],
             [326., 335., 344.]], dtype=float32)

여러 디바이스에 계산을 분산하려면 jax.vmapjax.pmap으로 바꾸면 됩니다.

jax.pmap(convolve)(xs, ws)
ShardedDeviceArray([[ 11.,  20.,  29.],
                    [ 56.,  65.,  74.],
                    [101., 110., 119.],
                    [146., 155., 164.],
                    [191., 200., 209.],
                    [236., 245., 254.],
                    [281., 290., 299.],
                    [326., 335., 344.]], dtype=float32)

병렬화된 convolve 함수는 ShardedDeviceArray를 반환한다는 것에 유의해주세요. 이는 이 배열의 요소가 병렬 처리에 사용되는 모든 디바이스에 분산되기 때문입니다. 만약, 다른 병렬 계산을 실행하는 경우, 요소는 디바이스 간 통신 비용을 발생시키지 않고 각의 디바이스에 유지됩니다.

jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))
ShardedDeviceArray([[   78.,   138.,   198.],
                    [ 1188.,  1383.,  1578.],
                    [ 3648.,  3978.,  4308.],
                    [ 7458.,  7923.,  8388.],
                    [12618., 13218., 13818.],
                    [19128., 19863., 20598.],
                    [26988., 27858., 28728.],
                    [36198., 37203., 38208.]], dtype=float32)

내부 jax.pmap(convolve)의 출력은 외부 jax.pmap(convolve)에 입력될 때 디바이스를 떠나지 않았습니다.

‘in_axes’ 지정#

vmap과 마찬가지로 in_axes를 사용하여 병렬화된 함수의 인수가 브로드캐스트(None)할지 또는 주어진 축을 따라 분할할지 여부를 지정할 수 있습니다. 단, vmap과 달리 pmap은 이 가이드 작성 시점에서 선행 축(0)만 지원한다는 점에 유의하십시오.

jax.pmap(convolve, in_axes=(0, None))(xs, w)
ShardedDeviceArray([[ 11.,  20.,  29.],
                    [ 56.,  65.,  74.],
                    [101., 110., 119.],
                    [146., 155., 164.],
                    [191., 200., 209.],
                    [236., 245., 254.],
                    [281., 290., 299.],
                    [326., 335., 344.]], dtype=float32)

위에서 ws를 만들 때 w를 수동으로 복제한 jax.pmap(convolve)(xs, ws)과 동일한 출력을 얻는 방법에 주목하십시오. 여기서는 in_axes에서 None으로 지정하여 브로드캐스팅으로 복제하였습니다.

변환된 함수를 호출할 때, 인수의 지정된 축의 크기는 호스트에서 사용 가능한 다바이스 수를 초과해서는 안 된다는 것을 기억해주세요.

pmapjit#

jax.pmap은 작업의 일부로 주어진 함수를 JIT 컴파일하므로 jax.jit를 추가로 할 필요가 없습니다.

디바이스 간 통신#

위의 내용으로는 단순한 병렬 연산, 예를 들어 다수의 디바이스에서 간단한 MLP 순전파를 배치하는 것을 수행하기에 충분합니다. 그러나 때로는 디바이스 간 정보를 전달해야 하는 경우도 있습니다. 예를 들어, 각 디바이스의 출력을 정규화하여 합이 1이 되도록 하는 데 관심이 있을 수 있습니다.

이를 위해서는 특수한 집합 연산(collective ops)(예: jax.lax.p* ops psum, pmean, pmax, …)를 사용할 수 있습니다. 집합 연산(collective ops)을 사용하려면, axis_name 인수를 통해 pmap을 적용된 축의 이름을 지정한 다음 연산을 호출할 때 참조해야 합니다. 방법은 다음과 같습니다.

def normalized_convolution(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  output = jnp.array(output)
  return output / jax.lax.psum(output, axis_name='p')

jax.pmap(normalized_convolution, axis_name='p')(xs, ws)
ShardedDeviceArray([[0.00816024, 0.01408451, 0.019437  ],
                    [0.04154303, 0.04577465, 0.04959785],
                    [0.07492582, 0.07746479, 0.07975871],
                    [0.10830861, 0.10915492, 0.10991956],
                    [0.14169139, 0.14084506, 0.14008042],
                    [0.17507419, 0.17253521, 0.17024128],
                    [0.20845698, 0.20422535, 0.20040214],
                    [0.24183977, 0.23591548, 0.23056298]], dtype=float32)

번역하기 어렵네요 The axis_name is just a string label that allows collective operations like jax.lax.psum to refer to the axis bound by jax.pmap. It can be named anything you want – in this case, p. This name is essentially invisible to anything but those functions, and those functions use it to know which axis to communicate across.

axis_namejax.lax.psum 같은 집단 연산이 jax.pmap 이 바인딩하는 축을 참조하도록 하는 문자열 레이블입니다. 이 경우에는 p로 지정했습니다. 이 이름은 본질적으로 해당 기능 외에는 보이지 않으며 해당 함수들은 이를 사용하여 통신할 축을 알아냅니다.

jax.vmap또한 axis_name을 지원합니다. 이는 jax.lax.p*연산이 jax.pmap과 동일한 방식으로 jax.lax.p*의 벡터화 문맥에서 사용될 수 있도록 합니다.

jax.vmap(normalized_convolution, axis_name='p')(xs, ws)
DeviceArray([[0.00816024, 0.01408451, 0.019437  ],
             [0.04154303, 0.04577465, 0.04959785],
             [0.07492582, 0.07746479, 0.07975871],
             [0.10830861, 0.10915492, 0.10991956],
             [0.14169139, 0.14084506, 0.14008042],
             [0.17507419, 0.17253521, 0.17024128],
             [0.20845698, 0.20422535, 0.20040214],
             [0.24183977, 0.23591548, 0.23056298]], dtype=float32)

normalized_convolution은 더 이상 jax.pmap 또는 jax.vmap에 의해 변환되지 않으면 작동하지 않는데, jax.lax.psum이 명명된 축(p, 이 경우)이 있을 것으로 예상하고, 이 두 변환 방법이 하나로 바인딩하는 유일한 방법이기 때문입니다.

jax.pmapjax.vmap의 중첩#

The reason we specify axis_name as a string is so we can use collective operations when nesting jax.pmap and jax.vmap. For example:

axis_name을 문자열로 지정하는 이유는 jax.pmapjax.vmap을 중첩할 때 집합 연산(collective operations)을 사용할 수 있기 때문입니다. 예:

jax.vmap(jax.pmap(f, axis_name='i'), axis_name='j')

fjax.lax.psum(..., axis_name='i')axis_name을 공유하므로 pmapped 축만 참조합니다.

일반적으로 jax.pmapjax.vmap은 임의의 순서로 중첩될 수 있습니다. 예를 들어 다른 pmap 내부에 pmap이 있을 수 있습니다.

예제(Example)#

다음은 각 배치가 별도의 디바이스에서 평가되는 서브 배치로 분할되는 데이터 병렬 처리가 있는 회귀 훈련 루프(regression training loop)의 예입니다.

다음 두 가지 사항에 주의해야 합니다:

  • update() 함수

  • 매개변수(parameters) 복제 및 디바이스 간 데이터 분할.

이 예제가 너무 복잡하다면, 다음 노트북 State in JAX에서 병렬 처리가 없는 동일한 예제를 찾을 수 있습니다. 이 예제가 이해되면 병렬화가 어떻게 다른지 비교하여 병렬 처리가 어떻게 변경되는지 이해할 수 있습니다.

from typing import NamedTuple, Tuple
import functools

class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray


def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)


def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * xs + params.bias
  return jnp.mean((pred - ys) ** 2)

LEARNING_RATE = 0.005

# So far, the code is identical to the single-device case. Here's what's new:


# Remember that the `axis_name` is just an arbitrary string label used
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
  """Performs one SGD update step on params using the given data."""

  # Compute the gradients on the given minibatch (individually on each device).
  loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)

  # Combine the gradient across all devices (by taking their mean).
  grads = jax.lax.pmean(grads, axis_name='num_devices')

  # Also combine the loss. Unnecessary for the update, but useful for logging.
  loss = jax.lax.pmean(loss, axis_name='num_devices')

  # Each device performs its own update, but since we start with the same params
  # and synchronise gradients, the params stay in sync.
  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grads)

  return new_params, loss

다음은 update() 작동 방식입니다.

update()는 데코레이션이 되지않고 pmean이 없는 [batch, ...] 형태의 데이터 텐서를 가져와 해당 배치에 대한 손실 함수를 계산하고 기울기를 평가합니다.

We want to spread the batch dimension across all available devices. To do that, we add a new axis using pmap. The arguments to the decorated update() thus need to have shape [num_devices, batch_per_device, ...]. So, to call the new update(), we’ll need to reshape data batches so that what used to be batch is reshaped to [num_devices, batch_per_device]. That’s what split() does below. Additionally, we’ll need to replicate our model parameters, adding the num_devices axis. This reshaping is how a pmapped function knows which devices to send which data.

사용 가능한 모든 디바이스에 ‘batch’ 차원을 확산하려고 합니다. 이를 위해 pmap을 사용하여 새 축을 추가합니다. 따라서 데코레이션된 update()에 대한 인수는 [num_devices, batch_per_device, ...] 형태를 가져야 합니다. 따라서 새로운 update()를 호출하려면 batch였던 것이 [num_devices, batch_per_device]로 재구성되도록 데이터 배치를 재구성해야 합니다. 이것이 split()이 아래에서 하는 일입니다. 또한 모델 매개변수를 복제하여 num_devices 축을 추가해야 합니다. 이 재구성은 pmapped된 함수가 어떤 디바이스에 어떤 데이터를 보낼지 아는 방법입니다.

업데이트 단계 중 어느 시점에서 각 디바이스에서 계산된 그래디언트를 결합해야 합니다. 그렇지 않으면 각 디바이스에서 수행하는 업데이트가 달라집니다. 그래서 jax.lax.pmean을 사용하여 num_devices 축 전체의 평균을 계산하여 배치의 평균 그래디언트를 제공합니다. 그 평균 그래디언트는 우리가 업데이트를 계산하는 데 사용하는 것입니다.

이름을 짓는 것 외에도, 여기서는 jax.pmap을 도입하는 동안 교훈적인 명확성을 위해 axis_namenum_devices를 사용합니다. 그러나 어떤 의미에서 그것은 너무 자명합니다. pmap에 의해 도입된 모든 축은 여러 디바이스를 나타냅니다. 따라서 batch, data(데이터 병렬화를 나타냄) 또는 model(모델 병렬화를 나타냄)과 같이 의미상 의미 있는 것으로 축 이름을 지정하는 것이 일반적입니다.

# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise

# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

지금까지는 선행 차원이 추가된 배열을 구성했습니다. 매개변수는 여전히 모두 호스트(CPU)에 있습니다. update()가 처음 호출될 때 pmap은 이들을 다바이스로 통신시키고 각 사본은 이후에 자체 디바이스에 남게됩니다. 그것들은 ShardedDeviceArray가 아니라 DeviceArray이기 때문에 알 수 있습니다.

type(replicated_params.weight)
jax.interpreters.xla._DeviceArray

매개변수(params)는 pmapped된 update()에서 반환되면 ShardedDeviceArray로 변환될 것입니다. (자세한 내용은 뒷부분을 참고하세요).

데이터에 대해서도 동일한 작업을 수행합니다.

def split(arr):
  """Splits the first axis of `arr` evenly across the number of devices."""
  return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)

type(x_split)
numpy.ndarray

데이터는 단순히 재구성된 바닐라 NumPy 배열입니다. 따라서 NumPy는 CPU에서만 실행되므로 호스트 이외의 위치에 있을 수 없습니다. 그것을 수정하지 않기 때문에, 일반적으로 각 단계에서 CPU에서 디바이스로 데이터가 스트리밍되는 실제 파이프라인에서와 같이 각 update 호출마다 디바이스로 전송됩니다.

def type_after_update(name, obj):
  print(f"after first `update()`, `{name}` is a", type(obj))

# Actual training loop.
for i in range(1000):

  # This is where the params and data gets communicated to devices:
  replicated_params, loss = update(replicated_params, x_split, y_split)

  # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
  # indicating that they're on the devices.
  # `x_split`, of course, remains a NumPy array on the host.
  if i == 0:
    type_after_update('replicated_params.weight', replicated_params.weight)
    type_after_update('loss', loss)
    type_after_update('x_split', x_split)

  if i % 100 == 0:
    # Note that loss is actually an array of shape [num_devices], with identical
    # entries, because each device returns its copy of the loss.
    # So, we take the first element to print it.
    print(f"Step {i:3d}, loss: {loss[0]:.3f}")


# Plot results.

# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))
after first `update()`, `replicated_params.weight` is a <class 'jax.interpreters.pxla.ShardedDeviceArray'>
after first `update()`, `loss` is a <class 'jax.interpreters.pxla.ShardedDeviceArray'>
after first `update()`, `x_split` is a <class 'numpy.ndarray'>
Step   0, loss: 0.228
Step 100, loss: 0.228
Step 200, loss: 0.228
Step 300, loss: 0.228
Step 400, loss: 0.228
Step 500, loss: 0.228
Step 600, loss: 0.228
Step 700, loss: 0.228
Step 800, loss: 0.228
Step 900, loss: 0.228
import matplotlib.pyplot as plt
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend()
plt.show()

Aside: JAX에서의 호스트와 디바이스#

TPU에서 실행할 때 ‘호스트’라는 개념이 중요해집니다. 호스트는 여러 디바이스를 관리하는 CPU입니다. 단일 호스트가 관리할 수 있는 디바이스 수(일반적으로 8개)는 한정적이기 때문에 대규모 병렬 프로그램을 실행할 때 여러 호스트가 필요하며 이를 관리하려면 약간의 기술이 필요합니다.

jax.devices()
[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]

CPU에서 실행할 때는 --xla_force_host_platform_device_count XLA 플래그를 사용하여 임의의 수의 디바이스를 에뮬레이션할 수 있습니다. 예를 들어 JAX를 가져오기 전에 다음을 실행하면 됩니다.

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
jax.devices()
[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

이것은 CPU 런타임이 (재)시작하는 것이 더 빠르기 때문에 로컬에서 디버깅 및 테스트하거나 Colab에서 프로토타이핑하는 데 특히 유용합니다.