使用Polars和PyTorch完成药物发现
使用perming完成实例
下载SELFIES分子描述符数据集,包含one_hot编码和相应的标签,然后使用polars加载数据集,
import polars
import numpy
df = polars.read_csv("descriptor_selfies.csv", separator='\t')
df.head()
"shape: (5, 3)\n",
"┌──────────────┬─────────────────────────────────┬─────────────────────────────────┐\n","│ id ┆ label ┆ one_hot │\n","│ --- ┆ --- ┆ --- │\n","│ str ┆ str ┆ str │\n","╞══════════════╪═════════════════════════════════╪═════════════════════════════════╡\n","│ CHEMBL179549 ┆ [24, 24, 24, 30, 24, 7, 24, 12… ┆ [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0… │\n","│ CHEMBL360920 ┆ [24, 24, 24, 30, 24, 7, 24, 12… ┆ [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0… │\n","│ CHEMBL182052 ┆ [24, 24, 24, 33, 24, 17, 8, 24… ┆ [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0… │\n","│ CHEMBL179662 ┆ [33, 24, 24, 18, 36, 24, 30, 2… ┆ [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0… │\n","│ CHEMBL181688 ┆ [24, 24, 17, 24, 24, 17, 24, 2… ┆ [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0… │\n","└──────────────┴─────────────────────────────────┴─────────────────────────────────┘"
读取特征数据集X和标签数据集y到内存中,
X, y= df["one_hot"], df["label"]
创建数值型数据集,转换字符型数据到numpy.ndarray内存中,
y_num = []
for y_str in y:y_list = eval(y_str)y_num.append(y_list)
y = numpy.array(y_num) # 创建 multi-outputs 标签
del y_str, y_list, y_num
X_num = []
for X_str in X:X_list = eval(X_str)X_flatten = [item for sublist in X_list for item in sublist]X_num.append(X_flatten)
X = numpy.array(X_num) # 创建 flatten 特征
del X_str, X_list, X_flatten, X_num
查看X和y的数组维度,
X.shape, y.shape
((1325, 8862), (1325, 211))
导入perming软件,并配置模型参数,包括输入特征维度、输入标签维度、隐藏单元的个数、批次大小、激活函数、评估标准、优化器、学习率,
import perming
main = perming.Box(8862, 211, (3000,), batch_size=64, activation='relu', inplace_on=True, criterion='MSELoss', solver='adam', learning_rate_init=0.01)
main.print_config() # extract features with main effects
MLP((mlp): Sequential((Linear0): Linear(in_features=8862, out_features=3000, bias=True)(Activation0): ReLU(inplace=True)(Linear1): Linear(in_features=3000, out_features=211, bias=True))
)
OrderedDict([('torch -v', '1.7.1+cu101'),('criterion', MSELoss()),('batch_size', 64),('solver',Adam (Parameter Group 0amsgrad: Falsebetas: (0.9, 0.99)eps: 1e-08lr: 0.01weight_decay: 0)),('lr_scheduler', None),('device', device(type='cuda'))])
因为是criterion=“MSELoss",所以将y转换为浮点类型,
y = y.astype("float") # multi-outputs
main.data_loader(X, y, random_seed=0)
训练输出的信息如下,
Epoch [1/60], Step [10/17], Training Loss: 683.9639, Validation Loss: 158.5208
Epoch [2/60], Step [10/17], Training Loss: 74.4517, Validation Loss: 16.8641
Epoch [3/60], Step [10/17], Training Loss: 41.3552, Validation Loss: 13.4700
Epoch [4/60], Step [10/17], Training Loss: 39.0044, Validation Loss: 13.0534
Epoch [5/60], Step [10/17], Training Loss: 35.2509, Validation Loss: 11.4001
Epoch [6/60], Step [10/17], Training Loss: 31.8014, Validation Loss: 10.8461
Epoch [7/60], Step [10/17], Training Loss: 29.4561, Validation Loss: 9.9555
Epoch [8/60], Step [10/17], Training Loss: 29.6944, Validation Loss: 8.6690
Epoch [9/60], Step [10/17], Training Loss: 26.8422, Validation Loss: 9.0475
Epoch [10/60], Step [10/17], Training Loss: 23.5534, Validation Loss: 8.4559
Epoch [11/60], Step [10/17], Training Loss: 22.1152, Validation Loss: 7.9635
Epoch [12/60], Step [10/17], Training Loss: 20.2078, Validation Loss: 7.2285
Epoch [13/60], Step [10/17], Training Loss: 18.4414, Validation Loss: 6.6722
Epoch [14/60], Step [10/17], Training Loss: 16.1690, Validation Loss: 6.1117
Epoch [15/60], Step [10/17], Training Loss: 13.4766, Validation Loss: 5.5554
Epoch [16/60], Step [10/17], Training Loss: 11.8005, Validation Loss: 5.1315
Epoch [17/60], Step [10/17], Training Loss: 10.5361, Validation Loss: 4.5014
Epoch [18/60], Step [10/17], Training Loss: 9.9437, Validation Loss: 3.5443
Epoch [19/60], Step [10/17], Training Loss: 10.6306, Validation Loss: 3.9934
Epoch [20/60], Step [10/17], Training Loss: 6.8944, Validation Loss: 3.4801
Epoch [21/60], Step [10/17], Training Loss: 7.0011, Validation Loss: 3.2631
Epoch [22/60], Step [10/17], Training Loss: 5.5973, Validation Loss: 2.2316
Epoch [23/60], Step [10/17], Training Loss: 4.8216, Validation Loss: 2.8058
Epoch [24/60], Step [10/17], Training Loss: 5.5661, Validation Loss: 2.7639
Epoch [25/60], Step [10/17], Training Loss: 3.6590, Validation Loss: 1.7645
Epoch [26/60], Step [10/17], Training Loss: 5.0497, Validation Loss: 2.5980
Epoch [27/60], Step [10/17], Training Loss: 3.6462, Validation Loss: 2.3122
Epoch [28/60], Step [10/17], Training Loss: 2.6290, Validation Loss: 1.4721
Epoch [29/60], Step [10/17], Training Loss: 3.0749, Validation Loss: 1.4911
Epoch [30/60], Step [10/17], Training Loss: 4.1072, Validation Loss: 2.2163
Epoch [31/60], Step [10/17], Training Loss: 3.2650, Validation Loss: 1.3476
Epoch [32/60], Step [10/17], Training Loss: 1.6862, Validation Loss: 2.0011
Epoch [33/60], Step [10/17], Training Loss: 1.6189, Validation Loss: 1.9502
Epoch [34/60], Step [10/17], Training Loss: 1.2636, Validation Loss: 1.8931
Epoch [35/60], Step [10/17], Training Loss: 0.8869, Validation Loss: 1.6373
Epoch [36/60], Step [10/17], Training Loss: 0.9195, Validation Loss: 1.7993
Epoch [37/60], Step [10/17], Training Loss: 1.3679, Validation Loss: 1.0173
Epoch [38/60], Step [10/17], Training Loss: 1.6105, Validation Loss: 1.1197
Epoch [39/60], Step [10/17], Training Loss: 1.1611, Validation Loss: 1.1692
Epoch [40/60], Step [10/17], Training Loss: 1.5315, Validation Loss: 1.2795
Epoch [41/60], Step [10/17], Training Loss: 1.5300, Validation Loss: 1.8402
Epoch [42/60], Step [10/17], Training Loss: 0.7957, Validation Loss: 1.1387
Epoch [43/60], Step [10/17], Training Loss: 0.6859, Validation Loss: 1.8664
Epoch [44/60], Step [10/17], Training Loss: 0.7019, Validation Loss: 1.7125
Epoch [45/60], Step [10/17], Training Loss: 0.6334, Validation Loss: 1.4241
Epoch [46/60], Step [10/17], Training Loss: 0.6721, Validation Loss: 1.6616
Epoch [47/60], Step [10/17], Training Loss: 0.5243, Validation Loss: 0.9518
Epoch [48/60], Step [10/17], Training Loss: 1.3168, Validation Loss: 1.6350
Epoch [49/60], Step [10/17], Training Loss: 3.0425, Validation Loss: 1.6377
Epoch [50/60], Step [10/17], Training Loss: 0.7838, Validation Loss: 1.0091
Epoch [51/60], Step [10/17], Training Loss: 0.8132, Validation Loss: 0.9184
Epoch [52/60], Step [10/17], Training Loss: 0.8691, Validation Loss: 1.3815
Epoch [53/60], Step [10/17], Training Loss: 0.5132, Validation Loss: 1.3461
Epoch [54/60], Step [10/17], Training Loss: 0.7782, Validation Loss: 1.6065
Epoch [55/60], Step [10/17], Training Loss: 0.4455, Validation Loss: 1.3275
Epoch [56/60], Step [10/17], Training Loss: 0.4325, Validation Loss: 0.9537
Epoch [57/60], Step [10/17], Training Loss: 0.7526, Validation Loss: 1.2994
Epoch [58/60], Step [10/17], Training Loss: 0.9250, Validation Loss: 0.9559
Epoch [59/60], Step [10/17], Training Loss: 0.7086, Validation Loss: 0.8987
Epoch [60/60], Step [10/17], Training Loss: 1.2652, Validation Loss: 1.5949
模型在测试数据集上的结果,
loss of Box on the 192 test dataset: 5.462291717529297.
OrderedDict([('problem', 'multi-outputs'),('loss',{'train': 1.0389052629470825,'val': 1.822463035583496,'test': 5.462291717529297})])
模型保存,
main.save(con=False, dir='./models/tp_selfies.ckpt')