transformer(4):FFN 编码器块
文章目录
- FFN
- 1.目的
- 2.代码
- 编码器块
FFN
1.目的
注意力机制捕捉了序列里面不同位置的相关关系,并没有加强非线性表达能力,所以添加FFN用于增强非线性表达能力。
2.代码
class Pos_FFN(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(Pos_FFN, self).__init__(*args, **kwargs)
self.lin_1 = nn.Linear(num_hiddens, 1024, bias=False)
self.relu1 = nn.ReLU()
self.lin_2 = nn.Linear(1024, num_hiddens, bias=False)
self.relu2 = nn.ReLU()
def forward(self, X):
X = self.lin_1(X)
X = self.relu1(X)
X = self.lin_2(X)
X = self.relu2(X) # 可写可不写
return X
编码器块
class Encoder_block(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(Encoder_block, self).__init__(*args, **kwargs)
self.attention = Attention_block()
self.add_norm_1 = AddNorm()
self.FFN = Pos_FFN()
self.add_norm_2 = AddNorm()
def forward(self, X, I_m):
I_m = I_m.unsqueeze(-2)
X_1 = self.attention(X, I_m)
X = self.add_norm_1(X, X_1)
X_1 = self.FFN(X)
X = self.add_norm_2(X, X_1)
return X