【深度学习】Attention总结

Posted by ShawnD on May 7, 2021

无参注意力

n_train = 50 n_test = 50

x_train, _ = torch.sort(torch.rand(n_train) * 5)

x_test = torch.arange(0, 5, 0.1)

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))

y_truth = f(x_test)

x_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))

每一个x_test样本按行广播为x_train的形状,作为query。

Q:x_repeat: (n_test, n_train)

K:x_train: (n_train)

V: y_train: (n_train)

Attention Weight: (n_train, n_train)

有参注意力

Nadaraya-Watson Kernel Regression

n_train = 50 n_test = 50

x_train, _ = torch.sort(torch.rand(n_train) * 5)

x_test = torch.arange(0, 5, 0.1)

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))

y_truth = f(x_test)

x_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # Shape of the output `queries` and `attention_weights`:
        # (no. of queries, no. of key-value pairs)
        #--- dim_0是batch, dim_1是n_train ---
        queries = queries.repeat_interleave(keys.shape[1]).reshape(
            (-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # Shape of `values`: (no. of queries, no. of key-value pairs)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

X_tile = x_train.repeat((n_train, 1))

Y_tile = y_train.repeat((n_train, 1))

按列广播。

keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练

Q:x_train -> queries : (n_train) -> (n_train, n_train - 1)

K:keys: (n_train, n_train - 1)

V:values: (n_train, n_train - 1)

Attention Weight: (n_train, n_train - 1)

测试:

keys = x_train.repeat((n_test, 1))

values = y_train.repeat((n_test, 1))

Q:x_test -> queries : (n_test) -> (n_test, n_train)

K:keys: (n_test, n_train)

V:values: (n_test, n_train)

Attention Weight: (n_test, n_train)

Additive Attention

queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))

values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)

Q: queries: (2, 1, 20) -> (2, 1, num_hidden) -> (2, 1, 1, num_hidden)

K: keys: (2, 10, 2) -> (2, 10, num_hidden) -> (2, 1, 10, num_hidden)

V: values: (2, 10, 4)

Attention Weight: (2, 1,10, 8)-> (2, 1, 10, 1) -> (2, 1, 10)

Scaled Dot-Product Attention

queries = torch.normal(0, 1, (2, 1, 2))

keys = torch.ones((2, 10, 2))

values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)

Q: queries: (2, 1, 2)

K: keys: (2, 10, 2) -> (2, 2, 10)

V: values: (2, 10, 4)

Attention Weight: (2, 1,10)

Multi-Head Attention

multi-head attention 的每个 head 选择 scaled dot-product attention。令 $p_q = p_k = p_v = p_o/h$。

num_hiddens, num_heads = 100, 5

batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])

X = torch.ones((batch_size, num_queries, num_hiddens))

Y = torch.ones((batch_size, num_kvpairs, num_hiddens))

Q:X:(2, 4, 100) -> (2, 4, 100) -> (2*5, 4, 20)

K:Y:(2, 6, 100) -> (2, 6, 100) -> (2*5, 6, 20)

V:Y:(2, 6, 100) -> (2, 6, 100) -> (2*5, 6, 20)

Attention Weight: (2*5, 4, 6)

Self-Attention

num_hiddens, num_heads = 100, 5

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])

X = torch.ones((batch_size, num_queries, num_hiddens))