Imagine you have JAX code like this, and run it on a machine with CUDA set up: key = jax.random.key(42) cpu0 = jax.devices("cpu")[0] with jax.default_device(cpu0): array = jax.random.randint( key, (530640, 6, 1024), 0, 50_000, dtype=jax.numpy.uint16 ) array.block_until_ready() item = array[0] item.block_until_ready() We're creating a big array, blocking until it's ready (JAX is asynchronous, so this makes sure that it's actually finished creating it), then getting the first item, and as a belt-and-braces thing making sure that that is ready too. How long do you think those last two lines -- a simple retrieval of a 6 x 1024 array from a larger one -- will take? Some tiny fraction of a second would seem reasonable. But running it on my machine just now, the answer is a bit of a surprise: just over 5 seconds. And if you try to get array[1] immediately afterwards, it still takes about 1.2s. Further lookups into array consistently take more than a second -- so while the larger initial…
No comments yet. Log in to reply on the Fediverse. Comments will appear here.