[:, :, 1]和[:, :, 0] 的区别; `prompt_vector` 和 `embedding_matrix`的作用
prompt_vector = torch.sum(prompt_embedding * attention_weights.unsqueeze(-1), dim=1) # [1, hidden_dim]
prompt_vector = torch.sum(prompt_embedding * attention_weights.unsqueeze(-1), dim=1)
主要作用是通过将 prompt_embedding
与 attention_weights
相乘后再按指定维度求和,得到一个新的张量 prompt_vector
。
代码解释
prompt_embedding
:这是一个包含提示词嵌入向量的张量,通常形状为[batch_size, seq_len, hidden_dim]
,表示批次大小、序列长度和隐藏层维度。attention_weights
:这是一个注意力权重张量,形状通常为[batch_size, seq_len]
,表示每个位置的注意力权重。