Seq2Seq - CrossEntropyLoss细节讨论
在 PyTorch 中,损失函数 CrossEntropyLoss
的输入参数通常需要满足特定的形状要求。对于 CrossEntropyLoss
,输入参数的形状要求如下:
-
input
:模型的输出,形状为[N, C]
,其中:-
N
是样本数量(或展平后的序列长度)。 -
C
是类别数量(目标词汇表的大小)。
-
-
target
:目标标签,形状为[N]
,其中每个元素是一个类别索引(整数)。
在上一节的代码中:
loss = loss_fn(similarities.view(-1, len(cn_vocab)), batch_labels.view(-1))
similarities
和 batch_labels
需要被调整为上述形状,以便符合 CrossEntropyLoss
的输入要求。以下详细解释为什么要这样写:
1. similarities.view(-1, len(cn_vocab))
-
similarities
的原始形状:[batch_size, seq_len, output_dim]
,其中:-
batch_size
是批次大小。 -
seq_len
是序列长度。 -
output_dim
是目标词汇表的大小(len(cn_vocab)
)。
-
-
目标形状:
[N, C]
,其中:-
N
是展平后的序列长度,即batch_size * seq_len
。 -
C
是目标词汇表的大小,即len(cn_vocab)
。
-
-
操作:
-
使用
view(-1, len(cn_vocab))
将similarities
展平为二维张量,形状为[batch_size * seq_len, len(cn_vocab)]
。 -
这样,每个时间步的输出都被展平为一个二维张量,每一行表示一个时间步的预测概率分布。
-
2. batch_labels.view(-1)
-
batch_labels
的原始形状:[batch_size, seq_len]
,其中:-
batch_size
是批次大小。 -
seq_len
是序列长度。
-
-
目标形状:
[N]
,其中:-
N
是展平后的序列长度,即batch_size * seq_len
。
-
-
操作:
-
使用
view(-1)
将batch_labels
展平为一维张量,形状为[batch_size * seq_len]
。 -
这样,每个时间步的目标标签都被展平为一个一维张量,每个元素是一个类别索引。
-
3. 为什么这样写
-
符合
CrossEntropyLoss
的输入要求:-
CrossEntropyLoss
要求输入的预测概率分布是一个二维张量[N, C]
,其中每一行表示一个样本的预测概率分布。 -
目标标签是一个一维张量
[N]
,其中每个元素是一个类别索引。
-
-
处理序列数据:
-
在序列到序列的任务中,每个时间步都有一个预测和一个目标标签。
-
通过展平操作,可以将所有时间步的预测和目标标签合并为一个批次,从而一次性计算整个批次的损失。
-
示例
假设:
-
batch_size = 2
-
seq_len = 3
-
output_dim = 5
(目标词汇表大小)
原始数据:
similarities: [2, 3, 5] # [batch_size, seq_len, output_dim]
batch_labels: [2, 3] # [batch_size, seq_len]
经过 view
操作后:
similarities.view(-1, 5): [6, 5] # [batch_size * seq_len, output_dim]
batch_labels.view(-1): [6] # [batch_size * seq_len]
这样,similarities
的每一行表示一个时间步的预测概率分布,batch_labels
的每个元素是一个类别索引,完全符合 CrossEntropyLoss
的输入要求。