인공지능(LLM·Vision)

nn.CrossEntropy 사용법

인폼에잇연 2025. 6. 27. 14:14

참고 : 왜 이렇게 하나? (Operation상의 이득은 없음)

이미지 분류: [batch, class, height, width]
시퀀스 분류: [batch, class, seq_len]
일반 분류: [batch, class]
처럼 일관성 있게 구현하기 위해서 2번째에 Class를 넣어서 CrossEntropy를 구현. 

 

중요한 것은 2번째에 Class가 들어가야 한다는 것인데, 

2차원으로 입력할 때와 3차원으로 입력할 때 맞춰줘야 한다는 점입니다. 

좀 헷갈리는 점인데,  

 

2차원 입력:
input: [N, C] (N: 샘플 수, C: 클래스 수)
target: [N] (각 샘플의 정답 인덱스)

3차원 이상 입력:
input: [N, C, d1, d2, ...] (C: 클래스 차원은 항상 두 번째)
target: [N, d1, d2, ...] (정답 인덱스만)

즉,
2차원(flatten): [N, C]와 [N]
3차원 이상(예: NLP): [batch, C, seq_len]와 [batch, seq_len]

**
여기서 C(vocab_size)가 두 번째 차원이어야 함왜 혼동이 생기나? 대부분의 NLP 모델은 [batch, seq_len, vocab_size]로 출력합니다. 하지만 CrossEntropyLoss는 클래스(vocab_size) 차원이 두 번째에 와야 하므로,

3차원으로 입력하는 경우, [batch, vocab_size, seq_len]로 permute해서 넣거나,
아니면 2차원으로 입력하는 경우에 [batch * seq_len, vocab_size]로 flatten해서 넣을 수도 있습니다.

결론
2차원(flatten) 입력: [N, C]에서 클래스 차원은 마지막이어도 됩니다.
3차원 이상 입력: [N, C, ...]에서 클래스 차원은 반드시 두 번째여야 합니다.
그래서 "vocab이 두 번째에 와야 한다"는 설명이 나오는 것입니다.

3차원이상으로 입력을 넣을 때에는 

# logits: [batch, seq_len, vocab_size]
logits_perm = logits.permute(0, 2, 1)        # [batch, vocab_size, seq_len]
loss = nn.CrossEntropyLoss()(logits_perm, target)

 

2차원으로 만들 때에는 

# logits: [batch, seq_len, vocab_size]
logits_flat = logits.view(-1, vocab_size)    # [batch * seq_len, vocab_size]
target_flat = target.view(-1)                # [batch * seq_len]
loss = nn.CrossEntropyLoss()(logits_flat, target_flat)

 

이렇게 만들어 넣으면 됩니다. 

헷갈리지 마세욥 

 

※ 3차원일 떄 crossentropy 계산하는 법 

 

logits (모델 출력): [batch=2, classes=3, seq_len=2]

batch0 = [[1.0, 2.0],   # class0의 시퀀스 [pos0, pos1]
          [0.5, 1.5],   # class1의 시퀀스
          [-0.5, -1.0]] # class2의 시퀀스

batch1 = [[0.2, -0.1],
          [1.0, 0.5],
          [0.1, 0.2]]

 

target (정답 인덱스): [batch=2, seq_len=2]

target = [[0, 1],  # batch0: pos0→class0, pos1→class1
          [1, 0]]  # batch1: pos0→class1, pos1→class0

 


계산 과정 (단계별)
1. Softmax 계산 (클래스 차원 기준)
각 배치와 시퀀스 위치별로 클래스 차원에 대해 softmax 적용:

batch0, pos0: [1.0, 0.5, -0.5]
exp(1.0)=2.718, exp(0.5)=1.649, exp(-0.5)=0.606
합: 2.718 + 1.649 + 0.606 = 4.973
softmax: [2.718/4.973≈0.547, 1.649/4.973≈0.332, 0.606/4.973≈0.122]

batch0, pos1: [2.0, 1.5, -1.0]
exp(2.0)=7.389, exp(1.5)=4.482, exp(-1.0)=0.368
합: 12.239
softmax: [7.389/12.239≈0.604, 4.482/12.239≈0.366, 0.368/12.239≈0.030]

batch1, pos0: [0.2, 1.0, 0.1]
exp(0.2)=1.221, exp(1.0)=2.718, exp(0.1)=1.105
합: 5.044
softmax: [1.221/5.044≈0.242, 2.718/5.044≈0.539, 1.105/5.044≈0.219]

batch1, pos1: [-0.1, 0.5, 0.2]
exp(-0.1)=0.905, exp(0.5)=1.649, exp(0.2)=1.221
합: 3.775
softmax: [0.905/3.775≈0.240, 1.649/3.775≈0.437, 1.221/3.775≈0.323]

2. 정답 확률 추출
target 인덱스에 해당하는 확률만 선택:
batch0, pos0: target=0 → 0.547
batch0, pos1: target=1 → 0.366
batch1, pos0: target=1 → 0.539
batch1, pos1: target=0 → 0.240

3. Negative Log-Likelihood 계산
각 확률에 대해 -log(prob) 적용:
0.547 → -log(0.547) ≈ 0.603
0.366 → -log(0.366) ≈ 1.005
0.539 → -log(0.539) ≈ 0.618
0.240 → -log(0.240) ≈ 1.427

 

(0.603 + 1.005 + 0.618 + 1.427) / 4 = 3.653 / 4 ≈ 0.913

 

import torch
import torch.nn as nn

# 입력 데이터
logits = torch.tensor([
    [[1.0, 2.0], [0.5, 1.5], [-0.5, -1.0]],  # batch0
    [[0.2, -0.1], [1.0, 0.5], [0.1, 0.2]]    # batch1
])
target = torch.tensor([[0, 1], [1, 0]])

# CrossEntropyLoss 적용 (클래스 차원=1)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, target)
print(loss.item())  # 출력: 약 0.913