1 hour ago · Tech · 0 comments

There's nothing like writing your own code with a framework to clarify how things fit together! Continuing with my port of my PyTorch LLM code to JAX, I wanted to load up a large dataset: the 10,248,871,837 16-bit unsigned integers in the train split of gpjt/fineweb-gpt2-tokens. That's just over 19GiB of data. from safetensors.flax import load_file ... full_dataset = load_file(dataset_dir / f"train.safetensors")["tokens"] When I ran that, I got a CUDA out-of-memory error: jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 19.09GiB. That makes sense! The allocation it was trying to do is exactly the size of the data I was trying to load. I have an RTX 3090 with 24 GiB, but some is already used up by the OS, various apps, and a model that the code creates earlier on. But in PyTorch land, I was used to things being loaded into RAM by default, and only moved over to the GPU when I asked it to do that. JAX was clearly loading to the GPU by default. How…

No comments yet. Log in to reply on the Fediverse. Comments will appear here.