데이터 분석/머신러닝, 딥러닝

Transformer

fullfish 2025. 12. 10. 16:35

1. Transformer의 등장 배경과 핵심 아이디어

기존의 RNN(순환 신경망) 모델은 단어를 순서대로 하나씩 처리해야 했기 때문에 속도가 느리고, 문장이 길어지면 앞부분의 정보를 잊어버리는 단점이 있었습니다.

 
  • 핵심 아이디어: RNN처럼 순차적으로 처리하지 않고, 문장 전체를 한 번에 입력받아 병렬로 처리합니다.
  • 장점:
    • 병렬 연산이 가능하여 GPU 효율이 높습니다.
    • 문장 내에서 거리가 먼 단어들 사이의 관계도 한 번에 파악할 수 있습니다.
       
  • 구조: 오직 Attention(어텐션) 메커니즘만으로 이루어진 Encoder-Decoder 구조입니다.

2. 가장 중요한 부품: Self-Attention (자가 주의 집중)

Transformer가 문맥을 이해하는 핵심 원리입니다. 문장 내의 각 단어가 다른 단어들과 어떤 연관이 있는지를 계산합니다.

 
  • Query (Q): "내가 지금 무엇을 찾고 있는지" (질문 던지는 역할)
  • Key (K): "내가 어떤 정보를 가진 단어인지" (신분증 역할)
  • Value (V): "내가 실제로 들려줄 정보" (내용물 역할)
  • Q, K, V의 직관적 이해: 각 단어는 3가지 역할로 변신하여 계산에 참여합니다.
  • 작동 원리:
    1. 한 단어의 Q와 모든 단어의 K를 비교(내적)하여 유사도(Score)를 구합니다.
    2. 이 점수를 Softmax를 통해 확률(가중치)로 바꿉니다. 즉, "어떤 단어에 얼마나 집중할지"를 결정합니다.
    3. 이 가중치를 각 단어의 V에 곱해서 더합니다(Weighted Sum). 이렇게 하면 문맥이 반영된 새로운 단어 표현이 만들어집니다.
       

3. 더 똑똑하게 보기: Multi-Head Attention

Self-Attention을 한 번만 수행하는 것이 아니라, 여러 개의 'Head'로 나누어 동시에 수행합니다.

  • 이유: 한 번만 계산하면 하나의 관점으로만 문장을 보게 됩니다. 여러 번 수행하면 다양한 관점(예: 주어-동사 관계, 형용사-명사 관계 등)을 동시에 포착하여 표현력이 풍부해집니다.
  • 방법: 여러 개의 Q, K, V 쌍을 만들어 각각 어텐션을 수행한 뒤, 결과들을 이어 붙입니다(Concat).

4. 순서 알려주기: Positional Encoding

Transformer는 문장을 한 번에 '집합'처럼 처리하기 때문에, 단어의 순서(어순)를 모른다는 단점이 있습니다.

  • 문제: "나는 밥을 먹었다"와 "밥을 나는 먹었다"를 똑같이 인식할 위험이 있습니다.
  • 해결: 단어 벡터에 **위치 정보를 담은 벡터(Positional Encoding)**를 더해줍니다. 이를 통해 모델이 "이 단어가 몇 번째에 위치하는지"를 알게 됩니다.
  • 구현: 주로 사인(Sin)과 코사인(Cos) 함수를 이용하여 위치별로 고유한 값을 만들어 더합니다.

5. 전체 구조: Encoder와 Decoder

Transformer는 크게 정보를 읽는 Encoder와 결과를 생성하는 Decoder로 나뉩니다.

Encoder (입력 처리)

  • 입력 문장의 문맥을 파악하여 압축된 정보를 만듭니다.
  • 블록 구조: Multi-Head Self-AttentionResidual(잔차 연결) + LayerNormFeed-Forward Network(FFN)Residual + LayerNorm 순서로 반복됩니다.

Decoder (출력 생성)

  • Encoder가 준 정보와 지금까지 만든 단어들을 바탕으로 다음 단어를 예측합니다.
  • Encoder와의 차이점:
     
    1. Masked Self-Attention: 자기 자신을 볼 때, 미래의 단어(아직 예측하지 않은 단어)는 보지 못하도록 가립니다(Masking).
    2. Encoder-Decoder Attention (Cross-Attention): 디코더의 Query가 인코더의 Key, Value를 참조합니다. 즉, "번역할 때 원문의 어느 부분을 참고할지" 결정합니다.
# self_attention

import torch
import torch.nn.functional as F

# 시드 고정으로 결과의 재현성 확보
torch.manual_seed(42)


def simple_self_attention(x):
    # x: 입력 시퀀스 텐서 (seq_len, d_model) 형태.
    # 각 행은 문장의 단어 임베딩을 나타냅니다.
    d_model = x.size(-1)

    # 가중치 행렬 초기화 (실제 모델에서는 nn.Linear로 학습됩니다)
    W_Q = torch.randn(d_model, d_model)
    W_K = torch.randn(d_model, d_model)
    W_V = torch.randn(d_model, d_model)

    # Query (Q), Key (K), Value (V) 텐서 생성
    # Q = 입력 x와 W_Q의 행렬 곱. '내가 찾고 있는 것'
    Q = x @ W_Q
    # K = 입력 x와 W_K의 행렬 곱. '내가 가진 정보'
    K = x @ W_K
    V = x @ W_V
    # Q, K, V 모두 (seq_len, d_model) 형태를 가집니다.

    # 1. Score 계산: Q와 K의 내적(dot product) 사용 (유사도 측정)
    # Q @ K.transpose(0, 1)는 (seq_len, seq_len) 형태를 반환합니다.
    # i번째 단어의 Q가 j번째 단어의 K와 얼마나 관련 있는지 점수 매김.
    scores = Q @ K.transpose(0, 1)

    # 스케일링 (Scaling): 분산이 커지는 것을 막기 위해 d_k의 제곱근으로 나눕니다.
    scores = scores / (d_model**0.5)

    # 2. Softmax로 정규화: Score를 Attention Weight(가중치)인 확률 분포로 변환
    # dim=-1 에 대해 적용하여 각 행의 합이 1이 되도록 합니다.
    # 각 행은 해당 단어가 문장 내의 다른 단어(K)들에 얼마나 집중해야 하는지(가중치)를 나타냅니다.
    attn_weights = F.softmax(scores, dim=-1)

    # 3. Weighted Sum (가중합): 새로운 은닉 표현 생성
    # 계산된 가중치(attn_weights)를 모든 단어의 Value(V)와 곱하여 합산합니다.
    out = attn_weights @ V

    return out, attn_weights


# --- 예제 실행 ---
# 입력 데이터 설정: seq_len = 4, d_model = 8
seq_len, d_model = 4, 8

# 랜덤 입력 데이터 생성 (임베딩된 문장이라고 가정)
x = torch.randn(seq_len, d_model)

# Self-Attention 계산 수행
out, attn_w = simple_self_attention(x)

print("out shape:", out.shape)
print("attn_w shape:", attn_w.shape)
print("attention weights:\n", attn_w)

'''
out shape: torch.Size([4, 8])
attn_w shape: torch.Size([4, 4])
attention weights:
 tensor([[9.9999e-01, 2.2638e-08, 4.5571e-07, 1.1432e-05],
        [1.1078e-02, 9.8871e-01, 1.2271e-04, 8.5188e-05],
        [1.5654e-06, 6.1913e-04, 9.7855e-01, 2.0830e-02],
        [3.0778e-07, 8.7853e-08, 9.9291e-01, 7.0907e-03]])'''

 

 

# multi_head_attention

import torch
import torch.nn as nn

# 시드 고정으로 결과의 재현성 확보
torch.manual_seed(42)

# Multi-Head Attention에 사용할 입력 데이터의 크기 설정
batch_size = 2  # 배치 크기 (동시에 처리할 문장의 개수)
seq_len = 5  # 시퀀스 길이 (문장 내 단어의 개수)
d_model = 16  # 임베딩 차원 (단어 벡터의 크기, embed_dim)
num_heads = 4  # 어텐션 헤드의 개수

# nn.MultiheadAttention 객체 초기화
mha = nn.MultiheadAttention(
    embed_dim=d_model,  # 입력 임베딩 차원
    num_heads=num_heads,  # 병렬로 수행할 어텐션 헤드의 개수
    batch_first=True,  # 입력 텐서의 순서를 (batch, seq_len, d_model)로 설정
)

# 랜덤 입력 텐서 (x) 생성
# x는 임베딩된 문장 데이터 (batch_size, seq_len, d_model) 형태
x = torch.randn(batch_size, seq_len, d_model)

# Self-Attention 계산 수행
# Query, Key, Value 모두 같은 입력 텐서(x)를 사용하면 Self-Attention이 됩니다.
attn_out, attn_weights = mha(x, x, x)

# 출력 형태 확인
print("attn_out shape:", attn_out.shape)
# attn_out: 문맥 정보가 반영된 출력 텐서의 형태. (batch, seq_len, d_model) -> (2, 5, 16)

print("attn_weights shape:", attn_weights.shape)
# attn_weights: 어텐션 가중치 텐서의 형태. (batch, seq_len, seq_len) -> (2, 5, 5)

'''
attn_out shape: torch.Size([2, 5, 16])
attn_weights shape: torch.Size([2, 5, 5])'''

 

# positional_encoding

import torch
import torch.nn as nn
import math


def get_positional_encoding(max_len, d_model):
    """
    Positional Encoding 텐서를 생성하는 함수입니다.

    반환: (max_len, d_model) 크기의 텐서
    """

    # Positional Encoding을 저장할 빈 텐서를 (max_len, d_model) 크기로 생성
    pe = torch.zeros(max_len, d_model)  #

    # 0부터 max_len-1까지의 위치(pos)를 나타내는 텐서를 생성하고 차원 확장 (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)

    # 분모 (Denominator)에 해당하는 스케일링 값들을 미리 계산
    # 이는 'pos' 값에 곱해져서 사인/코사인 함수의 주기를 조절합니다.
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
    )

    # 짝수 차원(0, 2, 4, ...)에는 사인(sin) 함수 적용
    pe[:, 0::2] = torch.sin(position * div_term)

    # 홀수 차원(1, 3, 5, ...)에는 코사인(cos) 함수 적용
    pe[:, 1::2] = torch.cos(position * div_term)

    return pe  # (max_len, d_model)


# --- Positional Encoding 적용 예제 ---

# 모델 및 데이터 크기 설정
batch_size = 2
seq_len = 5
d_model = 16
vocab_size = 100

# 단어 임베딩 레이어 정의
embedding = nn.Embedding(vocab_size, d_model)

# 입력 토큰 인덱스 생성 (예시)
x_idx = torch.randint(0, vocab_size, (batch_size, seq_len))  # (batch, seq_len)

# 1. 단어 임베딩 수행
x_embed = embedding(x_idx)  # (batch, seq_len, d_model)

# 2. Positional Encoding 생성
pe = get_positional_encoding(max_len=seq_len, d_model=d_model)  # (seq_len, d_model)

# 3. 배치를 위해 차원 확장
# (seq_len, d_model) -> (1, seq_len, d_model)로 확장하여 브로드캐스팅 준비
pe = pe.unsqueeze(0)

# 4. 단어 임베딩과 위치 인코딩을 더하여 최종 입력 벡터 생성
# 브로드캐스팅(broadcasting)을 통해 모든 배치에 동일한 pe가 더해집니다.
x_with_pos = x_embed + pe  # (batch, seq_len, d_model)


### 2. Transformer 입력 모듈 클래스 정의


class SimpleTransformerInput(nn.Module):
    def __init__(self, vocab_size, d_model, max_len):
        super().__init__()

        self.d_model = d_model
        # 단어 임베딩 레이어
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Positional Encoding 텐서를 미리 계산하여 저장
        self.pos_encoding = get_positional_encoding(
            max_len, d_model
        )  # (max_len, d_model)

    def forward(self, x):
        """
        입력 토큰 인덱스를 받아 임베딩 후 위치 인코딩을 더해 반환합니다.

        x: (batch, seq_len) 토큰 인덱스
        """

        # 입력 텐서의 크기 추출
        batch_size, seq_len = x.size()

        # 1. 단어 임베딩 수행
        emb = self.embedding(x)  # (batch, seq_len, d_model)

        # 2. Positional Encoding 추출 및 차원 확장
        # 현재 시퀀스 길이(seq_len)에 맞는 PE 부분만 잘라 사용하고, 배치 차원(1)을 추가
        pe = self.pos_encoding[:seq_len].unsqueeze(0)  # (1, seq_len, d_model)

        # 3. 단어 임베딩에 위치 인코딩을 더함
        emb = emb + pe

        return emb  # (batch, seq_len, d_model)


# --- 실행 및 결과 확인 부분 추가 ---
print("-" * 30)
print("SimpleTransformerInput 모듈 실행")

# 모듈 인스턴스화
# 최대 시퀀스 길이(max_len)는 여기서는 seq_len과 동일하게 5로 설정
max_len = 5
model = SimpleTransformerInput(vocab_size, d_model, max_len)

# 입력 데이터 (x_idx) 사용
output_tensor = model(x_idx)

# 출력 결과 확인
print("입력 토큰 인덱스 (x_idx) shape:", x_idx.shape)
print("최종 출력 텐서 (output_tensor) shape:", output_tensor.shape)
print("출력 텐서의 일부 값 (첫 번째 배치의 첫 번째 토큰):\n", output_tensor[0, 0, :4])
print("-" * 30)

'''
------------------------------
SimpleTransformerInput 모듈 실행
입력 토큰 인덱스 (x_idx) shape: torch.Size([2, 5])
최종 출력 텐서 (output_tensor) shape: torch.Size([2, 5, 16])
출력 텐서의 일부 값 (첫 번째 배치의 첫 번째 토큰):
 tensor([ 0.9177,  1.0161, -0.4875,  0.1554], grad_fn=<SliceBackward0>)
------------------------------'''

 

# 실습 예제

import math
import random
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


# ============================
# 1. 설정값 (하이퍼파라미터)
# ============================


# vocab: <pad>, <sos>, <eos>, a, b, c, d, e, f, g, h
NUM_LETTERS = 8  # a~h
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2
VOCAB_TOKENS = ["<pad>", "<sos>", "<eos>"] + [
    chr(ord("a") + i) for i in range(NUM_LETTERS)
]
VOCAB_SIZE = len(VOCAB_TOKENS)  # 3 + NUM_LETTERS

# 시퀀스 길이 및 데이터 수
MIN_SEQ_LEN = 3
MAX_SEQ_LEN = 8
NUM_TRAIN_SAMPLES = 20000
NUM_VALID_SAMPLES = 200

# Transformer 하이퍼파라미터
D_MODEL = 16
NHEAD = 4
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2
DIM_FEEDFORWARD = 32
DROPOUT = 0.1

# 학습 하이퍼파라미터
BATCH_SIZE = 32
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3


# ============================
# 2. 데이터셋 정의
# ============================


class CopyDataset(Dataset):
    """
    랜덤 문자 시퀀스를 만들고,
    타깃은 [<sos>] + src + [<eos>] 형태로 생성하는 Dataset.
    """

    def __init__(
        self,
        num_samples: int,
        min_len: int,
        max_len: int,
        vocab_start: int,
        vocab_end: int,
        sos_idx: int,
        eos_idx: int,
    ):
        super().__init__()
        self.num_samples = num_samples
        self.min_len = min_len
        self.max_len = max_len
        self.vocab_start = vocab_start
        self.vocab_end = vocab_end
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx

        self.data = [self._make_sample() for _ in range(num_samples)]

    def _make_sample(self) -> Tuple[torch.Tensor, torch.Tensor]:
        length = random.randint(self.min_len, self.max_len)
        # 문자 영역: [vocab_start, vocab_end] (정수 토큰)
        src_tokens = [
            random.randint(self.vocab_start, self.vocab_end) for _ in range(length)
        ]
        src = torch.tensor(src_tokens, dtype=torch.long)
        # trg: [<sos>, x1, ..., xL, <eos>]
        trg = torch.tensor(
            [self.sos_idx] + src_tokens + [self.eos_idx], dtype=torch.long
        )
        return src, trg

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.data[idx]


def collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]]):
    """
    batch: [(src, trg), (src, trg), ...]
      src: (src_len,)
      trg: (trg_len,)  -> [<sos>, x1, x2, ..., xL, <eos>]

    리턴:
      - src_batch : (B, S)
      - trg_input : (B, T) [<sos>, x1, ..., xL]
      - trg_output: (B, T) [x1, ..., xL, <eos>]
      - src_key_padding_mask: (B, S)  (PAD 위치=True)
      - tgt_key_padding_mask: (B, T)  (PAD 위치=True)
    """
    src_list, trg_list = zip(*batch)

    # (1) src 패딩
    src_batch = pad_sequence(
        src_list, batch_first=True, padding_value=PAD_IDX
    )  # (B, S)

    # (2) trg_input, trg_output 분리
    trg_input_list = []
    trg_output_list = []
    for trg in trg_list:
        # trg: [<sos>, x1, ..., xL, <eos>]
        trg_input_list.append(trg[:-1])  # [<sos>, x1, ..., xL]
        trg_output_list.append(trg[1:])  # [x1, ..., xL, <eos>]

    trg_input = pad_sequence(
        trg_input_list, batch_first=True, padding_value=PAD_IDX
    )  # (B, T)
    trg_output = pad_sequence(
        trg_output_list, batch_first=True, padding_value=PAD_IDX
    )  # (B, T)

    # Transformer는 key_padding_mask에서
    # "True = 가려야 할 위치(PAD)" 로 사용
    src_key_padding_mask = src_batch == PAD_IDX  # (B, S)
    tgt_key_padding_mask = trg_input == PAD_IDX  # (B, T)

    return src_batch, trg_input, trg_output, src_key_padding_mask, tgt_key_padding_mask


# ============================
# 3. Positional Encoding 정의
# ============================


class PositionalEncoding(nn.Module):
    """
    sin/cos 기반 Positional Encoding.
    입력: (B, S, D)
    출력: (B, S, D)  (PE를 더한 결과)
    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # (max_len, d_model) 모양의 위치 임베딩 미리 계산
        pe = torch.zeros(max_len, d_model)  # (L, D)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # (L, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )  # (D/2,)

        # 짝수/홀수 차원에 sin/cos 번갈아가며 넣기
        pe[:, 0::2] = torch.sin(
            position * div_term
        )  # 짝수 (position * div_term: (10, 8))
        pe[:, 1::2] = torch.cos(position * div_term)  # 홀수

        pe = pe.unsqueeze(0)  # (1, L, D)  → batch 차원 추가
        self.register_buffer(
            "pe", pe
        )  # 학습 파라미터는 아니지만, 모델과 함께 저장되도록

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, S, D)
        """
        seq_len = x.size(1)
        # pe[:, :seq_len, :] : (1, S, D) → broadcast 되어 더해짐
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)


# ============================
# 4. Transformer 기반 Copy 모델
# ============================


class TransformerCopyModel(nn.Module):
    """
    nn.Transformer를 사용한 Encoder-Decoder Copy 모델
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        nhead: int,
        num_encoder_layers: int,
        num_decoder_layers: int,
        dim_feedforward: int,
        dropout: float,
        pad_idx: int,
    ):
        super().__init__()
        self.d_model = d_model
        self.pad_idx = pad_idx

        self.tok_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,  # (B, S, D) 형태 사용
        )

        self.fc_out = nn.Linear(d_model, vocab_size)

    def generate_square_subsequent_mask(self, sz: int):
        """
        디코더용 causal mask 생성 (미래 토큰 가리기).
        shape: (sz, sz)  ; True 위치는 attention 불가능
        """
        mask = torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=1)
        # torch.triu = upper triangular (상삼각) 부분만 남기고 나머지는 0/False로 만드는 함수.
        # diagonal=1 이라는 건:
        # 메인 대각선보다 한 칸 위부터를 상삼각으로 보겠다는 뜻.

        return mask

    def forward(
        self,
        src: torch.Tensor,  # (B, S)
        tgt_input: torch.Tensor,  # (B, T)  [<sos>, x1, ..., xL]
        src_key_padding_mask: torch.Tensor,  # (B, S)  (True=PAD)
        tgt_key_padding_mask: torch.Tensor,  # (B, T)  (True=PAD)
    ):
        B, S = src.size()
        B2, T = tgt_input.size()
        # assert B == B2

        # 1) 토큰 임베딩 + 스케일 + Positional Encoding
        src_emb = self.tok_embedding(src) * math.sqrt(self.d_model)  # (B, S, D)
        # 토큰 임베딩 값의 스케일을 키워서, positional encoding에 비해 너무 작아지지 않게 하려고
        # √d_model을 곱해주는 것
        src_emb = self.pos_encoder(src_emb)  # (B, S, D)

        tgt_emb = self.tok_embedding(tgt_input) * math.sqrt(self.d_model)  # (B, T, D)
        tgt_emb = self.pos_encoder(tgt_emb)  # (B, T, D)

        # 2) 디코더용 causal mask (자기보다 뒤 시점은 못 보도록)
        tgt_mask = self.generate_square_subsequent_mask(T)  # (T, T)

        # 3) Transformer 호출
        #   - src_key_padding_mask: (B, S)  True=무시할 위치
        #   - tgt_key_padding_mask: (B, T)
        #   - memory_key_padding_mask: encoder output에 대한 mask (대부분 src와 동일)
        out = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )  # (B, T, D)

        logits = self.fc_out(out)  # (B, T, vocab_size)
        return logits


# ============================
# 5. 유틸 함수들
# ============================


def indices_to_tokens(indices: List[int]) -> List[str]:
    return [VOCAB_TOKENS[i] for i in indices]


def print_example(src, trg_output, pred_indices):
    """
    src: (S,)
    trg_output: (T,)
    pred_indices: (T,)
    """
    src_tokens = indices_to_tokens(src)
    trg_tokens = indices_to_tokens(trg_output)
    pred_tokens = indices_to_tokens(pred_indices)

    print("-------------------------------------------------")
    print("SRC        :", " ".join(src_tokens))
    print("TRG (gold) :", " ".join(trg_tokens))
    print("PRED       :", " ".join(pred_tokens))
    print("-------------------------------------------------")


# ============================
# 6. 학습 루프
# ============================


def train():
    # 1) 데이터셋 / 데이터로더
    train_dataset = CopyDataset(
        num_samples=NUM_TRAIN_SAMPLES,
        min_len=MIN_SEQ_LEN,
        max_len=MAX_SEQ_LEN,
        vocab_start=3,
        vocab_end=VOCAB_SIZE - 1,
        sos_idx=SOS_IDX,
        eos_idx=EOS_IDX,
    )
    valid_dataset = CopyDataset(
        num_samples=NUM_VALID_SAMPLES,
        min_len=MIN_SEQ_LEN,
        max_len=MAX_SEQ_LEN,
        vocab_start=3,
        vocab_end=VOCAB_SIZE - 1,
        sos_idx=SOS_IDX,
        eos_idx=EOS_IDX,
    )

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
    )

    # 2) 모델 준비
    model = TransformerCopyModel(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        num_decoder_layers=NUM_DECODER_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        dropout=DROPOUT,
        pad_idx=PAD_IDX,
    )

    # 3) 옵티마이저, 손실함수
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    def run_one_epoch(dataloader, train_mode: bool):
        if train_mode:
            model.train()
        else:
            model.eval()

        total_loss = 0.0
        total_tokens = 0

        with torch.set_grad_enabled(train_mode):
            for (
                src,
                trg_input,
                trg_output,
                src_key_padding_mask,
                tgt_key_padding_mask,
            ) in dataloader:

                src_key_padding_mask = src_key_padding_mask
                tgt_key_padding_mask = tgt_key_padding_mask

                logits = model(
                    src, trg_input, src_key_padding_mask, tgt_key_padding_mask
                )
                # logits: (B, T, V)
                B, T, V = logits.size()

                loss = criterion(logits.view(B * T, V), trg_output.reshape(-1))

                if train_mode:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                non_pad = (trg_output != PAD_IDX).sum().item()
                total_loss += loss.item() * non_pad
                total_tokens += non_pad

        if total_tokens == 0:
            return 0.0
        return total_loss / total_tokens

    # 4) Epoch 루프
    for epoch in range(1, NUM_EPOCHS + 1):
        train_loss = run_one_epoch(train_loader, train_mode=True)
        valid_loss = run_one_epoch(valid_loader, train_mode=False)

        print(
            f"[Epoch {epoch:02d}] "
            f"train_loss={train_loss:.4f}  valid_loss={valid_loss:.4f}"
        )

    # 5) 학습 후 예시 출력
    model.eval()
    with torch.no_grad():
        src, trg_input, trg_output, src_key_padding_mask, tgt_key_padding_mask = next(
            iter(valid_loader)
        )

        logits = model(src, trg_input, src_key_padding_mask, tgt_key_padding_mask)
        preds = logits.argmax(dim=-1)  # (B, T)

        num_show = min(3, src.size(0))
        for i in range(num_show):
            src_i = src[i]
            trg_i = trg_output[i]
            pred_i = preds[i]

            # PAD 제거
            src_i_nopad = src_i[src_i != PAD_IDX]
            trg_i_nopad = trg_i[trg_i != PAD_IDX]
            pred_i_nopad = pred_i[pred_i != PAD_IDX]

            # 길이 맞추기 (gold 길이에 맞춰 잘라서 보기)
            T_eff = len(trg_i_nopad)
            pred_i_cut = pred_i_nopad[:T_eff]

            print_example(
                src_i_nopad.tolist(), trg_i_nopad.tolist(), pred_i_cut.tolist()
            )


if __name__ == "__main__":
    train()

'''
  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
[Epoch 01] train_loss=2.2555  valid_loss=2.0528
[Epoch 02] train_loss=1.9942  valid_loss=1.8613
[Epoch 03] train_loss=1.8317  valid_loss=1.6800
[Epoch 04] train_loss=1.6940  valid_loss=1.5236
[Epoch 05] train_loss=1.5974  valid_loss=1.4131
[Epoch 06] train_loss=1.5238  valid_loss=1.3042
[Epoch 07] train_loss=1.4569  valid_loss=1.2217
[Epoch 08] train_loss=1.4119  valid_loss=1.1614
[Epoch 09] train_loss=1.3669  valid_loss=1.0942
[Epoch 10] train_loss=1.3257  valid_loss=1.0413
[Epoch 11] train_loss=1.2880  valid_loss=0.9935
[Epoch 12] train_loss=1.2664  valid_loss=0.9713
[Epoch 13] train_loss=1.2434  valid_loss=0.9266
[Epoch 14] train_loss=1.2119  valid_loss=0.8876
[Epoch 15] train_loss=1.1763  valid_loss=0.8588
-------------------------------------------------
SRC        : c c e
TRG (gold) : c c e <eos>
PRED       : c e e c
-------------------------------------------------
-------------------------------------------------
SRC        : e a b a b
TRG (gold) : e a b a b <eos>
PRED       : e a b a <eos> <eos>
-------------------------------------------------
-------------------------------------------------
SRC        : e d a c b d b
TRG (gold) : e d a c b d b <eos>
PRED       : d d b b b <eos> <eos> <eos>
-------------------------------------------------'''

'데이터 분석 > 머신러닝, 딥러닝' 카테고리의 다른 글

Attention  (0) 2025.12.09
시퀀스 모델  (0) 2025.12.05
단어 임베딩  (0) 2025.12.03
딥러닝 텍스트 전처리  (0) 2025.12.01
K-means Document Clustering  (0) 2025.11.28