生存分析任务建模以及损失函数
-
Loss参考自项目:https://github.com/JJ-ZHOU-Code/RobustMultiModel/blob/00d8c10a4d3a14ef0c02584b991b139678e6845d/utils/utils.py
-
可以直接看这一行,理解模型输出:https://github.com/JJ-ZHOU-Code/RobustMultiModel/blob/00d8c10a4d3a14ef0c02584b991b139678e6845d/models/model_survpath.py#L224
建模输出
这个代码可以看到模型的输出是:harzards
这个harzards就是一个时序分类任务,如果我是5年生存分析,那么就可以建模成10个时序点,包括:[0,0.5year) , [0.5, 1year) … [4.5year, 5year) 这一共10个时序点。如果需要建模到更细,那么分类的点就会越多。一般是一年一个点或者半年一个点比较合理。这个项目使用的就是半年一个点。
注意两个点:
- 时间点的区间是左闭右开的。
-
5years是没有的,因为当前面10个点sigmoid之后的概率都很小,那么就说明患者生存概率很大,归类为 >5years这个点。
因此这个模型的最后是一个 classifier = Linear(embed_dim, 11)
,生成的harzards就是第i个点发生事件的概率(如果是死亡,那么就是说第i个时间段发生死亡的概率最大)。
基于harzards我们可以得到生存率
hazards = torch.sigmoid(logits)
Surv = torch.cumprod(1 - hazards, dim=1)
Loss函数
基于CrossEntropy多分类建模
通过多分类的方式来建模患者会在哪一个时间点发生事件。
可以看到,用于监督信号的label是Y:shape=(B, k)
的,是离散的时序点而不是连续的时间。
因此在使用这个loss的时候,需要将监督信号整理为多分类问题(多个类别选一个)。
好像没见到什么人用,建议使用NLLloss。
def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):batch_size = len(Y)Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,kc = c.view(batch_size, 1).float() #censorship status, 0 or 1if S is None:S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards# without padding, S(0) = S[0], h(0) = h[0]# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]#h[y] = h(1)#S[1] = S(1)S_padded = torch.cat([torch.ones_like(c), S], 1)reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))ce_l = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))loss = (1-alpha) * ce_l + alpha * regloss = loss.mean()return lossclass CrossEntropySurvLoss(object):def __init__(self, alpha=0.15):self.alpha = alphadef __call__(self, hazards, S, Y, c, alpha=None): if alpha is None:return ce_loss(hazards, S, Y, c, alpha=self.alpha)else:return ce_loss(hazards, S, Y, c, alpha=alpha)
基于NLLloss时序分析
仍然是通过多分类的方式来建模患者会在哪一个时间点发生事件。
可以看到,用于监督信号的label是Y:shape=(B, k)
的,是离散的时序点而不是连续的时间。
跟CrossEntropyLoss一样,需要将监督信号整理为离散的时序点。
预测的时候使用log当成回归任务进行建模?
def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):batch_size = len(Y)Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,kc = c.view(batch_size, 1).float() #censorship status, 0 or 1if S is None:S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards# without padding, S(0) = S[0], h(0) = h[0]S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]#h[y] = h(1)#S[1] = S(1)uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))neg_l = censored_loss + uncensored_lossloss = (1-alpha) * neg_l + alpha * uncensored_lossloss = loss.mean()return lossclass NLLSurvLoss(object):def __init__(self, alpha=0.15):self.alpha = alphadef __call__(self, hazards, S, Y, c, alpha=None):if alpha is None:return nll_loss(hazards, S, Y, c, alpha=self.alpha)else:return nll_loss(hazards, S, Y, c, alpha=alpha)
基于Cox函数
CoxSurvLoss 的逻辑有所不同。它不使用离散的时间区间 Y,而是直接比较批次内患者的生存时间 S(在这里S代表的是一个时间点,而不是概率)和风险得分 hazards(在这里 hazards 通常是一个 (B, 1) 的风险分数值,而不是 (B, K) 的概率矩阵),来计算偏似然损失。
# 其中S是患者的生存时间
# c是 [1删失,0无删失]
class CoxSurvLoss(object):def __call__(hazards, S, c, **kwargs):# This calculation credit to Travers Ching https://github.com/traversc/cox-nnet# Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics datacurrent_batch_len = len(S)R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)for i in range(current_batch_len):for j in range(current_batch_len):R_mat[i,j] = S[j] >= S[i]R_mat = torch.FloatTensor(R_mat).to(device)theta = hazards.reshape(-1)exp_theta = torch.exp(theta)loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (1-c))return loss_cox