转码 Day 19:

toc

昨天休息#

起晚了下午出去和朋友聚餐晚上也在外面玩了,江西菜太辣了,难受一晚上,故回家直接睡了,一点没学。

今天做了啥#

lc:199、114、17

默写复习了一下MHA,把MQA GQA也看了一下,实际上大差不差。 只不过MQA Q仍然是独立的(每个头一个) 、V共享(只有一个头)         self.W_k = nn.Linear(d_model, self.d_k)         self.W_v = nn.Linear(d_model, self.d_k) 然后forward里面操作前增加一个维度        K = self.W_k(key).unsqueeze(1)         V = self.W_v(value).unsqueeze(1) 不用管这个维度大小了,注意力计算 - K,V会自动广播到num_heads维度

GQA则是折中的,    K、V分成num_groups组,每组共享 num_groups = 1时就是MQA,num_groups = num_heads时就是MHA

K、V按组共享#

self.W_k = nn.Linear(d_model, self.num_groups * self.d_k) self.W_v = nn.Linear(d_model, self.num_groups * self.d_k) 在它的forward里面 kv 是(batch_size, seq_len_k, self.num_groups, self.d_k)

q的numhead维度改写为(batch_size, self.num_groups, self.heads_per_group, seq_len_q, self.d_k) 然后K = K.unsqueeze(2)  # [batch, num_groups, 1, seq_len_k, d_k]

自己手写了一下GQA的代码,感觉已经理解了,可以随手写出来了。 另外有个蠢问题问了下ai为什么out = torch.matmul(attn, V),要分好几行来变形它,不能直接 out = out.contiguous().view(batch_size, seq_len_q, self.d_model)吗? 如果不transpose直接view,数据会按照错误的维度顺序排列,view 在没把 head 维移到正确位置时,会把每个 token 的特征向量按错误方式拼接,导致 token 对应的表示不对。不同 head(以及 GQA 的不同 group)产生的 Dh 通道会被错拼到一起。

明天再看看MLA的优化方式,和这些有什么不同。理解MLA(Multi-head Latent Attention)原理