Cross Attention 和微调

讨论了 Cross Attention 机制及其在微调中的应用,提供了具体实现和计算过程的示例。

  ·   1 min read

在 Cross Attention 机制中,Query(Q)、Key(K)和 Value(V)分别来自不同的输入源。具体而言,Q 来自 Decoder,而 K 和 V 通常来自 Encoder。这种设计使得模型能够将解码过程中的查询与编码过程中的上下文信息有效地结合起来。

具体实现

在以下的代码示例中,我们定义了 AttnProcessor2_0 类来处理 Cross Attention:

class AttnProcessor2_0:
    def __init__(self, query_dim, cross_attention_dim, inner_dim, inner_kv_dim, bias=True):
        self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
        self.to_k = nn.Linear(cross_attention_dim, self.inner_kv_dim, bias=bias)
        self.to_v = nn.Linear(cross_attention_dim, self.inner_kv_dim, bias=bias)

    def __call__(self, hidden_states, encoder_hidden_states=None):
        query = self.to_q(hidden_states)
        
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif self.norm_cross:
            encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
        
        key = self.to_k(encoder_hidden_states)
        value = self.to_v(encoder_hidden_states)

参数说明

  • Q (Query): 维度为 $16 \times 16 \times 256$
  • K (Key): 维度为 $77 \times 256$
  • V (Value): 维度为 $77 \times 256$

计算过程

在计算过程中,注意以下几点:

  1. Attention 计算: $$ \text{Attention} = Q \cdot K^T = 16 \times 16 \times 77 $$

  2. 输出计算: $$ \text{Output} = \text{Attention} \cdot V = 16 \times 16 \times 256 $$

微调策略

在微调模型时,可以根据以下策略进行调整:

  • 微调 K 和 V 矩阵: 适用于调整文本等其他控制嵌入。
  • 微调 Q 矩阵: 适用于微调内容,例如 Tune-A-Video。

这种微调方式使得模型能够灵活地适应不同的上下文和控制条件,从而提升生成的多样性和质量。