【Closure-Hayd】
RNA序列本身存在结构上的物理信息,因此可以利用文献提供的相关方法来对RNA序列的物理特征进行更加细致的提取。
-
几何向量编码(GVP模块)借鉴Rhodesign模型中的GVP(Geometric Vector Perceptron)模块,将每个核苷酸的原子坐标分解为标量特征(如原子间距离、二面角)和矢量特征(如C4'-C4'链方向向量)。例如:
-
标量特征:计算磷酸骨架(P-O5'-C5'-C4'-C3'-O3')的二面角、键长等几何参数
-
矢量特征:提取相邻核苷酸C4'原子的空间向量,编码局部骨架方向。
-
侧链特征:对N1/N9原子与骨架的几何关系进行编码,区分嘧啶和嘌呤碱基。
-
-
缺失值处理对NaN填充的原子坐标,采用掩码机制(masked attention)或插值补全(基于已知原子的空间分布预测缺失坐标),避免噪声干扰。
Rhodesign模型的github链接:https://github.com/ml4bio/RhoDesign
模型的文章链接:https://www.nature.com/articles/s43588-024-00720-6
RDesign/model/module.py at master · A4Bio/RDesign
git clone https://github.com/A4Bio/RDesign.giteval "$(/mnt/workspace/miniconda3/bin/conda shell.bash hook)"cd RDesign
conda env create -f environment.yml
conda activate RDesign
class TransformerLayer(nn.Module):def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.0):super(TransformerLayer, self).__init__()self.num_heads = num_headsself.num_hidden = num_hiddenself.num_in = num_inself.dropout = nn.Dropout(dropout)self.norm = nn.ModuleList([nn.BatchNorm1d(num_hidden) for _ in range(2)])self.attention = NeighborAttention(num_hidden, num_hidden + num_in, num_heads)self.dense = nn.Sequential(nn.Linear(num_hidden, num_hidden*4),nn.ReLU(),nn.Linear(num_hidden*4, num_hidden))def forward(self, h_V, h_E, edge_idx, batch_id=None):center_id = edge_idx[0]dh = self.attention(h_V, h_E, center_id, batch_id)h_V = self.norm[0](h_V + self.dropout(dh))dh = self.dense(h_V)h_V = self.norm[1](h_V + self.dropout(dh))return h_V
class NeighborAttention(nn.Module):def __init__(self, num_hidden, num_in, num_heads=4):super(NeighborAttention, self).__init__()self.num_heads = num_headsself.num_hidden = num_hiddenself.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)self.W_K = nn.Linear(num_in, num_hidden, bias=False)self.W_V = nn.Linear(num_in, num_hidden, bias=False)self.Bias = nn.Sequential(nn.Linear(num_hidden*3, num_hidden),nn.ReLU(),nn.Linear(num_hidden,num_hidden),nn.ReLU(),nn.Linear(num_hidden,num_heads))self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)def forward(self, h_V, h_E, center_id, batch_id):N = h_V.shape[0]E = h_E.shape[0]n_heads = self.num_headsd = int(self.num_hidden / n_heads)Q = self.W_Q(h_V).view(N, n_heads, 1, d)[center_id]K = self.W_K(h_E).view(E, n_heads, d, 1)attend_logits = torch.matmul(Q, K).view(E, n_heads, 1)attend_logits = attend_logits / np.sqrt(d)V = self.W_V(h_E).view(-1, n_heads, d) attend = scatter_softmax(attend_logits, index=center_id, dim=0)h_V = scatter_sum(attend*V, center_id, dim=0).view([N, self.num_hidden])h_V_update = self.W_O(h_V)return h_V_update