php网站后台入口北京seo优化诊断
在 JAX 中,使用 @partial(jax.jit) 装饰器可以将函数编译为高效的 JIT(Just-In-Time)代码。然而,JAX 的 JIT 编译器对函数中的操作有严格的要求,尤其是对非 JAX 函数的支持有限。如果你需要在 JIT 编译的函数中调用非 JAX 函数,可以使用 jax.pure_callback 来实现。
import jax
from functools import partial
import get_text_embeddingdef emb(id,di):id = max(id,di)if id == 0:text = "k"else:text = "s"emb = get_text_embedding(text)return emb@partial(jax.jit)
def get_embedding(id,di):# 定义回调函数def callback(id,di):return emb(id,di)# 使用 jax.pure_callback 调用回调函数result = jax.pure_callback(callback,jnp.zeros((1,384)), # 提供一个示例输出形状和类型id,di # 输入参数)return resultif __name__ == "__main__":id = 0di = 1emb = get_embedding(id,di)print(emb.shape)