데이터 샤딩 : 모델 메모리를 효과적으로 사용하는 방법
딥러닝에서 OOM이 발생한다면 가장 효과적으로 처리하는 방법은 데이터 샤딩 이라는 기법을 활용해서 처리하는 방식입니다. 이번 블로그에서는 JAX를 통해 데이터 샤딩을 어떻게 하는지 알아보도록 하겠습니다.
RuntimeError: CUDA out of memory. Tried to allocate 122.00 MiB (GPU 0; 3.95 GiB total capacity; 3.08 GiB already allocated; 44.44 MiB free; 280.89 MiB cached)
라는 메세지가 나오게 됩니다. 해당 메세지를 분석하게 된다면 현재 GPU에는 3.95GB의 VRAM을 갖고 있으며 3.08GB는 이미 할당되어 있으며 280.89 MB는 캐시이며 44.44MB만 사용가능한데 우리가 필요로 한 할당량은 122.00 MB라는 점입니다.
딥러닝 연구자들은 OOM이 발생한다면 배치 사이즈를 조절한다거나 다른 프로세스를 끄는 방법등 다양한 방법들을 시도합니다. 그러나 모든 방법이 실패했을 때에는 어쩔 수 없이 더 좋은 GPU를 찾게 됩니다. 그런데 한가지 문제가 더 있습니다. 만일 여러분들이 사용하고 있는 모델이 현존하는 가장 큰 VRAM을 갖고 있는 A100 80GB로도 해결되지 않는다면 어떻게 진행하실건가요? 이런 여러분들을 위한 “데이터 샤딩”이라는 기법을 가져와봤습니다!
데이터 샤딩?
데이터 샤딩이라는 기법은 딥러닝이 아닌 데이터베이스 부분에서 나온 테크닉입니다. 데이터베이스에서의 샤딩은 전체 데이터셋을 여러 부분으로 나누고 각 샤드를 서로 다른 데이터 베이스 서버나 클러스터에서 독점적으로 관리하는 방식입니다. 데이터베이스에서 이런 방식을 채택한 이유는 샤딩 기법은 확장성을 갖고 오며 성능을 향상시킬 수 있습니다. 또한 각 샤드가 독립적으로 운영되기 때문에 한개의 샤드에서 문제가 발생해도 다른 샤드는 정상적으로 작동할 수 있습니다. 마지막으로 비싼 하드웨어 1개 대신 싼 하드웨어 2개를 사용해 비용을 절감할 수 있습니다. 결론적으로 데이터 샤딩은 데이터베이스에서 성능 향상과 독립적 운영, 비용 절감을 위해 전체 데이터셋을 여러 부분으로 나누어서 다른 서버나 클러스터에서 관리하는 시스템을 의미합니다.
그렇다면 딥러닝에서의 데이터 샤딩은 무엇일까요? 딥러닝에서의 샤딩은 데이터를 데이터베이스에 담는 것이 아닌 GPU에 담는 분산학습 시스템을 의미합니다. 그렇기에 샤딩의 의미가 조금 달라집니다. 딥러닝에서의 데이터 샤딩은 모델의 파라미터, 그레디언트, 최적화 상태 또는 데이터셋을 GPU 여러개에 분할하는 데 사용합니다. 이걸 사용하게 된다면 메모리의 사용량을 줄이고 계산 능력을 향상시키며 딥러닝 훈련을 효율적으로 만듭니다.
딥러닝 샤딩의 종류
딥러닝에서의 데이터 샤딩의 대표적인 케이스로 PyTorch Fully Sharded Data Parallel(FSDP) API가 있으며 jax.sharding API가 있습니다. PyTorch FSDP API는 DeepSpeed나 FairScale을 사용할 때 FSDP를 사용해 모델의 파라미터, 그레디언트, 최적화 상태값까지 데이터 병렬 작업자 간에 샤딩을 하는 기법입니다. 이 방식을 사용하게 되면 데이터 병렬성의 단순함을 유지하면서 확장성을 제공합니다. 이에 반해 JAX에서 제공하는 jax.sharding API는 샤딩을 보다 더 커스터마이징 할 수 있게 제공합니다. jax.shading API는 jax.array가 여러개의 GPU에 어떻게 배치되는지 임의로 조작할 수 있습니다. 패턴 또한 임의로 조작해서 다양한 병렬성에 대한 샤딩을 제공하며 XLA 컴파일도 가능해 속도 또한 상당히 우수하다고 볼 수 있습니다.
JAX에서 샤딩 사용하는 방법 알아보기
이번 블로그에서는 jax.sharding API를 활용해서 JAX에서는 여러개의 GPU를 사용할 때 어떻게 샤딩하는지 알아보도록 하겠습니다. 이번에 사용하는 내용은 JAX-KR에서 전체 내용을 확인할 수 있으며 이번 블로그에서는 특정 부분만 발췌해서 설명하겠습니다.
def get_replicated_train_state(devices):
# 모든 변수는 모든 디바이스에서 복제됩니다.
var_mesh = Mesh(devices, axis_names=("_"))
# NamedShading에서 언급되지 않는 축이 복제됩니다 (여기에서는 모든 축입니다.)
var_replication = NamedSharding(var_mesh, P())
# 모델 변수들을 분산환경에 적용합니다.
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
non_trainable_variables = jax.device_put(
model.non_trainable_variables, var_replication
)
optimizer_variables = jax.device_put(optimizer.variables, var_replication)
# 튜플 1개에 모든 상태를 합칩니다.
return (trainable_variables, non_trainable_variables, optimizer_variables)
num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))
위의 코드를 설명하면 데이터 샤딩을 진행하기 위해서는 비단 데이터만 나누는 것이 아닌 모델의 파라미터, 그레디언트, 최적화 변수까지 전부 나눠야 합니다. JAX에서는 순수 함수형 프로그래밍으로 만들어진 딥러닝 프레임워크이기 때문에 상태관리를 서로 복제해서 사용하게 만듭니다. 이러한 방법을 사용하게 될 경우에는 보다 명확하게 변수들을 적용하는 걸 확인할 수 있기 때문에 깔끔하게 사용할 수 있습니다. 현재 한 메소드의 결과값을 생각해보면 데이터를 샤딩하는것보다 우선적으로 상태를 전부 샤딩해서 관리하는걸 우선한다는 걸 알 수 있습니다.
# 데이터는 배치 축으로 분할됩니다.
data_mesh = Mesh(devices, axis_names=("batch",)) # 메시의 축 이름을 지정합니다.
data_sharding = NamedSharding(
data_mesh,
P(
"batch",
),
) # 샤딩된 파티션의 축의 이름을 지정합니다.
# 데이터 샤딩을 시각화합니다.
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))
train_state = get_replicated_train_state(devices)
모델의 상태를 전부 샤딩했다면 이제는 데이터를 샤딩할 차례입니다. 우선 데이터 매시의 축을 지정하고 샤딩된 파티션의 축 이름 또한 지정합니다. 그리고 데이터를 샤딩 진행해 시각화까지 진행합니다. 이때 여러분들이 2개 이상의 GPU를 사용하고 있다면 다음과 같은 결과값이 나오게 됩니다.
마무리하며
이번에 알아본 내용은 VRAM이 부족한 경우가 발생한다면 여러 옵션들을 사용할 수 있지만 가장 간편하게 사용할 수 있는 Multi-GPU에서의 데이터 샤딩을 통한 분산학습에 대해 알아보았습니다. 코드의 경우 JAX를 통해 알아보았으며 Keras core를 사용한다면 더 간편하게 만들어낼 수 있습니다. 사실 OOM을 해결하는 방법에는 Nvidia에서 나온 PagedOptimizer를 활용해서 GPU VRAM이 모자란 경우 CPU와 통신을 진행해 늦추는 방법도 있습니다. 다만 우리가 딥러닝의 추론을 최적화한다는 관점에서 본다면 데이터 샤딩이 가장 좋은 선택지라고 생각하며 비교적 코드가 어렵지 않아 공부하는 분들 입장에서 쉽게 사용할 수 있을거라 생각합니다.