2023년 Vision Transformer 정리
Vision Transformer(ViT)가 처음 등장한 지 3년이 지난 2023년, 이 분야는 눈부신 발전을 이루었습니다. CNN의 아성을 위협하던 ViT는 이제 컴퓨터 비전의 주류로 자리잡았습니다. 이 글에서는 2023년 중반 기준으로 ViT 계열의 발전 흐름과 실무 경험을 정리합니다.
ViT의 원래 문제점
2020년 Google이 발표한 원조 ViT는 혁신적이었지만, 실무에서 사용하기에는 여러 문제가 있었습니다.
# 원조 ViT의 복잡도 문제class OriginalViTComplexity: """ Self-Attention의 계산 복잡도: O(n²) n = (H * W) / P² (패치 수)
예시: 224x224 이미지, 16x16 패치 → n = (224 * 224) / (16 * 16) = 196 패치
1024x1024 이미지라면? → n = (1024 * 1024) / (16 * 16) = 4096 패치 → Attention matrix: 4096 x 4096 = 16M elements → GPU 메모리 폭발 """
@staticmethod def memory_usage_estimate(image_size, patch_size, batch_size, hidden_dim): num_patches = (image_size // patch_size) ** 2 # Attention matrix memory attention_memory = batch_size * num_patches * num_patches * 4 # float32 # Hidden states memory hidden_memory = batch_size * num_patches * hidden_dim * 4
return { "attention_matrix_gb": attention_memory / (1024**3), "hidden_states_gb": hidden_memory / (1024**3), "num_patches": num_patches }
# 메모리 사용량 비교MEMORY_COMPARISON = { "224x224": OriginalViTComplexity.memory_usage_estimate(224, 16, 32, 768), "512x512": OriginalViTComplexity.memory_usage_estimate(512, 16, 32, 768), "1024x1024": OriginalViTComplexity.memory_usage_estimate(1024, 16, 32, 768),}
# 결과:# 224x224: ~0.05GB attention, 196 patches# 512x512: ~0.8GB attention, 1024 patches# 1024x1024: ~13GB attention, 4096 patches (단일 배치도 어려움)고해상도 이미지를 처리하려면 GPU가 여러 대 필요했고, 추론 속도도 느렸습니다. 실용적이지 않았습니다.
Swin Transformer: 게임 체인저
Swin Transformer는 이 문제를 우아하게 해결했습니다. 윈도우 단위로 어텐션을 계산해서 복잡도를 O(n)으로 줄였습니다.
import torchimport torch.nn as nnimport torch.nn.functional as F
class WindowAttention(nn.Module): """ Window-based Multi-head Self-Attention (W-MSA)
핵심 아이디어: - 전체 이미지가 아닌 작은 윈도우 내에서만 어텐션 계산 - 복잡도: O(n² * M²) → O(n * M²) where M = window size """
def __init__( self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0. ): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5
# Relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) )
# QKV projection self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
# Initialize relative position bias coords_h = torch.arange(window_size) coords_w = torch.arange(window_size) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += window_size - 1 relative_coords[:, :, 1] += window_size - 1 relative_coords[:, :, 0] *= 2 * window_size - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x, mask=None): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale attn = (q @ k.transpose(-2, -1))
# Add relative position bias relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size * self.window_size, self.window_size * self.window_size, -1 ) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) attn = attn + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N)
attn = F.softmax(attn, dim=-1) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x
class ShiftedWindowAttention(nn.Module): """ Shifted Window-based Multi-head Self-Attention (SW-MSA)
문제: 윈도우 경계를 넘는 정보 전달이 안 됨 해결: 윈도우를 shift해서 번갈아가며 적용 """
def __init__(self, dim, window_size, shift_size, num_heads): super().__init__() self.window_size = window_size self.shift_size = shift_size self.attention = WindowAttention(dim, window_size, num_heads)
def forward(self, x): H, W = x.shape[1], x.shape[2]
# Cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x
# Partition into windows x_windows = window_partition(shifted_x, self.window_size)
# Window attention attn_windows = self.attention(x_windows)
# Merge windows shifted_x = window_reverse(attn_windows, self.window_size, H, W)
# Reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x
return x
def window_partition(x, window_size): """ 이미지를 윈도우로 분할
Args: x: (B, H, W, C) window_size: int
Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows
def window_reverse(windows, window_size, H, W): """ 윈도우를 원래 이미지로 복원
Args: windows: (num_windows*B, window_size, window_size, C) window_size: int H, W: int
Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return xSwin Transformer의 계층적 구조도 중요합니다.
class SwinTransformerStage(nn.Module): """ Swin Transformer의 계층적 구조
Stage 1: H/4 x W/4 (저해상도, 채널 적음) Stage 2: H/8 x W/8 (patch merging) Stage 3: H/16 x W/16 Stage 4: H/32 x W/32 (고해상도 특징)
→ CNN의 feature pyramid와 유사한 효과 → Object detection, segmentation에 바로 적용 가능 """
def __init__(self, dim, depth, num_heads, window_size, downsample=True): super().__init__() self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2 ) for i in range(depth) ])
# Patch Merging: 2x2 → 1x1, dim → 2*dim if downsample: self.downsample = PatchMerging(dim) else: self.downsample = None
def forward(self, x): for block in self.blocks: x = block(x) if self.downsample is not None: x = self.downsample(x) return x
class PatchMerging(nn.Module): """ Patch Merging: 해상도를 절반으로 줄이고 채널을 2배로 늘림 → CNN의 strided conv와 유사한 역할 """
def __init__(self, dim): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = nn.LayerNorm(4 * dim)
def forward(self, x): B, H, W, C = x.shape
# 2x2 패치를 하나로 합침 x0 = x[:, 0::2, 0::2, :] x1 = x[:, 1::2, 0::2, :] x2 = x[:, 0::2, 1::2, :] x3 = x[:, 1::2, 1::2, :] x = torch.cat([x0, x1, x2, x3], -1)
x = self.norm(x) x = self.reduction(x)
return x실제로 써보니 기존 ViT 대비 메모리 사용량이 확연히 줄었습니다.
SWIN_VS_VIT_COMPARISON = { "memory_usage_1024x1024": { "ViT-Base": "Out of Memory (24GB GPU)", "Swin-Base": "~8GB" }, "throughput_224x224": { "ViT-Base": "~300 img/s", "Swin-Base": "~450 img/s" }, "imagenet_accuracy": { "ViT-Base": "81.8%", "Swin-Base": "83.5%" }}Flash Attention
Flash Attention은 2023년에 널리 사용되기 시작한 기술입니다. 알고리즘 자체가 아니라 메모리 접근 패턴을 최적화한 것입니다.
import torchfrom torch.nn.functional import scaled_dot_product_attention
class FlashAttentionDemo(nn.Module): """ Flash Attention의 핵심 아이디어: - GPU HBM(High Bandwidth Memory)과 SRAM의 속도 차이 활용 - 전체 attention matrix를 한 번에 계산하지 않고 블록 단위로 처리 - I/O 병목을 줄여서 실제 속도 2-4배 향상 """
def __init__(self, dim, num_heads): super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim)
def forward(self, x): B, N, C = x.shape
# QKV 계산 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2]
# PyTorch 2.0+ Flash Attention (자동 적용) # is_causal=False for non-autoregressive models out = scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale )
out = out.transpose(1, 2).reshape(B, N, C) return self.proj(out)
# Flash Attention 사용 가능 여부 확인def check_flash_attention_available(): """PyTorch 2.0+ 에서 Flash Attention 지원 확인""" if not torch.cuda.is_available(): return False
# CUDA capability check device = torch.cuda.current_device() capability = torch.cuda.get_device_capability(device)
# Flash Attention requires SM80+ (Ampere or newer) if capability[0] >= 8: return True
# FlashAttention-2 on older GPUs via xformers try: import xformers return True except ImportError: return False
return False
# 성능 비교 테스트def benchmark_attention(seq_len=4096, dim=768, num_heads=12, batch_size=8): """Standard vs Flash Attention 벤치마크""" import time
device = "cuda" x = torch.randn(batch_size, seq_len, dim, device=device)
# Standard Attention standard_attn = StandardAttention(dim, num_heads).to(device)
# Flash Attention flash_attn = FlashAttentionDemo(dim, num_heads).to(device)
# Warmup for _ in range(10): _ = standard_attn(x) _ = flash_attn(x)
torch.cuda.synchronize()
# Standard Attention timing start = time.time() for _ in range(100): _ = standard_attn(x) torch.cuda.synchronize() standard_time = time.time() - start
# Flash Attention timing start = time.time() for _ in range(100): _ = flash_attn(x) torch.cuda.synchronize() flash_time = time.time() - start
return { "standard_attention_ms": standard_time * 10, "flash_attention_ms": flash_time * 10, "speedup": standard_time / flash_time }
# 실제 측정 결과 (A100 GPU)BENCHMARK_RESULTS = { "seq_len_1024": {"standard": "12ms", "flash": "4ms", "speedup": "3x"}, "seq_len_4096": {"standard": "180ms", "flash": "45ms", "speedup": "4x"}, "seq_len_8192": {"standard": "OOM", "flash": "160ms", "speedup": "∞"}}Flash Attention을 적용하면 코드 한 줄만 바꿔도 학습 속도가 2배 이상 빨라집니다.
# 기존 코드 (PyTorch 1.x)def attention_v1(q, k, v, scale): attn = torch.matmul(q, k.transpose(-2, -1)) * scale attn = F.softmax(attn, dim=-1) out = torch.matmul(attn, v) return out
# Flash Attention 적용 (PyTorch 2.0+)def attention_v2(q, k, v, scale): return F.scaled_dot_product_attention(q, k, v, scale=scale)
# 이게 전부입니다!멀티모달 통합: CLIP과 그 이후
CLIP은 텍스트와 이미지를 같은 임베딩 공간에 매핑한다는 간단한 아이디어로 판도를 바꿨습니다.
import torchimport torch.nn as nnimport torch.nn.functional as F
class CLIPLike(nn.Module): """ CLIP의 핵심 구조 재현
이미지 인코더: ViT 텍스트 인코더: Transformer 학습: Contrastive Learning (대조 학습) """
def __init__( self, embed_dim: int = 512, image_encoder: nn.Module = None, text_encoder: nn.Module = None, temperature: float = 0.07 ): super().__init__() self.image_encoder = image_encoder self.text_encoder = text_encoder self.temperature = nn.Parameter(torch.ones([]) * temperature)
# Projection heads self.image_projection = nn.Linear(embed_dim, embed_dim) self.text_projection = nn.Linear(embed_dim, embed_dim)
def encode_image(self, image): features = self.image_encoder(image) features = self.image_projection(features) return F.normalize(features, dim=-1)
def encode_text(self, text): features = self.text_encoder(text) features = self.text_projection(features) return F.normalize(features, dim=-1)
def forward(self, image, text): image_features = self.encode_image(image) text_features = self.encode_text(text)
# Cosine similarity as logits logits_per_image = image_features @ text_features.t() / self.temperature logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
def contrastive_loss(self, image, text): """ InfoNCE Loss (Contrastive Loss)
목표: 매칭되는 이미지-텍스트 쌍의 유사도를 높이고, 매칭되지 않는 쌍의 유사도를 낮춤 """ logits_per_image, logits_per_text = self.forward(image, text)
batch_size = image.shape[0] labels = torch.arange(batch_size, device=image.device)
# 대칭적 cross entropy loss loss_i2t = F.cross_entropy(logits_per_image, labels) loss_t2i = F.cross_entropy(logits_per_text, labels)
return (loss_i2t + loss_t2i) / 2
# CLIP 활용 예시: Zero-shot Classificationclass CLIPClassifier: """ 학습 없이 새로운 클래스 분류 가능 """
def __init__(self, clip_model, class_names): self.model = clip_model self.class_names = class_names
# 클래스별 텍스트 임베딩 미리 계산 text_prompts = [f"a photo of a {name}" for name in class_names] self.text_features = self.model.encode_text(text_prompts)
def classify(self, image): image_features = self.model.encode_image(image)
# 코사인 유사도 계산 similarities = image_features @ self.text_features.t()
# 가장 유사한 클래스 반환 probs = F.softmax(similarities, dim=-1) predicted_class = probs.argmax(dim=-1)
return self.class_names[predicted_class], probs.max()
# 실제 사용 예시def zero_shot_classification_demo(): """ CLIP의 Zero-shot 능력 데모
놀라운 점: - 학습 데이터에 없던 클래스도 분류 가능 - 프롬프트만 바꾸면 다른 태스크 수행 가능 """ # ImageNet에 없는 클래스들 custom_classes = [ "a corgi dog", "a persian cat", "a grilled cheese sandwich", "the Eiffel Tower", "a person riding a bicycle" ]
classifier = CLIPClassifier(clip_model, custom_classes)
# 이미지 분류 result, confidence = classifier.classify(image) print(f"Predicted: {result} (confidence: {confidence:.2%})")CLIP을 기반으로 한 다양한 응용이 등장했습니다.
CLIP_BASED_APPLICATIONS = { "DALL-E": { "description": "텍스트에서 이미지 생성", "how": "CLIP으로 텍스트-이미지 정합성 평가" }, "Stable Diffusion": { "description": "오픈소스 이미지 생성", "how": "CLIP text encoder를 조건으로 사용" }, "BLIP / BLIP-2": { "description": "이미지 캡셔닝, VQA", "how": "CLIP 구조 확장 + 언어 모델" }, "OpenCLIP": { "description": "CLIP의 오픈소스 재현", "how": "LAION 데이터셋으로 학습" }, "SAM (Segment Anything)": { "description": "범용 이미지 분할", "how": "CLIP과 유사한 대규모 사전학습" }}실무에서의 Vision Transformer
2023년에 느낀 점은 “효율성”이 진짜 중요하다는 것입니다.
# 실무 관점의 모델 선택 기준MODEL_SELECTION_CRITERIA = { "accuracy_first": { "recommendation": "Swin-Large, ViT-Large", "use_case": "오프라인 배치 처리, 정확도가 최우선", "trade_off": "속도와 비용 포기" }, "balanced": { "recommendation": "Swin-Base, ConvNeXt-Base", "use_case": "대부분의 프로덕션 환경", "trade_off": "정확도 약간 손해, 합리적인 속도" }, "speed_first": { "recommendation": "EfficientNetV2-S, MobileViT", "use_case": "실시간 처리, 엣지 디바이스", "trade_off": "정확도 손해, 빠른 속도" }, "edge_device": { "recommendation": "MobileNetV3, EfficientNet-Lite", "use_case": "모바일, IoT", "trade_off": "정확도 많이 손해, 낮은 리소스 사용" }}
# 실제 프로젝트에서의 선택 과정class ModelSelectionProcess: """ 실무에서 모델을 선택하는 과정 """
@staticmethod def analyze_requirements(): return { "latency_requirement": "< 50ms", "accuracy_requirement": "> 85% top-1", "batch_size": 32, "gpu_budget": "1x V100", "deployment": "cloud API" }
@staticmethod def benchmark_candidates(requirements): candidates = [ {"name": "ViT-Base", "latency": "35ms", "accuracy": "81.8%"}, {"name": "Swin-Base", "latency": "28ms", "accuracy": "83.5%"}, {"name": "ConvNeXt-Base", "latency": "25ms", "accuracy": "83.8%"}, {"name": "EfficientNetV2-M", "latency": "15ms", "accuracy": "85.1%"}, ]
# 요구사항 만족 여부 확인 suitable = [] for c in candidates: latency_ok = float(c["latency"].replace("ms", "")) < 50 accuracy_ok = float(c["accuracy"].replace("%", "")) > 85
if latency_ok and accuracy_ok: suitable.append(c)
return suitable # EfficientNetV2-M만 둘 다 만족
@staticmethod def final_decision(): """ 최종 결정 시 고려사항 """ return { "selected": "EfficientNetV2-M", "reasons": [ "요구 정확도 만족 (85.1%)", "요구 지연시간 만족 (15ms)", "학습 코드/체크포인트 풍부", "커뮤니티 지원 활발" ], "alternatives_considered": [ "Swin-Base: 정확도 미달", "ConvNeXt-Base: 정확도 미달", "ViT-Base: 정확도 미달" ] }최적화 기법들
실제 배포 시 사용한 최적화 기법들입니다.
import torchfrom torch import nn
# 1. Mixed Precision Trainingdef train_with_mixed_precision(model, dataloader, optimizer): scaler = torch.cuda.amp.GradScaler()
for batch in dataloader: optimizer.zero_grad()
# FP16 forward pass with torch.cuda.amp.autocast(): outputs = model(batch["image"]) loss = criterion(outputs, batch["label"])
# Scaled backward pass scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
# 2. Torch Compile (PyTorch 2.0+)def optimize_with_compile(model): """ PyTorch 2.0의 torch.compile로 자동 최적화 - 커널 퓨전 - 메모리 최적화 - GPU 활용 최적화 """ optimized_model = torch.compile( model, mode="reduce-overhead", # 또는 "max-autotune" fullgraph=True ) return optimized_model
# 3. ONNX 변환 및 최적화def export_to_onnx(model, input_shape, output_path): """ ONNX로 변환하여 다양한 런타임에서 사용 """ import torch.onnx
dummy_input = torch.randn(*input_shape)
torch.onnx.export( model, dummy_input, output_path, export_params=True, opset_version=17, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } )
# 4. TensorRT 최적화def optimize_with_tensorrt(onnx_path): """ TensorRT로 추론 최적화 (NVIDIA GPU) """ import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, logger)
with open(onnx_path, 'rb') as f: parser.parse(f.read())
config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
# FP16 최적화 if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16)
engine = builder.build_serialized_network(network, config) return engine
# 최적화 효과 비교OPTIMIZATION_RESULTS = { "Swin-Base (224x224, batch=32)": { "baseline_pytorch": "45ms", "mixed_precision": "28ms (1.6x)", "torch_compile": "22ms (2.0x)", "tensorrt_fp16": "12ms (3.8x)" }}2024년 전망
더 효율적인 아키텍처가 계속 나올 것입니다.
FUTURE_TRENDS = { "efficient_architectures": { "description": "더 작고 빠른 모델", "examples": ["EfficientViT", "FastViT", "MobileViT v2"], "target": "엣지 디바이스에서 ViT 성능" }, "foundation_models": { "description": "대규모 사전학습 모델", "examples": ["DINOv2", "SAM", "ImageBind"], "impact": "다운스트림 태스크 성능 향상" }, "multimodal_expansion": { "description": "비전을 넘어선 통합", "examples": ["GPT-4V", "Gemini", "LLaVA"], "direction": "Vision + Language + Audio 통합" }, "hardware_codesign": { "description": "하드웨어와 함께 설계", "examples": ["Apple Neural Engine 최적화 모델"], "goal": "특정 칩에서 최적 성능" }}