1. Attention의 필요성
전통적인 Encoder-Decoder 모델의 문제점
Attention 메커니즘이 등장하기 이전의 기본적인 Seq2Seq 모델(RNN 기반의 Encoder-Decoder)은 다음과 같은 문제점을 가지고 있었습니다.
- Encoder: 긴 문장을 읽고 그 모든 정보를 하나의 요약 벡터(thought vector)인 마지막 hidden state에 압축합니다.
- Decoder: 이 요약 벡터만을 보고 순서대로 단어를 생성합니다.
- 문제:
- 문장이 길어질수록, 모든 정보를 단 하나의 hidden state에 압축해야 하므로 압축의 정도가 심해집니다.
- 압축이 심해질수록, 문장 앞쪽의 단어 정보가 손실되기 쉽습니다. (Long-term Dependency 문제)
- LSTM/GRU와 같은 개선된 RNN 구조가 도입되었지만, 이 문제를 완전히 해결하지는 못했습니다.
- 결과:
- 짧은 문장은 잘 처리되지만, 긴 문장의 경우 앞부분 내용이 번역에서 빠지거나, 틀리거나, 문맥이 이상해지는 현상이 발생했습니다.
Attention의 핵심 아이디어
Attention은 이 압축 및 정보 손실 문제를 해결하기 위해 도입되었습니다.
- 핵심 원리: Decoder가 단어를 하나 생성할 때마다, Encoder의 최종 요약 벡터만 보는 것이 아니라, Encoder의 모든 hidden state(h_1, h_2, ..., h_T)를 다시 참고하고, 현재 시점의 단어 생성에 중요한 위치에 더 큰 가중치를 줍니다.

2. Bahdanau (Additive) vs Luong (Multiplicative) Attention


import random
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt
import numpy as np
NUM_LETTERS = 8
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2
VOCAB_TOKENS = ["<pad>", "<sos>", "<eos>"] + [
chr(ord("a") + i) for i in range(NUM_LETTERS)
]
VOCAB_SIZE = len(VOCAB_TOKENS)
MIN_SEQ_LEN = 3
MAX_SEQ_LEN = 7
NUM_TRAIN_SAMPLES = 2000
NUM_VALID_SAMPLES = 200
EMBED_DIM = 32
HIDDEN_SIZE = 64
ATTN_DIM = 64
BATCH_SIZE = 32
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3
class CopyDataset(Dataset):
def __init__(
self,
num_samples: int,
min_len: int,
max_len: int,
vocab_start: int,
vocab_end: int,
sos_index: int,
eos_index: int,
):
super().__init__()
self.num_samples = num_samples
self.min_len = min_len
self.max_len = max_len
self.vocab_start = vocab_start
self.vocab_end = vocab_end
self.sos_idx = sos_index
self.eos_idx = eos_index
self.data = [self._make_sample() for _ in range(self.num_samples)]
def _make_sample(self) -> Tuple[torch.tensor, torch.tensor]:
length = random.randint(self.min_len, self.max_len)
src_tokens = [
random.randint(self.vocab_start, self.vocab_end) for _ in range(length)
]
src = torch.tensor(src_tokens, dtype=torch.long)
trg = torch.tensor(
[self.sos_idx] + src_tokens + [self.eos_idx], dtype=torch.long
)
return src, trg
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, index: int) -> Tuple[torch.tensor, torch.tensor]:
return self.data[index]
def collate_fn(batch: List[Tuple[torch.tensor, torch.tensor]]):
src_list, trg_list = zip(*batch)
src_batch = pad_sequence(src_list, batch_first=True, padding_value=PAD_IDX)
trg_input_list = []
trg_output_list = []
for trg in trg_list:
trg_input_list.append(trg[:-1]) # [sos_idx, w1, w2, .., wn]
trg_output_list.append(trg[1:]) # [w1, w2, w3...wn, eos_idx]
trg_input = pad_sequence(trg_input_list, batch_first=True, padding_value=PAD_IDX)
trg_output = pad_sequence(trg_output_list, batch_first=True, padding_value=PAD_IDX)
src_mask = (src_batch != PAD_IDX).long()
return src_batch, trg_input, trg_output, src_mask
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_size, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.rnn = nn.GRU(embed_dim, hidden_size, batch_first=True, bidirectional=False)
def forward(self, src, src_lengths=None):
emb = self.embedding(src)
outputs, hidden = self.rnn(emb) # output:(b, s, h), hidden:(1, b, h)
return outputs, hidden.squeeze(0) # hidden.squeeze(0)변경
class AdditiveAttention(nn.Module):
def __init__(self, hidden_size_enc, hidden_size_dec, attn_dim):
super().__init__()
self.W_h = nn.Linear(hidden_size_enc, attn_dim, bias=False)
self.W_s = nn.Linear(hidden_size_dec, attn_dim, bias=False)
self.v_a = nn.Linear(attn_dim, 1, bias=False)
def forward(self, encoder_hidden, decoder_hidden, mask=None):
Wh = self.W_h(encoder_hidden)
Ws = self.W_s(decoder_hidden).unsqueeze(1)
score = self.v_a(torch.tanh(Wh + Ws)).squeeze(-1)
if mask is not None:
score = score.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(score, dim=-1)
context = torch.bmm(
attn_weights.unsqueeze(1), # (B, 1, S)
encoder_hidden, # (B, S, H) => (B, 1, H) => (B, H)
).squeeze(1)
return context, attn_weights
class DecoderWithAttention(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_size, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.rnn = nn.GRU(embed_dim + hidden_size, hidden_size, batch_first=True)
self.fc_out = nn.Linear(hidden_size, vocab_size)
self.attention = AdditiveAttention(
hidden_size_enc=hidden_size, hidden_size_dec=hidden_size, attn_dim=ATTN_DIM
)
def forward(self, trg_input, encoder_outputs, encoder_mask, hidden):
B, T_in = trg_input.size()
emb = self.embedding(trg_input)
outputs = []
attn_list = []
decoder_hidden = hidden
input_step = emb[:, 0, :] # (B, E)
for t in range(1, T_in):
context, attn_weights = self.attention( # context.shape(B, S)
encoder_outputs, decoder_hidden, encoder_mask
)
attn_list.append(attn_weights.unsqueeze(1))
rnn_input = torch.cat([input_step, context], dim=-1).unsqueeze(
1
) # (B, E), (B, H) => (B, E+H) => (B, 1, E+H)
output, new_hidden = self.rnn(
rnn_input, decoder_hidden.unsqueeze(0) # (B, 1, E+H) # (1, B, H)
) # output :(B, 1, H), new_hidden : (1, B, H)
decoder_hidden = new_hidden.squeeze(0)
logits = self.fc_out(output.squeeze(1)) # (B, V)
outputs.append(logits.unsqueeze(1)) # (B, 1, V)
input_step = emb[:, t, :]
logits_all = torch.cat(outputs, dim=1) # (B, T_in-1, V)
attn_weights_all = torch.cat(attn_list, dim=1) # (B, T_in-1, S)
return logits_all, attn_weights_all
class Seq2Seq(nn.Module):
def __init__(self, encoder: Encoder, decoder: DecoderWithAttention):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, trg_input, src_mask):
encoder_outputs, enc_hidden = self.encoder(src)
logits, attn_weights = self.decoder(
trg_input, encoder_outputs, src_mask, enc_hidden
)
return logits, attn_weights
def indices_to_tokens(indices: List[int]) -> List[str]:
return [VOCAB_TOKENS[i] for i in indices]
def print_example(src, trg_input, trg_output, pred_indices):
"""
src: (S,)
trg_input: (T,)
trg_output: (T,) # [x1, ..., xL, <eos>]
pred_indices: (T,) # 예측 토큰 인덱스
"""
src_tokens = indices_to_tokens(src)
trg_tokens = indices_to_tokens(trg_output)
pred_tokens = indices_to_tokens(pred_indices)
print("-------------------------------------------------")
print("SRC :", " ".join(src_tokens))
print("TRG (gold) :", " ".join(trg_tokens))
print("PRED :", " ".join(pred_tokens))
print("-------------------------------------------------")
def train():
# 1) 데이터셋 / 데이터로더
train_dataset = CopyDataset(
num_samples=NUM_TRAIN_SAMPLES,
min_len=MIN_SEQ_LEN,
max_len=MAX_SEQ_LEN,
vocab_start=3,
vocab_end=VOCAB_SIZE - 1,
sos_index=SOS_IDX,
eos_index=EOS_IDX,
)
valid_dataset = CopyDataset(
num_samples=NUM_VALID_SAMPLES,
min_len=MIN_SEQ_LEN,
max_len=MAX_SEQ_LEN,
vocab_start=3,
vocab_end=VOCAB_SIZE - 1,
sos_index=SOS_IDX,
eos_index=EOS_IDX,
)
train_loader = DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
)
valid_loader = DataLoader(
valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)
# 2) 모델 준비
encoder = Encoder(
vocab_size=VOCAB_SIZE,
embed_dim=EMBED_DIM,
hidden_size=HIDDEN_SIZE,
pad_idx=PAD_IDX,
)
decoder = DecoderWithAttention(
vocab_size=VOCAB_SIZE,
embed_dim=EMBED_DIM,
hidden_size=HIDDEN_SIZE,
pad_idx=PAD_IDX,
)
model = Seq2Seq(encoder, decoder)
# 3) 옵티마이저, 손실함수
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
def train_one_epoch(dataloader, train_mode: bool):
if train_mode:
model.train()
else:
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.set_grad_enabled(train_mode):
for src, trg_input, trg_output, src_mask in dataloader:
logits, _ = model(src, trg_input, src_mask)
# logits: (B, T-1, V)
B, Tm1, V = logits.size()
# trg_output도 T-1 길이에 맞춰 자르기 (padding 때문에 길이가 다를 수는 없지만 안전용)
trg_out_cut = trg_output[:, :Tm1]
# PyTorch nn.CrossEntropyLoss는 2D/1D 형태를 기대
# input : (N, C) → N개 샘플, 각 샘플마다 C개 클래스 점수
# target : (N,) → 각 샘플의 정답 클래스 인덱스
loss = criterion(logits.view(B * Tm1, V), trg_out_cut.reshape(-1))
# loss : “배치 안의 모든 문장, 그 안의 모든 타임스텝(단어)에 대해
# CrossEntropyLoss를 계산해서 평균낸 값”
if train_mode:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 마스크 기반 토큰 수 집계
non_pad = (trg_out_cut != PAD_IDX).sum().item()
total_loss += loss.item() * non_pad
total_tokens += non_pad
if total_tokens == 0:
return 0.0
return total_loss / total_tokens
# 4) Epoch 루프
for epoch in range(1, NUM_EPOCHS + 1):
train_loss = train_one_epoch(train_loader, train_mode=True)
valid_loss = train_one_epoch(valid_loader, train_mode=False)
print(
f"[Epoch {epoch:02d}] "
f"train_loss={train_loss:.4f} valid_loss={valid_loss:.4f}"
)
# =======================================================================
# attn_weight 시각화
# =======================================================================
model.eval()
with torch.no_grad():
# 검증 데이터에서 한 배치 뽑기
src, trg_input, trg_output, src_mask = next(iter(valid_loader))
logits, attn_weights = model(src, trg_input, src_mask)
b = 0
src_b = src[b] # (S,)
trg_out_b = trg_output[b] # (T,)
attn_b = attn_weights[b] # (T-1, S)
# PAD 제거
src_no_pad = src_b[src_b != PAD_IDX]
trg_no_pad = trg_out_b[trg_out_b != PAD_IDX]
# trg_output = [x1, ..., xL, <eos>] 구조.
# 디코더 타임스텝(T-1) 개수에 맞춰 잘라주기 (길이 안 맞는 문제 방지)
Tm1 = attn_b.shape[0] # 디코더가 실제로 예측한 step 수
trg_no_pad = trg_no_pad[:Tm1]
# attn_b도 (T, S)로 자르기 (혹시 PAD 때문에 길이 차이가 나면 조정)
S = len(src_no_pad)
attn_mat = attn_b[: len(trg_no_pad), :S] # (T, S)
plot_attention_heatmap(
src_indices=src_no_pad.tolist(),
trg_indices=trg_no_pad.tolist(),
attn_matrix=attn_mat,
idx2token=VOCAB_TOKENS,
title="Sample 0 Attention",
)
def plot_attention_heatmap(
src_indices, trg_indices, attn_matrix, idx2token, title="Attention Heatmap"
):
"""
src_indices : (S,) 입력 문장 토큰 인덱스 (PAD 제거된 것)
trg_indices : (T,) 출력(혹은 타겟) 토큰 인덱스 (PAD 제거된 것, <sos> 뺀 상태)
attn_matrix: (T, S) numpy array, 해당 샘플의 attention 가중치
idx2token : 인덱스를 문자열 토큰으로 바꿔주는 리스트(VOCAB_TOKENS)
"""
# 인덱스를 실제 토큰 문자열로 변환
src_tokens = [idx2token[i] for i in src_indices]
trg_tokens = [idx2token[i] for i in trg_indices]
plt.figure(figsize=(len(src_tokens) * 0.7, len(trg_tokens) * 0.7))
plt.imshow(attn_matrix, aspect="auto", cmap="viridis")
# y축: 디코더 타임스텝(출력 토큰)
plt.yticks(ticks=np.arange(len(trg_tokens)), labels=trg_tokens, fontsize=10)
# x축: 인코더 타임스텝(입력 토큰)
plt.xticks(
ticks=np.arange(len(src_tokens)),
labels=src_tokens,
rotation=45,
ha="right",
fontsize=10,
)
plt.colorbar(label="Attention weight")
plt.xlabel("Source tokens (encoder)")
plt.ylabel("Target tokens (decoder)")
plt.title(title)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
train()'데이터 분석 > 머신러닝, 딥러닝' 카테고리의 다른 글
| Transformer (0) | 2025.12.10 |
|---|---|
| 시퀀스 모델 (0) | 2025.12.05 |
| 단어 임베딩 (0) | 2025.12.03 |
| 딥러닝 텍스트 전처리 (0) | 2025.12.01 |
| K-means Document Clustering (0) | 2025.11.28 |