网站建设图片尺寸太原网站建设
文章目录
- FFN
- 1.目的
- 2.代码
- 编码器块
FFN
1.目的
注意力机制捕捉了序列里面不同位置的相关关系,并没有加强非线性表达能力,所以添加FFN用于增强非线性表达能力。
2.代码
class Pos_FFN(nn.Module):def __init__(self, *args, **kwargs) -> None:super(Pos_FFN, self).__init__(*args, **kwargs)self.lin_1 = nn.Linear(num_hiddens, 1024, bias=False)self.relu1 = nn.ReLU()self.lin_2 = nn.Linear(1024, num_hiddens, bias=False)self.relu2 = nn.ReLU()def forward(self, X):X = self.lin_1(X)X = self.relu1(X)X = self.lin_2(X)X = self.relu2(X) # 可写可不写return X
编码器块
class Encoder_block(nn.Module):def __init__(self, *args, **kwargs) -> None:super(Encoder_block, self).__init__(*args, **kwargs)self.attention = Attention_block()self.add_norm_1 = AddNorm()self.FFN = Pos_FFN()self.add_norm_2 = AddNorm()def forward(self, X, I_m):I_m = I_m.unsqueeze(-2)X_1 = self.attention(X, I_m)X = self.add_norm_1(X, X_1)X_1 = self.FFN(X)X = self.add_norm_2(X, X_1)return X