老学庵

天行健,君子以自强不息;地势坤,君子以厚德载物!

0%

理解自注意力机制

  自Transformer架构的开山之作《Attention Is All You Need》发表以来,自注意力机制已成为深度学习模型的基石技术,尤其在自然语言处理领域展现革命性突破。鉴于该机制已渗透至各类先进模型架构,深入理解其运作原理变得至关重要。

  注意力机制的概念源起于神经机器翻译中对长序列处理的探索。传统逐字翻译的局限在于:不同语言特有的语法结构与语义逻辑(如中文"吃食堂"的动宾搭配)在直译中易被破坏,导致语义失真。


sample
逐字翻译(上)与注意力机制翻译(下)效果对比:后者能捕捉"making"与"more""difficult"的语义关联

  Transformer架构的革命性突破在于完全摒弃了循环神经网络的时序依赖。其提出的自注意力机制,通过动态计算序列元素间的关联权重,使模型能够: 1. 全局感知:每个位置都可以直接访问全部序列信息 2. 动态聚焦:根据上下文自动强化关键特征 3. 并行计算:突破RNN的序列计算瓶颈,支持并行计算


sample
论文中的注意力热力图直观呈现了单词"making"与上下文的多维关联(颜色深度表示注意力权重强度)

  从数学视角看,自注意力机制实质是构建动态特征增强网络:通过\((q,k,v)\)之间的矩阵运算,将原始词嵌入向量转换为包含上下文信息的增强表征向量。这种特性使其天然适配语言任务,例如"bank"在"river bank"与"bank loan"中通过注意力权重获得截然不同的语义编码。

  尽管后续研究涌现出稀疏注意力、局部注意力等改进变体,但原始缩放点积注意力仍是工业界首选,其优势在于:

  • 计算复杂度O(n²)在大规模分布式训练中可通过并行化缓解
  • 相比近似方法保留完整语义关联
  • 与硬件优化技术(如FlashAttention)高度适配

本文将以经典论文公式为框架展开解析: \[ Attention(Q,K,V)=softmax(\dfrac{QK^T}{\sqrt{d} + \epsilon})V \] 对该公式的逐项解构将揭示自注意力机制的本质特性。


扩展阅读推荐:
- 演进脉络:《高效Transformers综述
- 工程实践:《高效训练技术解析
- 最新突破:FlashAttention系列优化技术


Embedding an Input Sentence

  接下来以句子Life is short, eat dessert first为例,解释自注意力机制的原理。在处理文本时,由于计算机无法直接处理字符,一般会将字符映射到数域\(\phi\)中来处理。为了简单起见,在这个例子中数域\(\phi\)仅基于当前输入句子中的单词构建,在实际应用中,一般使用训练数据集中的所有单词来构建数域。

Code

1
2
3
4
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

输出

1
{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

接下来,我们使用这个字典为每个单词分配一个整数索引:

Code

1
2
3
4
5
6
import torch

sentence_int = torch.tensor(
[dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)

输出

1
tensor([0, 4, 5, 2, 1, 3])

  现在,得到输入句子向量形式,接下来使用嵌入层将输入编码为实向量嵌入。在这里,我们将使用一个简单的三维嵌入层,使得每个输入词都由一个三维向量表示。请注意,嵌入层的大小通常在几百到几千之间。例如,Llama2使用的嵌入大小为4096。为了方便说明在这里使用三维嵌入,这样我们就可以快速查看单个向量。例句由6个单词组成,输入通过嵌入层映射为\(6 \times 3\)的张量:

Code

1
2
3
4
5
6
7
vocab_size = 6
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

输出

1
2
3
4
5
6
7
tensor([[ 0.3374, -0.1778, -0.3035],
[ 0.1794, 1.8951, 0.4954],
[ 0.2692, -0.0770, -1.0205],
[-0.2196, -0.3792, 0.7671],
[-0.5880, 0.3486, 0.6603],
[-1.1925, 0.6984, -1.4097]])
torch.Size([6, 3])
## Defining the Weight Matrices   完成了输入数据的高维映射后,我们将深入探讨Transformer架构的核心组件之一:缩放点积注意力(Scaled Dot-Product Attention)。该机制由三个可训练权重矩阵\(W_q、W_k、W_v\)构成,这些矩阵作为模型的参数在训练过程中进行优化。通过这三个矩阵,输入序列被分别投影到对应的查询(query)、键(key)和值(value)向量空间。

  具体而言,查询、键和值序列的生成可通过以下方式实现:将输入嵌入向量x分别与权重矩阵\(W_q、W_k、W_v\)进行矩阵乘法运算,公式表达为:

  • query向量

\[q^{(i)} = x^{(i)}\cdot W_q \quad i\in[1,T]\]

  • key向量

\[k^{(i)} = x^{(i)}\cdot W_k \quad i\in[1,T]\]

  • value向量

\[v^{(i)} = x^{(i)}\cdot W_v \quad i\in[1,T]\]

索引\(i\)指输入序列中的标记索引位置,其长度为\(T\)

sample
计算query、key和value

  其中,\(q^{(i)}\)\(k^{(i)}\)维度为\(d_k\)。投影矩阵\(W_q\)\(W_k\)的维度为\(d \times d_k\) ,而\(W_v\)维度为\(d \times d_v\)。需要注意的是,\(d\)代表每个词向量\(x\)的大小。由于我们计算的是查询向量和键向量之间的点积,因此这两个向量必须包含相同数量的元素 ( \(d_q=d_k\))。在LLM中,我们对值向量使用相同的大小,例如\(d_q=d_k=d_v\)然而,值向量\(v(i)\)中的元素数量(它决定了最终上下文向量的大小)可以是任意的。

因此,对于以下代码演练,我们将设置\(d_q=d_k=2\)并使用\(d_v = 4\),初始化投影矩阵如下:

Code

1
2
3
4
5
6
7
8
9
10
torch.manual_seed(123)

d = embedded_sentence.shape[1]
# 与之前的词嵌入向量类似,$d_q、d_k、d_v$通常会比较大,为了便于说明,我们在这里使用较小的数字。
d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))

Computing the Unnormalized Attention Weights

\(x^{(2)}\)\(qkv\)计算为例:

sample
注意力计算

代码实现如下:

Code

1
2
3
4
5
6
7
8
9
x_2 = embedded_sentence[1]
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

输出

1
2
3
4
torch.Size([2])
torch.Size([2])
torch.Size([4])

  为了下一步计算非标准化注意力权重,计算所有输入的\(qkv\)

Code

1
2
3
4
5
6
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

输出

1
2
3
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])

得到了所有输入的\(qkv\),接下来开始计算非标准化注意力权重\(\omega\),如下图所示:

sample
计算非标准化注意力权重ω

如上图所示,我们计算\(\omega_{i,j}\)作为查询和键序列之间的点积:

\[\omega_{i,j} =q^{(i)}k^{(j)}\]

例如,我们可以按如下方式计算查询和第5个输入元素(对应于索引位置4)的非规范化注意力权重:

1
2
omega_24 = query_2.dot(keys[4])
print(omega_24)
1
2
tensor(1.2903)

由于我们稍后需要这些未标准化的注意力权重\(\omega\)来计算实际的注意力权重,因此让我们计算所有输入标记的\(\omega\)值,如上图所示:

1
2
3
omega_2 = query_2 @ keys.T
print(omega_2)

1
2
tensor([-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374])

Computing the Attention Weights

  自注意力机制的后续步骤是应用softmax函数对未归一化的注意力权重\(\omega\)进行归一化从而获得归一化的注意力权重\(\alpha\)。此外,在通过softmax函数进行归一化之前,先使用\(\dfrac{1}{\sqrt{d_k}}\)\(\omega\)进行缩放,如下所示:

sample
计算标准化注意力权重α

通过\(d_k\)进行缩放可确保权重向量的欧氏长度大致相同。这有助于防止注意力权重过小或过大,从而避免数值不稳定或影响模型在训练期间的收敛能力。

在代码中,我们可以按如下方式实现注意力权重的计算:

1
2
3
4
5
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

1
tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229])

最后一步是计算上下文向量\(z^{(2)}\) ,它是原始查询输入\(x_{(2)}\)的注意力加权版本,通过注意力权重将所有其他输入元素作为其上下文:

sample
注意力权重特定于某个输入元素。这里,我们选择了输入元素x(2)

在代码中,它看起来如下所示:

代码

1
2
3
4
5
context_vector_2 = attention_weights_2 @ values

print(context_vector_2.shape)
print(context_vector_2)

输出

1
2
3
torch.Size([4])
tensor([0.5313, 1.3607, 0.7891, 1.3110])

请注意,由于我们之前指定了\(d_v > d\) ,因此此输出向量具有比原始输入向量(\(d=3\))更多的维度(\(d_v=4\)),但是,嵌入大小选择\(d_v\)是任意的。

Self-Attention

现在,为了总结上面章节中自注意力机制的代码实现,我们可以将前面的代码总结在一个紧凑的SelfAttention类中:

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch.nn as nn

class SelfAttention(nn.Module):

def __init__(self, d_in, d_out_kq, d_out_v):
super().__init__()
self.d_out_kq = d_out_kq
self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

def forward(self, x):
keys = x @ self.W_key
queries = x @ self.W_query
values = x @ self.W_value

attn_scores = queries @ keys.T # unnormalized attention weights
attn_weights = torch.softmax(
attn_scores / self.d_out_kq**0.5, dim=-1
)

context_vec = attn_weights @ values
return context_vec

按照 PyTorch 的约定,SelfAttention上述类在方法中初始化自注意力参数__init__,并通过该方法计算所有输入的注意力权重和上下文向量forward。我们可以按如下方式使用此类:

代码

1
2
3
4
5
6
7
8
torch.manual_seed(123)

# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

输出

1
2
3
4
5
6
7
tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
[ 0.5313, 1.3607, 0.7891, 1.3110],
[-0.3542, -0.1234, -0.2627, -0.3706],
[ 0.0071, 0.3345, 0.0969, 0.1998],
[ 0.1008, 0.4780, 0.2021, 0.3674],
[-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)

如果您查看第二行,您会发现它与context_vector_2上一节中的值完全匹配:tensor([0.5313, 1.3607, 0.7891, 1.3110])

Multi-Head Attention

  Transformer使用了多头注意力的模块。这个“多头”注意力模块与我们上面讨论过的自注意力机制(尺度点积注意力)有何关系?在自注意力机制中,输入序列使用三个矩阵进行变换,分别代表查询、键和值。在多头注意力机制中,这三个矩阵可以被视为一个注意力头。下图总结了我们之前介绍和实现的这个注意力头:

sample
总结之前实现的自注意力机制

顾名思义,多头注意力机制包含多个这样的头,每个头由查询、键和值矩阵组成。这一概念类似于在卷积神经网络中使用多个核,从而生成具有多个输出通道的特征图。

sample
多头注意力

为了在代码中说明这一点,我们可以MultiHeadAttentionWrapper为之前的SelfAttention类编写一个类:

1
2
3
4
5
6
7
8
9
10
11
12
class MultiHeadAttentionWrapper(nn.Module):

def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
super().__init__()
self.heads = nn.ModuleList(
[SelfAttention(d_in, d_out_kq, d_out_v)
for _ in range(num_heads)]
)

def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)

这些d_*参数与之前的类中相同SelfAttention——这里唯一的新输入参数是注意力头的数量:

  • d_in:输入特征向量的维数。

  • d_out_kq:query和key输出的维度。

  • d_out_v:value输出的维度。

  • num_heads:注意力头的数量。

  我们使用这些输入参数初始化类SelfAttention时间num_heads。并使用 PyTorchnn.ModuleList来存储这些多个SelfAttention实例。然后,forward过程涉及将每个SelfAttention头(存储在 中self.heads)独立地应用于输入x。然后,每个头的结果沿着最后一个维度(dim=-1)连接起来。让我们在下面看看它的实际效果!

首先,为了便于说明,假设我们有一个输出维度为 1 的自注意力头:

Code

1
2
3
4
5
6
7
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 1

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

输出

1
2
3
4
5
6
tensor([[-0.0185],
[ 0.4003],
[-0.1103],
[ 0.0668],
[ 0.1180],
[-0.1827]], grad_fn=<MmBackward0>)

现在,让我们将其扩展到 4 个注意力头:

Code

1
2
3
4
5
6
7
8
9
10
11
12
torch.manual_seed(123)

block_size = embedded_sentence.shape[1]
mha = MultiHeadAttentionWrapper(
d_in, d_out_kq, d_out_v, num_heads=4
)

context_vecs = mha(embedded_sentence)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

输出

1
2
3
4
5
6
7
8
tensor([[-0.0185,  0.0170,  0.1999, -0.0860],
      [ 0.4003,  1.7137,  1.3981,  1.0497],
      [-0.1103, -0.1609,  0.0079, -0.2416],
      [ 0.0668,  0.3534,  0.2322,  0.1008],
      [ 0.1180,  0.6949,  0.3157,  0.2807],
      [-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([6, 4])

  从上面的输出可以看出,之前创建的单头注意力机制的输出张量现在对应于多头注意力输出张量的第一列。

  需要注意的是,多头注意力的输出是一个\(6\times4\)的张量:因为有6个输入token和4个自注意力头,其中每个自注意力头返回一个一维输出。在前面的自注意力计算中,同样生成了一个\(6 \times 4\)的张量,那是因为在该实例中输出维度被设置为4而不是1。在实践中,如果可以在SelfAttention类中也可以调节输出维度的大小,那么使用多头注意力的意义是什么呢?

  单自注意力头增加输出维度和使用多头注意力之间的区别在于模型处理和学习数据的方式。虽然,二者都能增加模型表示数据不同特征的能力,但它们本质上是不同的方法。

  多头注意力中的每个注意力都有可能学会关注输入序列的不同分布,捕捉数据中各个维度的关系,这种表现的多样性是多头注意力机制成功的关键。

  多头注意力在并行计算方面也更加高效。每个头都可以单独处理,这使得它非常适合使用GPU或TPU等硬件加速器来加速计算。

  简而言之,多头注意力的使用不只是为了让模型的能力更强,更显著增强了其学习数据中各种特征和关系的能力。在7B的Llama2模型中,使用32个自注意力头。

技术原理的深层解读

​+ 特征解耦

  每个头的查询、键、值矩阵(Q/K/V)通过独立线性变换生成,使不同头能学习到输入序列的不同投影空间。例如,一个头可能聚焦局部词序依赖,另一个头可能捕捉长距离语义关联。

  • ​容错机制

  实验表明,即使某些头的注意力权重失效,其他头仍能提供有效特征,增强了模型鲁棒性。 + ​计算效率

  虽然总参数量与单头扩展维度相当,但多头拆分维度(如512维拆分为32头×16维)可降低单个矩阵乘法的计算复杂度。

  • 工程实践启示

  在Transformer架构中,头的数量需平衡模型容量与计算资源。例如,ViT模型通常采用12-16头,而百亿参数大模型可能使用128头。值得注意的是,头数超过输入嵌入维度时会出现维度碎片化问题,因此现代模型常采用分组查询注意力(GQA)进行优化。

Causal Self-Attention

  为了在gpt(解码器风格)的文本生成大模型中应用注意力机制,在本节中,我们将引入Causal Self-Attention,本质上是对前文讨论的自注意机制进行适应性改造,以贴合使用场景。在原始的Transformer结构中,它对应于“屏蔽多头注意”模块,为了简单起见,我们将在本节中查看单个注意头,但是相同的概念可以推广到多个注意头。

sample
The causal self-attention module in the original transformer architecture (via “Attention Is All You Need”, https://arxiv.org/abs/1706.03762)

  Causal Self-Attention确保序列中某个位置的预测输出仅依赖于先前位置的已知输出,跟未来位置的输出无关。简单来说,它确保对下一个单词的预测只依赖于前面的单词。为了在gpt这类大语言模型中实现这一点,在处理每个token时,屏蔽输入文本中位于当前token之后的token,我们将这种这种屏蔽策略称为causal mask

  下面的图示说明了causal mask如何作用于注意力权重中,以隐藏输出中的未来输入token。
sample

  为了方便说明和实现Causal Self-Attention力,我们使用前一节中的未加权注意力分数和注意力权重,首先,我们快速回顾一下前面Self-Attention部分的注意力分数的计算:

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 4

W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
W_value = nn.Parameter(torch.rand(d_in, d_out_v))

x = embedded_sentence

keys = x @ W_key
queries = x @ W_query
values = x @ W_value

# attn_scores are the "omegas",
# the unnormalized attention weights
attn_scores = queries @ keys.T

print(attn_scores)
print(attn_scores.shape)

输出

1
2
3
4
5
6
7
8
9
tensor([[ 0.0613, -0.3491,  0.1443, -0.0437, -0.1303,  0.1076],
      [-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374],
      [ 0.2432, -1.3934,  0.5869, -0.1851, -0.5191,  0.4730],
      [-0.0794,  0.4487, -0.1807,  0.0518,  0.1677, -0.1197],
      [-0.1510,  0.8626, -0.3597,  0.1112,  0.3216, -0.2787],
      [ 0.4344, -2.5037,  1.0740, -0.3509, -0.9315,  0.9265]],
      grad_fn=<MmBackward0>)
torch.Size([6, 6])

  与之前的Self-Attention部分类似,上面的输出是一个6×6张量,其中包含6个输入令牌的成对非标准化注意权重(也称为注意分数)。之前,我们通过softmax函数计算缩放后的点积注意力,如下所示:

Code

1
2
3
attn_weights = torch.softmax(attn_scores / d_out_kq**0.5, dim=1)
print(attn_weights)

输出

1
2
3
4
5
6
7
8
tensor([[0.1772, 0.1326, 0.1879, 0.1645, 0.1547, 0.1831],
      [0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
      [0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],
      [0.1505, 0.2187, 0.1401, 0.1651, 0.1793, 0.1463],
      [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.1231],
      [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
      grad_fn=<SoftmaxBackward0>)

  上面的6×6输出表示注意权重。

  现在,在gpt这类的LLM模型中中,我们训练模型从左到右一次读取和生成一个令牌(或单词)。如果我们有一个训练文本样本,比如"Life is short eat desert first",我们有以下设置,其中箭头右侧单词的上下文向量应该只包含它自己和前面的单词:

  • "Life" → "is"

  • "Life is" → "short"

  • "Life is short" → "eat"

  • "Life is short eat" → "desert"

  • "Life is short eat desert" → "first"

  实现上述设置的最简单方法是通过对对角线上方的注意力权重矩阵应用掩码来屏蔽所有未来的标记,如下图所示。这样,在创建上下文向量时就不会包含“未来”词,上下文向量是作为输入的注意力加权和而创建的。

sample
Attention weights above the diagonal should be masked out

  在下面的代码中,我们可以通过PyTorch的tril函数来实现这一点,我们首先使用它来创建一个1和0的掩码:

Code

1
2
3
4
block_size = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)

输出

1
2
3
4
5
6
7
tensor([[1., 0., 0., 0., 0., 0.],
      [1., 1., 0., 0., 0., 0.],
      [1., 1., 1., 0., 0., 0.],
      [1., 1., 1., 1., 0., 0.],
      [1., 1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 1., 1.]])

接下来,我们将注意力权重与这个蒙版相乘,将对角线上方的所有注意力权重归零:

Code

1
2
3
masked_simple = attn_weights*mask_simple
print(masked_simple)

输出

1
2
3
4
5
6
7
8
tensor([[0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0386, 0.6870, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.1965, 0.0618, 0.2506, 0.0000, 0.0000, 0.0000],
      [0.1505, 0.2187, 0.1401, 0.1651, 0.0000, 0.0000],
      [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.0000],
      [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
      grad_fn=<MulBackward0>)

  虽然以上是一种屏蔽将来单词的方法,但请注意,每一行的注意力权重之和不再为1。为此,进行行归一化操作,使它们的权重总和为1,这是注意力权重的标准约定:

Code

1
2
3
4
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

输出

1
2
3
4
5
6
7
8
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
      [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
      [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
      [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
      grad_fn=<DivBackward0>)

  可以看到,现在每一行的注意力权重之和为1。

  在Transformer模型中,规范化的注意力权重比非规范化的权重更有优势,主要有两个原因。首先,归一化的注意力权重之和为1,类似于概率分布。这样就可以更容易地根据比例来解释模型对输入的各个部分的关注。其次,通过约束注意权值求和为1,这种归一化有助于控制权值和梯度的尺度,从而提高训练的动态性。

More efficient masking without renormalization

  在上面编码的Causal Self-Attention过程中,我们首先计算注意得分,然后计算注意权重,掩盖对角线上方的注意权重,最后重新规范化注意权重。如下图所示:

sample
The previously implemented causal self-attention procedure

  在实际实现时,还有一种更有效的方法可以达到同样的效果。在这种方法中,我们对角线上的值替换为负无穷,然后将这些值输入softmax函数以计算注意力权重。如下图所示:

sample

An alternative, more efficient approach to implementing causal self-attention

我们可以在PyTorch中按照如下方式编写这个过程,首先屏蔽对角线上方的注意力得分:

Code

1
2
3
4
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

  在上面的代码中,首先创建了一个对角线以下都是0,对角线以上都是1的mask。这里,torch.triu(上三角)保留了矩阵主对角线上的元素,并将下面的元素置零,从而保留了上三角部分。相比之下,torch.tril(下三角)则保留了主对角线上的元素以及下面的元素。然后,使用masked_fill方法通过正的mask值(1s)将上三角的所有元素替换为-torch.inf,结果如下所示。

输出

1
2
3
4
5
6
7
8
tensor([[ 0.0613,    -inf,    -inf,    -inf,    -inf,    -inf],
      [-0.6004, 3.4707,   -inf,   -inf,   -inf,   -inf],
      [ 0.2432, -1.3934, 0.5869,   -inf,   -inf,   -inf],
      [-0.0794, 0.4487, -0.1807, 0.0518,   -inf,   -inf],
      [-0.1510, 0.8626, -0.3597, 0.1112, 0.3216,   -inf],
      [ 0.4344, -2.5037, 1.0740, -0.3509, -0.9315, 0.9265]],
      grad_fn=<MaskedFillBackward0>)

然后,我们所要做的就是像往常一样应用softmax函数来获得归一化和屏蔽的注意力权重:

Code

1
2
3
attn_weights = torch.softmax(masked / d_out_kq**0.5, dim=1)
print(attn_weights)

输出

1
2
3
4
5
6
7
8
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
      [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
      [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
      [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
      grad_fn=<SoftmaxBackward0>)

  为什么会这样呢?在最后一步中应用的softmax函数将输入值转换为概率分布。当输入中存在-inf时,因为e^(-inf)接近于0,因此这些位置对输出概率没有贡献。

Conclusion

  在这篇文章里,我们一步步通过编程的方式,探索了自注意力机制的内部运作。基于此,我们又研究了多头注意力,这是大型语言转换器的一个基本组成部分。

  然后我们还编了交叉注意力,这是自注意力的一种变体,当应用于两个不同的序列之间时特别有效。最后,我们还编了Causal Self-Attention力,这个概念对于在GPT和Llama这样的解码器风格LLM中生成连贯且上下文适当的序列至关重要。

  通过从头开始编写这些复杂的机制,你或许对转换器和LLM中使用的自注意力机制的内部运作有了很好的理解。

(请注意,本文中提供的代码仅用于说明目的。如果您计划在培训llm时实现自我关注,我建议考虑像[Flash Attention](https://arxiv.org/abs/2307.08691)这样的优化实现,它可以减少内存占用和计算负载。

Bonus Topic: Cross-Attention

  在上面的代码讲解里,咱们设置了Self-Attention和Causal-Attention这两部分,把_d_q和_d_k都设成了2,_d_v设成了4。也就是说,咱们让查询序列和键序列用了同样的维度。虽然通常情况下,值矩阵W_v都会选择和查询矩阵、键矩阵一样的维度(比如PyTorch里的MultiHeadAttention类就是这样),但其实咱们也可以给值维度选个任意的数字大小……

本文翻译自Understanding and Coding Self-Attention, Multi-Head Attention, Causal-Attention, and Cross-Attention in LLMs