I'm porting my PyTorch LLM code to JAX, using Flax as the neural network layer. For various reasons I wanted to use Safetensors to store checkpoints of the model. It took a little while to get it working; here's the trick I learned. If you look at the Safetensors docs, you'll see that it doesn't mention a JAX implementation -- indeed, searching for "safetensors jax" at the time I'm writing this gives you a link to this GitHub repo by Alvaro Bartolome -- which was last updated in 2023. However, if you look more closely at the docs, they do have a link to the Flax API. I feel this is somewhat misnamed, as it is actually a JAX API. There's no reference (again, as of the time of writing) to Flax in the source -- it's all just JAX code. And in fact Bartolome's library uses it under the hood. There is one problem, though. The API works with simple single-level dictionaries, with strings mapping directly to JAX arrays. For example, the save_file function has this signature: def save_file(…
No comments yet. Log in to reply on the Fediverse. Comments will appear here.