在 JIT 编译的函数中调用非 JAX 函数
在 JAX 中,使用 @partial(jax.jit) 装饰器可以将函数编译为高效的 JIT(Just-In-Time)代码。然而,JAX 的 JIT 编译器对函数中的操作有严格的要求,尤其是对非 JAX 函数的支持有限。如果你需要在 JIT 编译的函数中调用非 JAX 函数,可以使用 jax.pure_callback 来实现。
import jax
from functools import partial
import get_text_embedding
def emb(id,di):
id = max(id,di)
if id == 0:
text = "k"
else:
text = "s"
emb = get_text_embedding(text)
return emb
@partial(jax.jit)
def get_embedding(id,di):
# 定义回调函数
def callback(id,di):
return emb(id,di)
# 使用 jax.pure_callback 调用回调函数
result = jax.pure_callback(
callback,
jnp.zeros((1,384)), # 提供一个示例输出形状和类型
id,di # 输入参数
)
return result
if __name__ == "__main__":
id = 0
di = 1
emb = get_embedding(id,di)
print(emb.shape)