데이터 분석/머신러닝, 딥러닝

Attention

fullfish 2025. 12. 9. 15:50

1. Attention의 필요성

전통적인 Encoder-Decoder 모델의 문제점

Attention 메커니즘이 등장하기 이전의 기본적인 Seq2Seq 모델(RNN 기반의 Encoder-Decoder)은 다음과 같은 문제점을 가지고 있었습니다.

  • Encoder: 긴 문장을 읽고 그 모든 정보를 하나의 요약 벡터(thought vector)인 마지막 hidden state에 압축합니다.
  • Decoder: 이 요약 벡터만을 보고 순서대로 단어를 생성합니다.
  • 문제:
    • 문장이 길어질수록, 모든 정보를 단 하나의 hidden state에 압축해야 하므로 압축의 정도가 심해집니다.
    • 압축이 심해질수록, 문장 앞쪽의 단어 정보가 손실되기 쉽습니다. (Long-term Dependency 문제)
    • LSTM/GRU와 같은 개선된 RNN 구조가 도입되었지만, 이 문제를 완전히 해결하지는 못했습니다.
  • 결과:
    • 짧은 문장은 잘 처리되지만, 긴 문장의 경우 앞부분 내용이 번역에서 빠지거나, 틀리거나, 문맥이 이상해지는 현상이 발생했습니다.

Attention의 핵심 아이디어

Attention은 이 압축 및 정보 손실 문제를 해결하기 위해 도입되었습니다.

  • 핵심 원리: Decoder가 단어를 하나 생성할 때마다, Encoder의 최종 요약 벡터만 보는 것이 아니라, Encoder의 모든 hidden state(h_1, h_2, ..., h_T)를 다시 참고하고, 현재 시점의 단어 생성에 중요한 위치에 더 큰 가중치를 줍니다.

 

2. Bahdanau (Additive) vs Luong (Multiplicative) Attention

import random
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt
import numpy as np

NUM_LETTERS = 8
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2
VOCAB_TOKENS = ["<pad>", "<sos>", "<eos>"] + [
    chr(ord("a") + i) for i in range(NUM_LETTERS)
]

VOCAB_SIZE = len(VOCAB_TOKENS)

MIN_SEQ_LEN = 3
MAX_SEQ_LEN = 7
NUM_TRAIN_SAMPLES = 2000
NUM_VALID_SAMPLES = 200

EMBED_DIM = 32
HIDDEN_SIZE = 64
ATTN_DIM = 64

BATCH_SIZE = 32
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3


class CopyDataset(Dataset):
    def __init__(
        self,
        num_samples: int,
        min_len: int,
        max_len: int,
        vocab_start: int,
        vocab_end: int,
        sos_index: int,
        eos_index: int,
    ):
        super().__init__()
        self.num_samples = num_samples
        self.min_len = min_len
        self.max_len = max_len
        self.vocab_start = vocab_start
        self.vocab_end = vocab_end
        self.sos_idx = sos_index
        self.eos_idx = eos_index

        self.data = [self._make_sample() for _ in range(self.num_samples)]

    def _make_sample(self) -> Tuple[torch.tensor, torch.tensor]:
        length = random.randint(self.min_len, self.max_len)
        src_tokens = [
            random.randint(self.vocab_start, self.vocab_end) for _ in range(length)
        ]

        src = torch.tensor(src_tokens, dtype=torch.long)

        trg = torch.tensor(
            [self.sos_idx] + src_tokens + [self.eos_idx], dtype=torch.long
        )

        return src, trg

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, index: int) -> Tuple[torch.tensor, torch.tensor]:
        return self.data[index]


def collate_fn(batch: List[Tuple[torch.tensor, torch.tensor]]):
    src_list, trg_list = zip(*batch)

    src_batch = pad_sequence(src_list, batch_first=True, padding_value=PAD_IDX)

    trg_input_list = []
    trg_output_list = []

    for trg in trg_list:
        trg_input_list.append(trg[:-1])  # [sos_idx, w1, w2, .., wn]
        trg_output_list.append(trg[1:])  # [w1, w2, w3...wn, eos_idx]

    trg_input = pad_sequence(trg_input_list, batch_first=True, padding_value=PAD_IDX)

    trg_output = pad_sequence(trg_output_list, batch_first=True, padding_value=PAD_IDX)

    src_mask = (src_batch != PAD_IDX).long()
    return src_batch, trg_input, trg_output, src_mask


class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)

        self.rnn = nn.GRU(embed_dim, hidden_size, batch_first=True, bidirectional=False)

    def forward(self, src, src_lengths=None):
        emb = self.embedding(src)
        outputs, hidden = self.rnn(emb)  # output:(b, s, h), hidden:(1, b, h)
        return outputs, hidden.squeeze(0)  # hidden.squeeze(0)변경


class AdditiveAttention(nn.Module):
    def __init__(self, hidden_size_enc, hidden_size_dec, attn_dim):
        super().__init__()
        self.W_h = nn.Linear(hidden_size_enc, attn_dim, bias=False)
        self.W_s = nn.Linear(hidden_size_dec, attn_dim, bias=False)
        self.v_a = nn.Linear(attn_dim, 1, bias=False)

    def forward(self, encoder_hidden, decoder_hidden, mask=None):
        Wh = self.W_h(encoder_hidden)
        Ws = self.W_s(decoder_hidden).unsqueeze(1)
        score = self.v_a(torch.tanh(Wh + Ws)).squeeze(-1)

        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(score, dim=-1)

        context = torch.bmm(
            attn_weights.unsqueeze(1),  # (B, 1, S)
            encoder_hidden,  # (B, S, H)  =>  (B, 1, H) => (B, H)
        ).squeeze(1)

        return context, attn_weights


class DecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)

        self.rnn = nn.GRU(embed_dim + hidden_size, hidden_size, batch_first=True)

        self.fc_out = nn.Linear(hidden_size, vocab_size)
        self.attention = AdditiveAttention(
            hidden_size_enc=hidden_size, hidden_size_dec=hidden_size, attn_dim=ATTN_DIM
        )

    def forward(self, trg_input, encoder_outputs, encoder_mask, hidden):

        B, T_in = trg_input.size()
        emb = self.embedding(trg_input)

        outputs = []
        attn_list = []

        decoder_hidden = hidden
        input_step = emb[:, 0, :]  # (B, E)

        for t in range(1, T_in):
            context, attn_weights = self.attention(  # context.shape(B, S)
                encoder_outputs, decoder_hidden, encoder_mask
            )

            attn_list.append(attn_weights.unsqueeze(1))

            rnn_input = torch.cat([input_step, context], dim=-1).unsqueeze(
                1
            )  # (B, E), (B, H) => (B, E+H) => (B, 1, E+H)

            output, new_hidden = self.rnn(
                rnn_input, decoder_hidden.unsqueeze(0)  #  (B, 1, E+H)  # (1, B, H)
            )  # output :(B, 1, H), new_hidden : (1, B, H)

            decoder_hidden = new_hidden.squeeze(0)
            logits = self.fc_out(output.squeeze(1))  # (B, V)

            outputs.append(logits.unsqueeze(1))  # (B, 1, V)

            input_step = emb[:, t, :]

        logits_all = torch.cat(outputs, dim=1)  # (B, T_in-1, V)
        attn_weights_all = torch.cat(attn_list, dim=1)  # (B, T_in-1, S)

        return logits_all, attn_weights_all


class Seq2Seq(nn.Module):

    def __init__(self, encoder: Encoder, decoder: DecoderWithAttention):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg_input, src_mask):
        encoder_outputs, enc_hidden = self.encoder(src)
        logits, attn_weights = self.decoder(
            trg_input, encoder_outputs, src_mask, enc_hidden
        )

        return logits, attn_weights


def indices_to_tokens(indices: List[int]) -> List[str]:
    return [VOCAB_TOKENS[i] for i in indices]


def print_example(src, trg_input, trg_output, pred_indices):
    """
    src: (S,)
    trg_input: (T,)
    trg_output: (T,)  # [x1, ..., xL, <eos>]
    pred_indices: (T,) # 예측 토큰 인덱스
    """
    src_tokens = indices_to_tokens(src)
    trg_tokens = indices_to_tokens(trg_output)
    pred_tokens = indices_to_tokens(pred_indices)

    print("-------------------------------------------------")
    print("SRC        :", " ".join(src_tokens))
    print("TRG (gold) :", " ".join(trg_tokens))
    print("PRED       :", " ".join(pred_tokens))
    print("-------------------------------------------------")


def train():
    # 1) 데이터셋 / 데이터로더
    train_dataset = CopyDataset(
        num_samples=NUM_TRAIN_SAMPLES,
        min_len=MIN_SEQ_LEN,
        max_len=MAX_SEQ_LEN,
        vocab_start=3,
        vocab_end=VOCAB_SIZE - 1,
        sos_index=SOS_IDX,
        eos_index=EOS_IDX,
    )
    valid_dataset = CopyDataset(
        num_samples=NUM_VALID_SAMPLES,
        min_len=MIN_SEQ_LEN,
        max_len=MAX_SEQ_LEN,
        vocab_start=3,
        vocab_end=VOCAB_SIZE - 1,
        sos_index=SOS_IDX,
        eos_index=EOS_IDX,
    )

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
    )

    # 2) 모델 준비
    encoder = Encoder(
        vocab_size=VOCAB_SIZE,
        embed_dim=EMBED_DIM,
        hidden_size=HIDDEN_SIZE,
        pad_idx=PAD_IDX,
    )
    decoder = DecoderWithAttention(
        vocab_size=VOCAB_SIZE,
        embed_dim=EMBED_DIM,
        hidden_size=HIDDEN_SIZE,
        pad_idx=PAD_IDX,
    )

    model = Seq2Seq(encoder, decoder)

    # 3) 옵티마이저, 손실함수
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    def train_one_epoch(dataloader, train_mode: bool):
        if train_mode:
            model.train()
        else:
            model.eval()

        total_loss = 0.0
        total_tokens = 0

        with torch.set_grad_enabled(train_mode):
            for src, trg_input, trg_output, src_mask in dataloader:

                logits, _ = model(src, trg_input, src_mask)
                # logits: (B, T-1, V)
                B, Tm1, V = logits.size()

                # trg_output도 T-1 길이에 맞춰 자르기 (padding 때문에 길이가 다를 수는 없지만 안전용)
                trg_out_cut = trg_output[:, :Tm1]

                # PyTorch nn.CrossEntropyLoss는 2D/1D 형태를 기대
                # input  : (N, C)  → N개 샘플, 각 샘플마다 C개 클래스 점수
                # target : (N,)    → 각 샘플의 정답 클래스 인덱스
                loss = criterion(logits.view(B * Tm1, V), trg_out_cut.reshape(-1))
                # loss : “배치 안의 모든 문장, 그 안의 모든 타임스텝(단어)에 대해
                # CrossEntropyLoss를 계산해서 평균낸 값”

                if train_mode:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                # 마스크 기반 토큰 수 집계
                non_pad = (trg_out_cut != PAD_IDX).sum().item()
                total_loss += loss.item() * non_pad
                total_tokens += non_pad

        if total_tokens == 0:
            return 0.0
        return total_loss / total_tokens

    # 4) Epoch 루프
    for epoch in range(1, NUM_EPOCHS + 1):
        train_loss = train_one_epoch(train_loader, train_mode=True)
        valid_loss = train_one_epoch(valid_loader, train_mode=False)

        print(
            f"[Epoch {epoch:02d}] "
            f"train_loss={train_loss:.4f}  valid_loss={valid_loss:.4f}"
        )

        # =======================================================================
        # attn_weight 시각화
        # =======================================================================
    model.eval()
    with torch.no_grad():
        # 검증 데이터에서 한 배치 뽑기
        src, trg_input, trg_output, src_mask = next(iter(valid_loader))

        logits, attn_weights = model(src, trg_input, src_mask)

        b = 0
        src_b = src[b]  # (S,)
        trg_out_b = trg_output[b]  # (T,)
        attn_b = attn_weights[b]  # (T-1, S)

        # PAD 제거
        src_no_pad = src_b[src_b != PAD_IDX]
        trg_no_pad = trg_out_b[trg_out_b != PAD_IDX]

        # trg_output = [x1, ..., xL, <eos>] 구조.
        # 디코더 타임스텝(T-1) 개수에 맞춰 잘라주기 (길이 안 맞는 문제 방지)
        Tm1 = attn_b.shape[0]  # 디코더가 실제로 예측한 step 수
        trg_no_pad = trg_no_pad[:Tm1]

        # attn_b도 (T, S)로 자르기 (혹시 PAD 때문에 길이 차이가 나면 조정)
        S = len(src_no_pad)
        attn_mat = attn_b[: len(trg_no_pad), :S]  # (T, S)

        plot_attention_heatmap(
            src_indices=src_no_pad.tolist(),
            trg_indices=trg_no_pad.tolist(),
            attn_matrix=attn_mat,
            idx2token=VOCAB_TOKENS,
            title="Sample 0 Attention",
        )


def plot_attention_heatmap(
    src_indices, trg_indices, attn_matrix, idx2token, title="Attention Heatmap"
):
    """
    src_indices : (S,)  입력 문장 토큰 인덱스 (PAD 제거된 것)
    trg_indices : (T,)  출력(혹은 타겟) 토큰 인덱스 (PAD 제거된 것, <sos> 뺀 상태)
    attn_matrix: (T, S) numpy array, 해당 샘플의 attention 가중치
    idx2token  : 인덱스를 문자열 토큰으로 바꿔주는 리스트(VOCAB_TOKENS)
    """

    # 인덱스를 실제 토큰 문자열로 변환
    src_tokens = [idx2token[i] for i in src_indices]
    trg_tokens = [idx2token[i] for i in trg_indices]

    plt.figure(figsize=(len(src_tokens) * 0.7, len(trg_tokens) * 0.7))
    plt.imshow(attn_matrix, aspect="auto", cmap="viridis")

    # y축: 디코더 타임스텝(출력 토큰)
    plt.yticks(ticks=np.arange(len(trg_tokens)), labels=trg_tokens, fontsize=10)
    # x축: 인코더 타임스텝(입력 토큰)
    plt.xticks(
        ticks=np.arange(len(src_tokens)),
        labels=src_tokens,
        rotation=45,
        ha="right",
        fontsize=10,
    )

    plt.colorbar(label="Attention weight")
    plt.xlabel("Source tokens (encoder)")
    plt.ylabel("Target tokens (decoder)")
    plt.title(title)
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    train()

'데이터 분석 > 머신러닝, 딥러닝' 카테고리의 다른 글

Transformer  (0) 2025.12.10
시퀀스 모델  (0) 2025.12.05
단어 임베딩  (0) 2025.12.03
딥러닝 텍스트 전처리  (0) 2025.12.01
K-means Document Clustering  (0) 2025.11.28