昨天休息
起晚了下午出去和朋友聚餐晚上也在外面玩了,江西菜太辣了,难受一晚上,故回家直接睡了,一点没学。
今天做了啥
lc:199、114、17
默写复习了一下MHA,把MQA GQA也看了一下,实际上大差不差。 只不过MQA Q仍然是独立的(每个头一个) 、V共享(只有一个头) self.W_k = nn.Linear(d_model, self.d_k) self.W_v = nn.Linear(d_model, self.d_k) 然后forward里面操作前增加一个维度 K = self.W_k(key).unsqueeze(1) V = self.W_v(value).unsqueeze(1) 不用管这个维度大小了,注意力计算 - K,V会自动广播到num_heads维度
GQA则是折中的, K、V分成num_groups组,每组共享 num_groups = 1时就是MQA,num_groups = num_heads时就是MHA
K、V按组共享
self.W_k = nn.Linear(d_model, self.num_groups * self.d_k) self.W_v = nn.Linear(d_model, self.num_groups * self.d_k) 在它的forward里面 kv 是(batch_size, seq_len_k, self.num_groups, self.d_k)
q的numhead维度改写为(batch_size, self.num_groups, self.heads_per_group, seq_len_q, self.d_k) 然后K = K.unsqueeze(2) # [batch, num_groups, 1, seq_len_k, d_k]
自己手写了一下GQA的代码,感觉已经理解了,可以随手写出来了。
另外有个蠢问题问了下ai为什么out = torch.matmul(attn, V),要分好几行来变形它,不能直接 out = out.contiguous().view(batch_size, seq_len_q, self.d_model)吗?
如果不transpose直接view,数据会按照错误的维度顺序排列,view 在没把 head 维移到正确位置时,会把每个 token 的特征向量按错误方式拼接,导致 token 对应的表示不对。不同 head(以及 GQA 的不同 group)产生的 Dh 通道会被错拼到一起。
明天再看看MLA的优化方式,和这些有什么不同。理解MLA(Multi-head Latent Attention)原理