DeepSpeed Activation Checkpointing OOM (Out-of-Memory) 디버깅 마스터: GPU 메모리 사용량 최적화 및 초거대 모델 학습 전략

초거대 모델 학습 중 빈번하게 발생하는 OOM 오류, 특히 DeepSpeed Activation Checkpointing 사용 시 발생하는 문제를 해결하는 궁극적인 가이드입니다. GPU 메모리 사용량을 획기적으로 줄여 학습 효율을 극대화하고, 더 큰 모델을 더 빠르게 학습시키는 전략을 소개합니다.

1. The Challenge / Context

초거대 모델(LLM)의 학습은 막대한 GPU 메모리를 요구합니다. 파라미터 수가 증가함에 따라, 모델의 활성화 값(activation)을 저장하는 데 필요한 메모리 역시 기하급수적으로 증가합니다. Activation Checkpointing은 이러한 문제를 해결하는 핵심 기술이지만, 잘못된 설정이나 이해 부족은 오히려 OOM(Out-of-Memory) 오류를 야기할 수 있습니다. 현재 AI 모델 개발 경쟁이 치열해짐에 따라 더 크고 강력한 모델을 학습시키는 것이 중요하며, Activation Checkpointing을 효과적으로 활용하는 능력은 경쟁 우위를 확보하는 데 필수적입니다. 텐서 코어 사용량 극대화, 효율적인 메모리 관리, 안정적인 학습 환경 구축 모두가 이 기술을 제대로 이해하고 활용하는 데 달려있습니다.

2. Deep Dive: DeepSpeed Activation Checkpointing

DeepSpeed Activation Checkpointing (또는 Gradient Checkpointing)은 모델의 순전파(forward pass) 과정에서 활성화 값을 모두 저장하는 대신, 필요한 활성화 값을 다시 계산하여 메모리 사용량을 줄이는 기술입니다. 역전파(backward pass) 시에 필요한 활성화 값을 순전파에서 다시 계산함으로써, 메모리 사용량과 계산량 사이의 트레이드오프를 제공합니다. DeepSpeed는 이 기능을 더욱 효율적으로 구현하고, 다양한 최적화 옵션을 제공합니다.

작동 원리: 모델의 레이어를 여러 개의 세그먼트로 나누고, 각 세그먼트의 입력 활성화 값만 저장합니다. 역전파 단계에서 각 세그먼트의 활성화 값은 저장된 입력 값을 이용하여 다시 계산됩니다. 이 방식은 전체 활성화 값을 저장하는 것보다 훨씬 적은 메모리를 사용하지만, 순전파 계산을 일부 반복해야 하므로 계산 시간이 약간 늘어납니다.

핵심 기능:

  • Selective Activation Checkpointing: 모든 레이어에 Activation Checkpointing을 적용하는 대신, 메모리 사용량이 높은 특정 레이어에만 적용하여 성능 저하를 최소화할 수 있습니다.
  • CPU Offloading: 활성화 값을 CPU 메모리로 오프로딩하여 GPU 메모리 부족 문제를 완화할 수 있습니다.
  • Distributed Checkpointing: Activation Checkpointing 계산을 여러 GPU에 분산시켜 계산 시간을 단축할 수 있습니다.

3. Step-by-Step Guide / Implementation

DeepSpeed Activation Checkpointing을 효과적으로 사용하고 OOM 오류를 디버깅하는 단계별 가이드입니다.

Step 1: DeepSpeed 설정 확인 및 구성

DeepSpeed를 올바르게 설치하고 구성해야 합니다. deepspeed_config.json 파일을 통해 다양한 옵션을 설정할 수 있습니다.


    {
      "train_batch_size": 16,
      "train_micro_batch_size_per_gpu": 4,
      "gradient_accumulation_steps": 4,
      "optimizer": {
        "type": "AdamW",
        "params": {
          "lr": 0.0001,
          "weight_decay": 0.01
        }
      },
      "scheduler": {
        "type": "WarmupLR",
        "params": {
          "warmup_min_lr": 0.00001,
          "warmup_max_lr": 0.0001,
          "warmup_num_steps": 1000
        }
      },
      "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
      },
      "activation_checkpointing": {
        "partition_activations": true,
        "contiguous_memory_optimization": true,
        "cpu_checkpointing": false
      },
      "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "reduce_scatter": true,
        "contiguous_gradients": true,
        "allgather_partitions": true
      }
    }
    

설명:

  • partition_activations: 활성화 값을 여러 GPU에 분산하여 저장합니다.
  • contiguous_memory_optimization: 활성화 값을 연속적인 메모리 공간에 저장하여 메모리 단편화를 줄입니다.
  • cpu_checkpointing: 활성화 값을 CPU 메모리로 오프로딩합니다 (OOM 해결에 유용).
  • zero_optimization: ZeRO 최적화는 모델 파라미터, gradients, optimizer states를 분산하여 메모리를 더욱 절약합니다. Stage 2 이상을 권장합니다.

Step 2: 모델에 Activation Checkpointing 적용

DeepSpeed Activation Checkpointing을 모델에 적용하는 방법은 크게 두 가지입니다. 첫 번째는 DeepSpeed 엔진을 사용하는 것이고, 두 번째는 직접 구현하는 것입니다.

2.1. DeepSpeed 엔진 사용

가장 간단한 방법은 DeepSpeed 엔진을 사용하여 모델을 래핑하는 것입니다.


    import deepspeed
    import torch

    model = ...  # Your PyTorch model
    config = "deepspeed_config.json"

    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        config=config
    )

    # 이제 model_engine을 사용하여 학습 진행
    

DeepSpeed 엔진은 설정 파일(deepspeed_config.json)에 따라 자동으로 Activation Checkpointing을 적용합니다.

2.2. 직접 구현

보다 세밀한 제어가 필요한 경우, torch.utils.checkpoint를 사용하여 Activation Checkpointing을 직접 구현할 수 있습니다.


    import torch
    from torch.utils.checkpoint import checkpoint

    class MyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.layer1 = torch.nn.Linear(10, 20)
            self.layer2 = torch.nn.Linear(20, 30)
            self.layer3 = torch.nn.Linear(30, 40)

        def forward(self, x):
            x = torch.relu(self.layer1(x))
            x = checkpoint(torch.relu, self.layer2(x))  # Activation Checkpointing 적용
            x = torch.relu(self.layer3(x))
            return x
    

설명: checkpoint 함수는 지정된 함수의 입력 활성화 값을 저장하지 않고, 역전파 시에 다시 계산합니다. 위 예제에서는 layer2에 Activation Checkpointing이 적용되었습니다.

Step 3: OOM 오류 발생 시 디버깅

Activation Checkpointing을 적용했음에도 불구하고 OOM 오류가 발생할 수 있습니다. 이 경우, 다음 단계를 통해 문제를 해결할 수 있습니다.

  1. 배치 크기 줄이기: 가장 기본적인 해결책은 train_batch_size 또는 train_micro_batch_size_per_gpu를 줄이는 것입니다.
  2. CPU Offloading 활성화: deepspeed_config.json 파일에서 cpu_checkpointingtrue로 설정합니다. 또한 ZeRO Offload 옵션도 함께 사용하는 것이 좋습니다.
  3. Selective Activation Checkpointing 활용: 모든 레이어에 Activation Checkpointing을 적용하는 대신, 메모리 사용량이 높은 레이어에만 적용합니다. profiling 툴을 사용하여 메모리 사용량이 높은 레이어를 식별할 수 있습니다.
  4. Gradient Accumulation Step 조정: Gradient Accumulation Step을 늘리면 더 큰 배치 크기를 효과적으로 사용할 수 있지만, 메모리 사용량도 증가할 수 있습니다. 적절한 균형을 찾아야 합니다.
  5. Mixed Precision Training (FP16): FP16을 활성화하면 메모리 사용량을 절반으로 줄일 수 있습니다. fp16 설정을 deepspeed_config.json 파일에서 확인하고 활성화합니다.
  6. Garbage Collection: 명시적으로 `torch.cuda.empty_cache()`를 호출하여 GPU 메모리를 확보합니다.
  7. PyTorch Profiler 활용: PyTorch Profiler를 사용하여 메모리 누수를 진단하고, 메모리 사용량이 많은 연산을 식별합니다.

4. Real-world Use Case / Example

저는 한 번은 20B 파라미터 모델을 사용하여 텍스트 생성 작업을 수행하던 중 지속적인 OOM 오류에 직면했습니다. 초기에는 배치 크기를 줄이고 Gradient Accumulation Step을 늘리는 방법을 시도했지만, 근본적인 해결책은 아니었습니다. PyTorch Profiler를 사용하여 메모리 사용량을 분석한 결과, 특정 Transformer 블록이 비정상적으로 많은 메모리를 사용하는 것을 확인했습니다. 해당 블록에만 Activation Checkpointing을 적용하고, CPU Offloading을 활성화한 결과, OOM 오류 없이 학습을 진행할 수 있었습니다. 또한, FP16을 사용하여 학습 속도를 2배로 향상시킬 수 있었습니다. 이 경험을 통해 Selective Activation Checkpointing의 중요성과 profiling 툴의 유용성을 깨달았습니다.

5. Pros & Cons / Critical Analysis

  • Pros:
    • GPU 메모리 사용량 감소
    • 더 큰 모델 학습 가능
    • 더 큰 배치 크기 사용 가능
  • Cons:
    • 계산 시간 증가 (순전파 재계산 필요)
    • 설정 및 디버깅 복잡성 증가
    • 모든 모델에 효과적인 것은 아님 (메모리 사용량 프로파일링 필요)

6. FAQ

  • Q: Activation Checkpointing을 사용하면 항상 메모리 사용량이 줄어드나요?
    A: 항상 그런 것은 아닙니다. Activation Checkpointing은 계산 시간과 메모리 사용량 사이의 트레이드오프를 제공합니다. 메모리 사용량이 높은 모델의 특정 레이어에 적용하는 것이 효과적입니다.
  • Q: CPU Offloading을 활성화하면 성능이 저하되나요?
    A: 네, CPU Offloading은 GPU 메모리를 확보하는 대신 CPU와 GPU 간의 데이터 전송 오버헤드를 발생시킵니다. 하지만 OOM 오류를 해결할 수 있다면 성능 저하를 감수할 만한 가치가 있을 수 있습니다.
  • Q: DeepSpeed와 PyTorch의 내장 Activation Checkpointing (torch.utils.checkpoint)의 차이점은 무엇인가요?
    A: DeepSpeed는 Activation Checkpointing을 더욱 효율적으로 구현하고, 다양한 최적화 옵션 (예: Selective Activation Checkpointing, CPU Offloading)을 제공합니다. PyTorch의 내장 Activation Checkpointing은 기본적인 기능만 제공합니다.

7. Conclusion

DeepSpeed Activation Checkpointing은 초거대 모델 학습 시 발생하는 OOM 오류를 해결하고 GPU 메모리 사용량을 최적화하는 강력한 기술입니다. 이 가이드에서 제시된 단계별 방법과 디버깅 전략을 통해 더 큰 모델을 더 효율적으로 학습시키고, AI 모델 개발 경쟁에서 우위를 점할 수 있을 것입니다. 지금 바로 DeepSpeed를 설정하고, Activation Checkpointing을 적용하여 여러분의 모델 학습 효율을 극대화해보세요! 자세한 내용은 DeepSpeed 공식 문서를 참고하시기 바랍니다.