nn.CrossEntropy 사용법
참고 : 왜 이렇게 하나? (Operation상의 이득은 없음)
이미지 분류: [batch, class, height, width]
시퀀스 분류: [batch, class, seq_len]
일반 분류: [batch, class]
처럼 일관성 있게 구현하기 위해서 2번째에 Class를 넣어서 CrossEntropy를 구현.
중요한 것은 2번째에 Class가 들어가야 한다는 것인데,
2차원으로 입력할 때와 3차원으로 입력할 때 맞춰줘야 한다는 점입니다.
좀 헷갈리는 점인데,
2차원 입력:
input: [N, C] (N: 샘플 수, C: 클래스 수)
target: [N] (각 샘플의 정답 인덱스)
3차원 이상 입력:
input: [N, C, d1, d2, ...] (C: 클래스 차원은 항상 두 번째)
target: [N, d1, d2, ...] (정답 인덱스만)
즉,
2차원(flatten): [N, C]와 [N]
3차원 이상(예: NLP): [batch, C, seq_len]와 [batch, seq_len]
**
여기서 C(vocab_size)가 두 번째 차원이어야 함왜 혼동이 생기나? 대부분의 NLP 모델은 [batch, seq_len, vocab_size]로 출력합니다. 하지만 CrossEntropyLoss는 클래스(vocab_size) 차원이 두 번째에 와야 하므로,
3차원으로 입력하는 경우, [batch, vocab_size, seq_len]로 permute해서 넣거나,
아니면 2차원으로 입력하는 경우에 [batch * seq_len, vocab_size]로 flatten해서 넣을 수도 있습니다.
결론
2차원(flatten) 입력: [N, C]에서 클래스 차원은 마지막이어도 됩니다.
3차원 이상 입력: [N, C, ...]에서 클래스 차원은 반드시 두 번째여야 합니다.
그래서 "vocab이 두 번째에 와야 한다"는 설명이 나오는 것입니다.
3차원이상으로 입력을 넣을 때에는
# logits: [batch, seq_len, vocab_size]
logits_perm = logits.permute(0, 2, 1) # [batch, vocab_size, seq_len]
loss = nn.CrossEntropyLoss()(logits_perm, target)
2차원으로 만들 때에는
# logits: [batch, seq_len, vocab_size]
logits_flat = logits.view(-1, vocab_size) # [batch * seq_len, vocab_size]
target_flat = target.view(-1) # [batch * seq_len]
loss = nn.CrossEntropyLoss()(logits_flat, target_flat)
이렇게 만들어 넣으면 됩니다.
헷갈리지 마세욥
※ 3차원일 떄 crossentropy 계산하는 법
logits (모델 출력): [batch=2, classes=3, seq_len=2]
batch0 = [[1.0, 2.0], # class0의 시퀀스 [pos0, pos1]
[0.5, 1.5], # class1의 시퀀스
[-0.5, -1.0]] # class2의 시퀀스
batch1 = [[0.2, -0.1],
[1.0, 0.5],
[0.1, 0.2]]
target (정답 인덱스): [batch=2, seq_len=2]
target = [[0, 1], # batch0: pos0→class0, pos1→class1
[1, 0]] # batch1: pos0→class1, pos1→class0
계산 과정 (단계별)
1. Softmax 계산 (클래스 차원 기준)
각 배치와 시퀀스 위치별로 클래스 차원에 대해 softmax 적용:
batch0, pos0: [1.0, 0.5, -0.5]
exp(1.0)=2.718, exp(0.5)=1.649, exp(-0.5)=0.606
합: 2.718 + 1.649 + 0.606 = 4.973
softmax: [2.718/4.973≈0.547, 1.649/4.973≈0.332, 0.606/4.973≈0.122]
batch0, pos1: [2.0, 1.5, -1.0]
exp(2.0)=7.389, exp(1.5)=4.482, exp(-1.0)=0.368
합: 12.239
softmax: [7.389/12.239≈0.604, 4.482/12.239≈0.366, 0.368/12.239≈0.030]
batch1, pos0: [0.2, 1.0, 0.1]
exp(0.2)=1.221, exp(1.0)=2.718, exp(0.1)=1.105
합: 5.044
softmax: [1.221/5.044≈0.242, 2.718/5.044≈0.539, 1.105/5.044≈0.219]
batch1, pos1: [-0.1, 0.5, 0.2]
exp(-0.1)=0.905, exp(0.5)=1.649, exp(0.2)=1.221
합: 3.775
softmax: [0.905/3.775≈0.240, 1.649/3.775≈0.437, 1.221/3.775≈0.323]
2. 정답 확률 추출
target 인덱스에 해당하는 확률만 선택:
batch0, pos0: target=0 → 0.547
batch0, pos1: target=1 → 0.366
batch1, pos0: target=1 → 0.539
batch1, pos1: target=0 → 0.240
3. Negative Log-Likelihood 계산
각 확률에 대해 -log(prob) 적용:
0.547 → -log(0.547) ≈ 0.603
0.366 → -log(0.366) ≈ 1.005
0.539 → -log(0.539) ≈ 0.618
0.240 → -log(0.240) ≈ 1.427
(0.603 + 1.005 + 0.618 + 1.427) / 4 = 3.653 / 4 ≈ 0.913
import torch
import torch.nn as nn
# 입력 데이터
logits = torch.tensor([
[[1.0, 2.0], [0.5, 1.5], [-0.5, -1.0]], # batch0
[[0.2, -0.1], [1.0, 0.5], [0.1, 0.2]] # batch1
])
target = torch.tensor([[0, 1], [1, 0]])
# CrossEntropyLoss 적용 (클래스 차원=1)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, target)
print(loss.item()) # 출력: 약 0.913