当前位置: 首页 > news >正文

生存分析任务建模以及损失函数

  • Loss参考自项目:https://github.com/JJ-ZHOU-Code/RobustMultiModel/blob/00d8c10a4d3a14ef0c02584b991b139678e6845d/utils/utils.py

  • 可以直接看这一行,理解模型输出:https://github.com/JJ-ZHOU-Code/RobustMultiModel/blob/00d8c10a4d3a14ef0c02584b991b139678e6845d/models/model_survpath.py#L224

建模输出

这个代码可以看到模型的输出是:harzards

这个harzards就是一个时序分类任务,如果我是5年生存分析,那么就可以建模成10个时序点,包括:[0,0.5year) , [0.5, 1year) … [4.5year, 5year) 这一共10个时序点。如果需要建模到更细,那么分类的点就会越多。一般是一年一个点或者半年一个点比较合理。这个项目使用的就是半年一个点。

注意两个点:

  1. 时间点的区间是左闭右开的。
  2. 5years是没有的,因为当前面10个点sigmoid之后的概率都很小,那么就说明患者生存概率很大,归类为 >5years这个点。

因此这个模型的最后是一个 classifier = Linear(embed_dim, 11),生成的harzards就是第i个点发生事件的概率(如果是死亡,那么就是说第i个时间段发生死亡的概率最大)。

基于harzards我们可以得到生存率

hazards = torch.sigmoid(logits)
Surv = torch.cumprod(1 - hazards, dim=1)

Loss函数

基于CrossEntropy多分类建模

通过多分类的方式来建模患者会在哪一个时间点发生事件。
可以看到,用于监督信号的label是Y:shape=(B, k)的,是离散的时序点而不是连续的时间。
因此在使用这个loss的时候,需要将监督信号整理为多分类问题(多个类别选一个)。
好像没见到什么人用,建议使用NLLloss。

def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):batch_size = len(Y)Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,kc = c.view(batch_size, 1).float() #censorship status, 0 or 1if S is None:S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards# without padding, S(0) = S[0], h(0) = h[0]# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]#h[y] = h(1)#S[1] = S(1)S_padded = torch.cat([torch.ones_like(c), S], 1)reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))ce_l = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))loss = (1-alpha) * ce_l + alpha * regloss = loss.mean()return lossclass CrossEntropySurvLoss(object):def __init__(self, alpha=0.15):self.alpha = alphadef __call__(self, hazards, S, Y, c, alpha=None): if alpha is None:return ce_loss(hazards, S, Y, c, alpha=self.alpha)else:return ce_loss(hazards, S, Y, c, alpha=alpha)

基于NLLloss时序分析

仍然是通过多分类的方式来建模患者会在哪一个时间点发生事件。
可以看到,用于监督信号的label是Y:shape=(B, k)的,是离散的时序点而不是连续的时间。
跟CrossEntropyLoss一样,需要将监督信号整理为离散的时序点。
预测的时候使用log当成回归任务进行建模?

def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):batch_size = len(Y)Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,kc = c.view(batch_size, 1).float() #censorship status, 0 or 1if S is None:S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards# without padding, S(0) = S[0], h(0) = h[0]S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]#h[y] = h(1)#S[1] = S(1)uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))neg_l = censored_loss + uncensored_lossloss = (1-alpha) * neg_l + alpha * uncensored_lossloss = loss.mean()return lossclass NLLSurvLoss(object):def __init__(self, alpha=0.15):self.alpha = alphadef __call__(self, hazards, S, Y, c, alpha=None):if alpha is None:return nll_loss(hazards, S, Y, c, alpha=self.alpha)else:return nll_loss(hazards, S, Y, c, alpha=alpha)

基于Cox函数

CoxSurvLoss 的逻辑有所不同。它不使用离散的时间区间 Y,而是直接比较批次内患者的生存时间 S(在这里S代表的是一个时间点,而不是概率)和风险得分 hazards(在这里 hazards 通常是一个 (B, 1) 的风险分数值,而不是 (B, K) 的概率矩阵),来计算偏似然损失。

# 其中S是患者的生存时间
# c是 [1删失,0无删失]
class CoxSurvLoss(object):def __call__(hazards, S, c, **kwargs):# This calculation credit to Travers Ching https://github.com/traversc/cox-nnet# Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics datacurrent_batch_len = len(S)R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)for i in range(current_batch_len):for j in range(current_batch_len):R_mat[i,j] = S[j] >= S[i]R_mat = torch.FloatTensor(R_mat).to(device)theta = hazards.reshape(-1)exp_theta = torch.exp(theta)loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (1-c))return loss_cox
http://www.dtcms.com/a/496670.html

相关文章:

  • 中国正规的加盟网站网站设计的风格有哪些
  • 怎么修改网站图标小企业公司网站建设
  • docker学习(4)容器的生命周期与资源控制
  • 网站建设开发网站案例项目费用电子商务网站建设读书笔记
  • 做推广用那个网站吗室内装修设计软件免费版下载破解版
  • 做网站必须学php吗wordpress改插件难吗
  • SAP MM采购订单推送OA分享
  • 如何线下宣传网站深圳网站建设那家好
  • 豆包谈追星
  • 手机网站开发公司哪家好惠州营销网站建设
  • 选择做华为网站的目的和意义博客登陆wordpress
  • 洛谷 P5718:找最小值 ← if + while
  • 网站美食建设图片素材故事式软文范例500字
  • 装饰网站建设的背景贵阳网页设计培训
  • Vue3 中的 watch 和 watchEffect:如何优雅地监听数据变化
  • 深度学习模型训练的一些常见指标
  • 购物网站建设情况汇报更合公司网站建设
  • 前端+AI:HTML5语义标签(一)
  • 微端边缘设备部署大模型简单笔记
  • wordpress的网站无法发布文章创造一个平台要多少钱
  • 搜索本地存储逻辑
  • 域名解析在线seo网站培训班
  • ASTM C1693-11蒸压加气混凝土检测
  • RAG(检索增强生成)详解:让大模型更“博学”更“靠谱”
  • 我有域名怎么建网站鱼滑怎么制作教程
  • 萧县做网站的公司商城app官方下载
  • 网站弹广告是什么样做的辽阳做网站
  • 德州市德城区城乡建设局网站电子商务网页设计与制作课后作业
  • 网站中的滑动栏怎么做的asp网站路径
  • 深圳专业网站建设公网站建设系统怎么样