当前位置: 首页 > news >正文

Sum-rate计算

1.ZF

import torchdef calc_sum_rate_corrected(H: torch.Tensor, soft_mask: torch.Tensor, p: float = 80.0, sigma2: float = 1.0) -> torch.Tensor:"""基于 Top-k 软掩码选择的可导 sum-rate 计算 (修正版)。此函数修正了原实现中的矩阵维度和SINR计算错误。Args:H: 信道矩阵, 形状为 (N, K),N=天线数, K=用户数。soft_mask: 每个用户的选择概率掩码, 形状为 (K,),值在 [0, 1] 之间。p: 总发射功率。sigma2: 噪声功率。Returns:sum_rate: 可导的和速率标量。"""N, K = H.shapedevice = H.device# 确保掩码和信道矩阵的数据类型一致if soft_mask.dtype != H.dtype:soft_mask = soft_mask.to(H.dtype)# 1. 根据soft_mask选出Top-k个最可能的用户topk = min(N, K)# 使用.real以防soft_mask是复数类型_, indices = torch.topk(soft_mask.real, topk)H_sel = H[:, indices]  # 形状: (N, topk)soft_mask_sel = soft_mask[indices]  # 形状: (topk,)# 2. 为了可导性,使用soft_mask对选出的信道进行加权# 这是您原始设计中有意为之的部分,予以保留H_weighted = H_sel * soft_mask_sel.unsqueeze(0)  # 形状: (N, topk)try:# 3. [修正] 计算ZF权重矩阵 WH_H = H_weighted.conj().T  # 形状: (topk, N)# [修正点 #2] Gram矩阵应为 (topk, topk)gram = H_H @ H_weighted# 为提高数值稳定性,加入微小的正则化项reg_eps = 1e-6gram_reg = gram + reg_eps * torch.eye(topk, device=device, dtype=gram.dtype)# [修正点 #3] 正确求解 W,其形状应为 (topk, N)# W 的每一行 w_k 是分配给用户k的波束成形向量W = torch.linalg.solve(gram_reg, H_H)# 4. [修正] 使用向量化方式计算所有用户的SINR# 分子 (Signal Power)# p_u 是分配给每个用户的功率p_u = p / topk# G是均衡后的等效信道,对角线元素是每个用户的信号增益G = W @ H_weighted  # 形状: (topk, topk)signal_gains = torch.diag(G)num = p_u * torch.abs(signal_gains) ** 2# 分母 (Noise Power)# [修正点 #5] 按行(dim=1)求和,计算每个用户权重向量w_k的模长平方noise_power = sigma2 * torch.sum(torch.abs(W) ** 2, dim=1)den = noise_power# 为防止除以零,在分母上增加一个极小值sinr = num / (den + 1e-12)# 5. 计算和速率rate = torch.log2(1 + sinr)return torch.sum(rate)except torch.linalg.LinAlgError:# 如果矩阵奇异无法求解,返回0return torch.tensor(0.0, device=device)

http://www.dtcms.com/a/303419.html

相关文章:

  • 【代码解读】通义万相最新视频生成模型 Wan 2.2 实现解析
  • 同态滤波算法详解:基于频域变换的光照不均匀校正
  • 栈算法之【用栈实现队列】
  • 凸优化:凸函数的一些常用性质
  • OpenLayers 综合案例-量测工具
  • 【Zustand】从复杂到简洁:Zustand 状态管理简化实战指南
  • 图解系统的学习笔记--硬件结构
  • 告别繁琐 Mapper!Stream-Query 正式入驻 GitCode 平台
  • GPFS文件系统更换磁盘
  • 企业级JWT验证最佳方案:StringUtils.hasText()
  • AD中放置过孔阵列
  • Python 异常 (Exception) 深度解析
  • 如何获取我当前的IP地址
  • 掌握 ArkTS 复杂数据绑定:从双向输入到多组件状态同步
  • AWS MemoryDB 可观测最佳实践
  • Python Pandas.merge_ordered函数解析与实战教程
  • 全球首个1米高精度特大城市开放空间数据集(Tif)
  • 力扣刷题977——有序数组的平方
  • 热门JavaScript库“is“等软件包遭npm供应链攻击植入后门
  • “菜鸟的java代码日记“ DAY3——跳跃游戏(中等)
  • DBAPI的SQL实现模糊查询的3种方案
  • [论文阅读] 人工智能 | 机器学习工作流的“救星”:数据虚拟化服务如何解决数据管理难题?
  • 数据结构面经
  • 《中国棒球》cba球队有哪些球队·棒球1号位
  • MySQL 查询重复数据的方式总结
  • 历史版本vscode的下载地址
  • 从黑客松出发,AI + Web3 项目怎么打磨成产品?
  • vue2中实现leader-line-vue连线文章对应字符
  • 事务实现的底层原理
  • SwinTransformer改进(14):集成MLCA注意力机制的Swin Transformer模型