当前位置: 首页 > news >正文

TensorFlow充分并行化使用CPU

关键字:TensorFlow 并行化、TensorFlow CPU多线程

场景:在没有GPU或者GPU性能一般、环境不可用的机器上,对于多核CPU,有时TensorFlow或上层的Keras默认并没有完全利用机器的计算能力(CPU占用没有接近100%),因此想让它通过多线程、并行化充分利用计算资源,提升效率。

1.‌get_inter_op_parallelism_threads(...)‌ 获取用于独立操作之间并行执行的线程数。

  • 此方法用于查询当前配置中,可并行执行多个独立操作(如无依赖关系的运算符)的线程池大小。独立操作间的并行性通过线程池调度实现,适用于计算图中无数据依赖的分支操作‌。

‌2.get_intra_op_parallelism_threads(...)‌ 获取单个操作内部用于并行执行的线程数。

  • 此方法返回单个运算符(如矩阵乘法、卷积等)内部并行计算时使用的线程数。某些复杂运算符可通过多线程加速计算,例如利用多核 CPU 并行处理子任务‌。

‌3.set_inter_op_parallelism_threads(...)‌ 设置用于独立操作之间并行执行的线程数。

  • 通过此方法调整线程池大小,控制独立操作间的并行度。例如,在多个无依赖关系的运算符同时运行时,提高此值可提升整体吞吐量,但需避免过度占用资源导致竞争‌。

‌4.set_intra_op_parallelism_threads(...)‌设置单个操作内部用于并行执行的线程数。

  • 针对支持内部并行的运算符(如 matmul、reduce_sum),此方法设置其内部子任务的最大并行线程数。合理调整此值可优化计算密集型操作的性能,但需考虑 CPU 核心数和实际负载‌。

参考链接: https://www.tensorflow.org/api_docs/python/tf/config/threading

完整写法:tf.config.threading.set_inter_op_parallelism_threads(num_threads)

注意事项‌:线程数设置需在会话初始化前完成,且某些环境变量(如 OMP_NUM_THREADS)可能影响最终效果‌。

import os
# 注意:环境变量需在导入TensorFlow之前设置才能确保生效
os.environ["OMP_NUM_THREADS"] = "1"       # 禁用OpenMP的多线程(由TensorFlow自己管理)
os.environ["KMP_BLOCKTIME"] = "0"         # 设置线程在空闲后立即回收

import tensorflow as tf

def configure_cpu_parallelism(intra_threads=8, inter_threads=2):
    """
    参数说明:
    intra_threads - 控制单个操作内部并行度(如矩阵乘法),建议设为物理CPU核心数
    inter_threads - 控制多个操作间的并行度,建议根据任务类型调整(计算密集/IO密集)
    
    推荐设置:
    对于计算密集型任务,inter_threads建议设为CPU的NUMA节点数或较小数值
    总线程数不应超过CPU逻辑核心数(可通过os.cpu_count()查看)
    """
    try:
        # 设置操作内并行线程数(针对单个操作的多核并行)
        tf.config.threading.set_intra_op_parallelism_threads(intra_threads)

        # 设置操作间并行线程数(针对计算图多个操作的流水线并行)
        tf.config.threading.set_inter_op_parallelism_threads(inter_threads)

    except RuntimeError as e:
        # TensorFlow运行时一旦初始化后无法修改配置
        print(f"配置失败:{str(e)}(请确保在创建任何TensorFlow对象前调用本函数)")

# 示例配置(假设8核CPU)
configure_cpu_parallelism(intra_threads=8, inter_threads=2)

# 验证配置
print("\n验证当前线程配置:")
print(f"Intra-op threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
print(f"Inter-op threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
print(f"物理CPU核心数: {os.cpu_count()}")
print(f"OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', '未设置')}")

相关文章:

  • 国产Linux系统统信安装redis教程步骤
  • 基于SpringBoot的动物救助中心系统(源码+数据库)
  • Zen 5白色装机优选,华硕X870 AYW GAMING WIFI W主板来了!
  • 网工基础 | 常见英文术语注解
  • 新闻推荐系统(springboot+vue+mysql)含万字文档+运行说明文档
  • 参照Spring Boot后端框架实现序列化工具类
  • [特殊字符] 第十一讲 | 空间回归模型实战:SAR / SEM / GWR逐个击破
  • python办公自动化------word文件的操作
  • xlinx GT传输器学习
  • Spring Cloud 远程调用
  • 初识SpringAI(接入硅基流动deepseek)
  • <C#>在 C# .NET 中,使用 LoggerExtensions方法创建日志
  • 修改Todesk软件显示的设备码的办法
  • 前端请求设置credentials: ‘include‘导致的cors问题
  • 网络安全1
  • Git中git rebase 和 git merge使用及区别
  • Python小程序 - 文件处理3:正则表达式
  • 珠江桥牌试吃活动 一酱承粤味谷雨话新炊
  • Elasticsearch 系列专题 - 第四篇:聚合分析
  • LangGraph 使用指南
  • 建设委员会官方网站/制作网页代码大全
  • 学院网站建设项目的成本计划书/最近三天的新闻大事摘抄
  • 扬中做网站的公司/百度网址
  • 做一元购网站/深圳网络营销和推广方案
  • 西安企业网站搭建/百度指数怎么查
  • 视频网站用什么cms/优化seo