In this article we’ll learn how to train a neural network to play Tic-Tac-Toe using reinforcement learning in Jax. This article will aim to be more pedagogical, so the code we’ll end up with won’t be super optimized, but it will be fast enough to train a model to perfect play in about 15 seconds on a laptop. Code from this page can be found at this Github repo as well as in a Colab notebook (although the Colab notebook runs considerably more slowly). Playing Tic-Tac-Toe in Jax Before we get to the fancy neural networks and reinforcement learning we’ll first look at how a Tic-Tac-Toe game might be represented using Jax. For this we’ll use the PGX library, which implements a number of games in pure Jax. PGX represents a game’s state with a dataclass called State. This dataclass has a couple of fields: current_player: This is simply a 0 or a 1 and alternates on every turn. What is perhaps confusing about this is that there is no relationship between player 0 and an X or an O. Player 0 is…
No comments yet. Log in to reply on the Fediverse. Comments will appear here.