1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
# Q, K, V 的线性变换
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# 输出投影
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.scale = math.sqrt(self.head_dim)
def forward(self, x, mask=None):
"""
x: (batch_size, seq_len, embed_dim)
mask: 可选的注意力掩码
"""
batch_size, seq_len, _ = x.shape
# 线性变换
Q = self.q_proj(x) # (batch, seq_len, embed_dim)
K = self.k_proj(x)
V = self.v_proj(x)
# 重塑为多头形式
# (batch, seq_len, embed_dim) -> (batch, num_heads, seq_len, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# scores: (batch, num_heads, seq_len, seq_len)
# 应用掩码(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attention_weights = torch.softmax(scores, dim=-1)
# 加权求和
context = torch.matmul(attention_weights, V)
# context: (batch, num_heads, seq_len, head_dim)
# 拼接多头
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.embed_dim)
# 输出投影
output = self.out_proj(context)
return output, attention_weights
|