16  Multi-Head Attention e Projeções Lineares

Expansão do mecanismo de atenção para múltiplas cabeças (heads), permitindo que o modelo foque em diferentes partes da sequência simultaneamente.

17 Multi-Head Attention e Projeções Lineares

17.1 1. Introdução

No capítulo anterior, exploramos o mecanismo de Self-Attention (Autoatenção), que permite ao modelo ponderar a importância de diferentes palavras em uma sequência. No entanto, uma única operação de atenção limita a capacidade do modelo de focar em diferentes tipos de relacionamentos simultaneamente (por exemplo, concordância gramatical versus contexto semântico).

O Multi-Head Attention (MHA) resolve essa limitação executando múltiplos mecanismos de autoatenção em paralelo. Cada “cabeça” (head) aprende a focar em diferentes subespaços de representação, permitindo que o modelo capture nuances complexas e variadas da linguagem ao mesmo tempo.

17.2 2. O Conceito de Projeções Lineares

Antes de dividir o processamento em múltiplas cabeças, os vetores de entrada (Embeddings) passam por Projeções Lineares.

Matematicamente, uma projeção linear é uma multiplicação de matrizes que transforma o vetor de entrada original em um novo espaço vetorial. No contexto do Transformer, para cada cabeça \(i\), treinamos três matrizes de pesos independentes:

  1. \(W_i^Q\) (Pesos para Query)
  2. \(W_i^K\) (Pesos para Key)
  3. \(W_i^V\) (Pesos para Value)

Se a dimensão do modelo é \(d_{model}\) e queremos \(h\) cabeças, a dimensão de cada cabeça será tipicamente \(d_k = d_{model} / h\). As projeções reduzem a dimensionalidade para cada cabeça, mantendo o custo computacional total similar ao de uma única cabeça com dimensão total.

17.3 3. Arquitetura do Multi-Head Attention

O processo ocorre em quatro etapas principais:

  1. Projeção: Os vetores de entrada (Queries, Keys, Values) são projetados linearmente \(h\) vezes com matrizes de pesos diferentes e aprendidas.
  2. Scaled Dot-Product Attention: O mecanismo de atenção é aplicado em paralelo a cada uma dessas projeções.
  3. Concatenação: As saídas de todas as cabeças são concatenadas.
  4. Projeção Final: A concatenação passa por uma última camada linear (\(W^O\)) para restaurar a dimensão original e misturar as informações aprendidas pelas diferentes cabeças.

17.3.1 Formulação Matemática

Dado um conjunto de matrizes de entrada \(Q, K, V\):

\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O \]

Onde cada cabeça é calculada como:

\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]

E a função de Atenção é a padrão:

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

17.4 4. Diagrama de Fluxo de Dados

O diagrama abaixo ilustra como o fluxo de dados é dividido, processado e reunificado.

graph TD
    subgraph Inputs
    X[Input Embeddings]
    end

    subgraph "Linear Projections (Split)"
    LP_Q[Linear Proj Q]
    LP_K[Linear Proj K]
    LP_V[Linear Proj V]
    end

    subgraph "Heads (Parallel Processing)"
    H1[Head 1: Scaled Dot-Product]
    H2[Head 2: Scaled Dot-Product]
    H_dots[...]
    Hn[Head h: Scaled Dot-Product]
    end

    subgraph "Aggregation"
    Concat[Concatenation]
    LinearOut[Linear Proj Output (Wo)]
    end

    X --> LP_Q
    X --> LP_K
    X --> LP_V

    LP_Q -- Split --> H1 & H2 & H_dots & Hn
    LP_K -- Split --> H1 & H2 & H_dots & Hn
    LP_V -- Split --> H1 & H2 & H_dots & Hn

    H1 --> Concat
    H2 --> Concat
    H_dots --> Concat
    Hn --> Concat

    Concat --> LinearOut
    LinearOut --> Output[Final Context Vectors]

    style Inputs fill:#f9f,stroke:#333,stroke-width:2px
    style H1 fill:#bbf,stroke:#333
    style H2 fill:#bbf,stroke:#333
    style Hn fill:#bbf,stroke:#333
    style LinearOut fill:#dfd,stroke:#333

17.5 5. Implementação Técnica (PyTorch)

Na prática, não criamos \(h\) camadas lineares separadas por ineficiência. Em vez disso, projetamos para a dimensão total (\(d_{model}\)) e usamos operações de reshape e transpose para separar as cabeças logicamente.

Abaixo, uma implementação robusta e anotada:

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model deve ser divisível por num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Projeções Lineares (W_q, W_k, W_v)
        # Note que fazemos uma única matriz grande para eficiência
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # Projeção Final (W_o)
        self.w_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q, K, V shape: [batch_size, num_heads, seq_len, d_k]
        
        # 1. Matmul Q e K transposto
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 2. Aplicar máscara (opcional, usado no Decoder)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 3. Softmax
        attn_weights = torch.softmax(scores, dim=-1)
        
        # 4. Multiplicar pelos Values
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 1. Projeções Lineares e Reshape para separar cabeças
        # Transformação: [batch, seq_len, d_model] -> [batch, seq_len, num_heads, d_k]
        # Transpose: -> [batch, num_heads, seq_len, d_k] para facilitar matmul
        Q = self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.w_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.w_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 2. Aplicar Atenção em todas as cabeças simultaneamente
        attn_output, _ = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # 3. Concatenação
        # Transpose reverso: [batch, num_heads, seq_len, d_k] -> [batch, seq_len, num_heads, d_k]
        # Contiguous + View: -> [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 4. Projeção Linear Final
        output = self.w_o(attn_output)
        
        return output

17.6 6. Por que Múltiplas Cabeças? (Intuição)

Imagine a frase: “O banco negou o empréstimo porque ele estava sem fundos.”

Para entender a palavra “ele”, o modelo precisa relacioná-la a outras palavras. * Cabeça 1 (Sintática): Pode focar na estrutura gramatical, ligando “ele” ao sujeito da oração anterior. * Cabeça 2 (Semântica): Pode focar no contexto financeiro (“banco”, “empréstimo”, “fundos”) para desambiguar se “ele” se refere ao banco ou ao solicitante (neste caso, o contexto sugere o banco).

Se tivéssemos apenas uma cabeça de atenção, o modelo teria que fazer uma média ponderada de todos esses relacionamentos, potencialmente diluindo informações cruciais. Com o Multi-Head Attention, cada cabeça se especializa em um aspecto diferente da linguagem, resultando em uma representação muito mais rica e robusta.

17.7 7. Resumo

  • Multi-Head Attention permite o processamento paralelo de diferentes subespaços de representação.
  • Projeções Lineares (\(W^Q, W^K, W^V\)) são usadas para transformar a entrada em dimensões específicas para cada cabeça.
  • A saída é uma combinação linear (concatenação + \(W^O\)) das visões de todas as cabeças, fornecendo ao restante da rede uma visão contextual completa da sequência.