当前位置: 首页 > news >正文

在 JIT 编译的函数中调用非 JAX 函数

在 JAX 中,使用 @partial(jax.jit) 装饰器可以将函数编译为高效的 JIT(Just-In-Time)代码。然而,JAX 的 JIT 编译器对函数中的操作有严格的要求,尤其是对非 JAX 函数的支持有限。如果你需要在 JIT 编译的函数中调用非 JAX 函数,可以使用 jax.pure_callback 来实现。

import jax
from functools import partial
import get_text_embedding

def emb(id,di):
		id = max(id,di)
    if id == 0:
        text = "k"
    else:
        text = "s"
    emb = get_text_embedding(text)
    return emb
    
@partial(jax.jit)
def get_embedding(id,di):
    # 定义回调函数
    def callback(id,di):
        return emb(id,di)

    # 使用 jax.pure_callback 调用回调函数
    result = jax.pure_callback(
        callback,
        jnp.zeros((1,384)),  # 提供一个示例输出形状和类型
        id,di # 输入参数
    )
    return result

if __name__ == "__main__":
    id = 0
    di = 1
    emb = get_embedding(id,di)
    print(emb.shape)
http://www.dtcms.com/a/77390.html

相关文章:

  • OpenAI Agents SDK 使用自定义的 OpenAI-Compatible API
  • 将对象内的键值转换为响应式变量后,在setup函数中用这些属性的时候为什么不用像ref那样加value
  • 冯・诺依曼架构深度解析
  • WPF-实现按钮的动态变化
  • OMRON Corporation Programming Contest 2025 (AtCoder Beginner Contest 397)题解
  • 对接豆包大模型
  • SvelteKit 最新中文文档教程(6)—— 状态管理
  • 【微服务】基于Lambda ESM的预留模式调整Kafka ESM吞吐量的实战
  • 【海螺AI视频】蓝耘智算 | AI视频新浪潮:蓝耘MaaS与海螺AI视频创作体验
  • leetcode33.搜索旋转排序数组-medium
  • 【八股文】volatile关键字的底层原理是什么
  • 实现搜索功能:第一部分
  • 穿越是时空之门(java)
  • Ubuntu安装TensorFlow 2.13-GPU版全流程指南(anaconda)
  • golang中的接口
  • 【Java进阶学习 第九篇】常用API(Array、冒泡选择排序、二分查找、正则表达式)
  • 【C++进阶】指针:从基础到实践
  • Leetcode Hot 100 79.单词搜索
  • 【spring对bean Singleton和Prototype的管理流程】
  • 英伟达GTC 2025大会产品全景剖析与未来路线深度洞察分析
  • 小程序开发中的安全问题及防护措施
  • 蓝桥与力扣刷题(蓝桥 组队)
  • E1-相亲派对(组合)
  • 【AI News | 20250319】每日AI进展
  • @Resource和@Autowire
  • Java 中 LinkedList 的底层数据结构及相关分析
  • 【源码阅读】多个函数抽象为类(实现各种类型文件转为PDF)
  • UE4学习笔记 FPS游戏制作6 添加枪口特效
  • 详细解析GetOpenFileName()
  • Vue3 核心特性解析:Suspense 与 Teleport 原理深度剖析