2 hours ago · Tech · 0 comments

I was debugging an issue with a JAX/Flax NNX training loop the other day, and found a neat little trick to help debug it. Specifically, I wanted to see if the issue was with my model, my loss function, my optimiser settings, or the "plumbing" of the training loop itself -- were gradients actually coming through and being applied to the parameters? I could print out the loss and the gradients, but printing out the parameters to see if they were changing was unhelpful -- any given update might only change a small number of parameters, or might change them such a small amount that I'd not notice -- especially given that the model had 77 million of them! Let's take a look. The world's worst LLM I am building an LLM from scratch in JAX and Flax NNX, and at this stage I'm trying to get the training loop right. As a simple test, I've just implemented the "shell" of the LLM -- the token embeddings on the input side, and the final linear layer for an output head, wired directly together. My…

No comments yet. Log in to reply on the Fediverse. Comments will appear here.