This message was deleted.
# ask-for-help
s
This message was deleted.
🍱 1
a
Hi Rishab, back when I was implementing this, I believe I ran into race condition when using jit, so I decided not to have jit enabled. For jax, I believe we should enable jit, let me try to find the notes I have back then about some qwarks I ran into
additionally, with jit compilation we need to somewhat have support for warmup hook so that it compiled the model before running inference (which I think we can now use
@svc.on_startup
for this)
I haven’t fully battle test jax yet, since for most of the flax use ase I have seen so far, they are related to transformers, so I decided to to go to deep into the details.
r
Hey @Aaron Pham, thanks for getting back to me! Jax inference doesn't make sense without jit so would love to read those notes on the kinks you ran into 😄 I'm working on a PR (code link below) for integrating Ivy (https://unify.ai/) to BentoML which would mean users can get ~3x latency boost by transpiling any of their torch (or tensorflow) model to jax by simply adding a flag, so I have to say usage wouldn't be limited to just the popular Jax transformers 😅 Code: Minimal notebook, check out
inference latency comparision
section - https://github.com/unifyai/BentoML/blob/dev_transpile/examples/pytorch_mnist/pytorch_mnist_demo.ipynb