This article is participating in Python Theme Month. See the link to the event for more details

Description: This program is based on reinforcement learning, training snakes to eat foods that are present in their environment.

A sample GIF is given below to give you an idea of what we’re going to build.

Ai-powered snake

To see how we manually built this snake 2D animation simulation using PyGame, click on the link:

How to create a snake game in Python | Python Theme month

Having built the basic snake game, we will now focus on how reinforcement learning can be applied to it.

We must create three modules in this project:

  1. Environment (the game we just built)
  2. Model (enhanced model for movement prediction)
  3. Broker (mediation between environment and model)

Module link

Algorithm:

We randomly placed snakes and food on the board.

  • Use 11 values to calculate the state of the snake. Set the value to 0 if any of the conditions are true, and 1 otherwise.

How do you define 11 states

The agent calculates 11 status values based on the current Head location, as described above.

  • Once these states are obtained, the agent passes them to the model and performs the next step.

  • The reward is calculated after the next state is executed. Rewards are defined as follows:

    • Eat food: +10
    • Game over: -10
    • Other: 0
  • Update the Q value (discussed later) and train the model.

  • Having analyzed the algorithm, we must now build ideas to continue coding the algorithm.

The model:

Neural network model

The model is designed using the Pytorch, but you can also use TensorFlow according to your own comfort.

We are using a dense neural network with 11 size input layers and a dense layer with 256 neurons and 3 neuron outputs. You can adjust these hyperparameters to get the best results.

How does the model work?

  • The game starts, and the Q value is randomly initialized.
  • The system obtains the current status s.
  • It performs an action based on S, randomly or based on its neural network. In the first phase of training, the system often selects random movements to maximize exploration. Later, the system relied more and more on its neural network.
  • When the AI selects and performs an action, the environment rewards it. The agent then reaches the new state and updates its Q value according to the Behrman equation.

Behrman equation

  • In addition, for each step, it stores the original state, the action, the state reached after performing that action, the reward earned, and whether the game is over. This data is then sampled to train the neural network. This operation is called replay memory.
  • Repeat the last two actions until a condition (e.g., game over) is met.

The core of the project is the model you will be training, because the correctness of the actions the snake will take depends entirely on the quality of the model you are building. So I want to explain this to you in part of the code.

The first part

  1. Create a class called Linear_Qnet to initialize the linear neural network.
  2. The forward function is used to take the input (11 state vectors) through the neural network and apply the RELu activation function and return the output to the next move 1 x 3 vector size. In short, this is the predictive function that the agent will call.
  3. The save function is used to save the trained model for later use.
class Linear_QNet(nn.Module) :
	def __init__(self, input_size, hidden_size, output_size) :
		super().__init__()
		self.linear1 = nn.Linear(input_size, hidden_size)
		self.linear2 = nn.Linear(hidden_size, output_size)

	def forward(self, x) :
		x = F.relu(self.linear1(x))
		x = self.linear2(x)
		return x

	def save(self, file_name='model_name.pth') :
		model_folder_path = 'Path'
		file_name = os.path.join(model_folder_path, file_name)
		torch.save(self.state_dict(), file_name)
Copy the code

The second part

1. Initialize QTrainer โˆ— to set the learning rate of the optimizer.

  • The Gamma value is the discount rate used in the Behrman equation.
  • Initialize the Adam optimizer to update weights and biases.
  • The criterion is the mean square loss function.

2. Train_step function

  • As you know, PyTorch only works for tensors, so we are converting all inputs to tensors.
  • As mentioned above, we did a short memory exercise, and then we passed only one value

States, actions, rewards, moves so we need to convert them into a vector, so we use uncompressed functionality.

  • Get the state from the model and calculate the new Q value using the following formula:
Q_new = reward + Gamma * Max (next_predicted Qvalue)Copy the code
  • Calculate the sum of mean square errors between the new Q value and the previous Q value

The loss is propagated back to make weight updates.

class QTrainer:
	def __init__(self,model,lr,gamma):
#Learning Rate for Optimizer
		self.lr = lr
#Discount Rate
		self.gamma = gamma
#Linear NN defined above.
		self.model = model
#optimizer for weight and biases updation
		self.optimer = optim.Adam(model.parameters(),lr = self.lr)
#Mean Squared error loss function
		self.criterion = nn.MSELoss()
		

	
	def train_step(self,state,action,reward,next_state,done):
		state = torch.tensor(state,dtype=torch.float)
		next_state = torch.tensor(next_state,dtype=torch.float)
		action = torch.tensor(action,dtype=torch.long)
		reward = torch.tensor(reward,dtype=torch.float)

#only one parameter to train, \ Hence convert to tuple of shape(1, x)
		if(len(state.shape) == 1) : # (1, x)
			state = torch.unsqueeze(state,0)
			next_state = torch.unsqueeze(next_state,0)
			action = torch.unsqueeze(action,0)
			reward = torch.unsqueeze(reward,0)
			done = (done, )

# 1. Predicted Q value with current state
		pred = self.model(state)
		target = pred.clone(a)for idx in range(len(done)):
			Q_new = reward[idx]
			if not done[idx]:
				Q_new = reward[idx] +
				self.gamma * torch.max(self.model(next_state[idx]))
			target[idx][torch.argmax(action).item()] = Q_new
# 2. Q_new = reward + gamma * max(next_predicted Qvalue)
#pred.clone()
#preds[argmax(action)] = Q_new
		self.optimer.zero_grad()
		loss = self.criterion(target,pred)
		loss.backward(a)# backward propogation of loss

		self.optimer.step(a)Copy the code

The agent

  • Gets the current state of the snake from the environment.
def get_state(self, game) :
	head = game.snake[0]
	point_l = Point(head.x - BLOCK_SIZE, head.y)
	point_r = Point(head.x + BLOCK_SIZE, head.y)
	point_u = Point(head.x, head.y - BLOCK_SIZE)
	point_d = Point(head.x, head.y + BLOCK_SIZE)

	dir_l = game.direction == Direction.LEFT
	dir_r = game.direction == Direction.RIGHT
	dir_u = game.direction == Direction.UP
	dir_d = game.direction == Direction.DOWN

	state = [
		# Danger Straight
		(dir_u and game.is_collision(point_u))or
		(dir_d and game.is_collision(point_d))or
		(dir_l and game.is_collision(point_l))or
		(dir_r and game.is_collision(point_r)),

		# Danger right
		(dir_u and game.is_collision(point_r))or
		(dir_d and game.is_collision(point_l))or
		(dir_u and game.is_collision(point_u))or
		(dir_d and game.is_collision(point_d)),

		# Danger Left
		(dir_u and game.is_collision(point_r))or
		(dir_d and game.is_collision(point_l))or
		(dir_r and game.is_collision(point_u))or
		(dir_l and game.is_collision(point_d)),

		# Move Direction
		dir_l,
		dir_r,
		dir_u,
		dir_d,

		# Food Location
		game.food.x < game.head.x, # food is in left
		game.food.x > game.head.x, # food is in right
		game.food.y < game.head.y, # food is up
		game.food.y > game.head.y # food is down
	]
	return np.array(state, dtype=int)
Copy the code
  • Call the model to get the next state of the snake
def get_action(self, state) :
	Tradeoff explotation/exploitation
	self.epsilon = 80 - self.n_game
	final_move = [0.0.0]
	if(random.randint(0.200) < self.epsilon):
		move = random.randint(0.2)
		final_move[move] = 1
	else:
		state0 = torch.tensor(state, dtype=torch.float).cuda()
		prediction = self.model(state0).cuda() # prediction by model
		move = torch.argmax(prediction).item()
		final_move[move] = 1
	return final_move
Copy the code

Note: There is a tradeoff between development and exploration. Development involves making assumptions about what is best based on currently observed data. Exploration is about making random decisions without taking into account previous actions and reward pairs. Therefore, however necessary, considering exploiting vulnerabilities can result in agents being unable to explore the entire environment, and exploration may not always provide a better reward.

  • Play the model prediction steps in the environment.
  • Stores current state, moves performed, and rewards.
  • Training models for rewards based on the movement and environment performed. (Training short memory)
def train_short_memory(self, state, action, reward, next_state, done) :
	self.trainer.train_step(state, action, reward, next_state, done)
Copy the code
  • If the game ends with hitting a wall or a body, train the model and reset the environment based on all the moves performed so far. (Training long memory). Train at a batch size of 1000.
def train_long_memory(self) :
	if (len(self.memory) > BATCH_SIZE):
		mini_sample = random.sample(self.memory, BATCH_SIZE)
	else:
		mini_sample = self.memory
	states, actions, rewards, next_states, dones = zip(*mini_sample)
	self.trainer.train_step(states, actions, rewards, next_states, dones)
Copy the code

The training model takes about 100 periods to get better performance. Check my training progress.

Output:

  • To run this game, first create an environment at the Anaconda prompt or (on any platform). You then install the necessary modules, such as Pytorch (for the DQ learning model), Pygame (for the visual effects of the game), and other basic modules.
  • Then start training by running the agent.py file in the environment you just created. You will see the following two GUIs, one for the training progress and one for the AI-driven Snake game.
  • After a certain score is reached, you can exit the game, and the newly trained model is saved in the path defined in the save function of models.py.

In the future, you can use this training model simply by changing the code in the agent.py file, as follows:

self.model.load_state_dict(torch.load('PATH'))
Copy the code

Note: Comment out all training function calls.

The training schedule

First-generation version

Second generation version

Source code: SnakeGameAI

Application:

The goal of the project is to come up with an idea of how reinforcement learning can be applied and how it can be used in real-world applications such as self-driving cars (e.g. AWS DeepRacer), training robots on assembly lines, etc…

Tip:

  • Use a separate environment and install all required modules. (You can use the Anaconda environment)
  • To train the model, you can use the GPU for faster training.

Quick summary — AN AI-driven snake game using deep Q learning

I hope this tutorial series has been helpful to you, and the bloggers are still learning, and I hope to correct anything that went wrong. If you enjoyed this article and are interested in seeing more of it, you can check it out here (Github/Gitee). This is a summary of all my original works and source code. Follow me for more information.

More on this at ๐Ÿงต

  • Python exception handling | Python theme month
  • Python Multithreading tutorial | Python theme month
  • Python Socket programming essentials | Python theme month
  • 30 Python tutorials and tips | Python theme month
  • Python statements, expressions, and indentation | Python theme month
  • Python keywords, identifiers, and variables | Python theme month
  • How to write comments and multi-line comments in Python | Python Theme Month
  • Learn about Python numbers and type conversions with examples | Python theme month
  • Python data types — Basic to advanced learning | Python topic month
  • Object-oriented programming in Python – classes, Objects and Members | Python theme month

๐Ÿฐ Past excellent articles recommended:

  • 20 Python Tips Everyone must Know | Python Theme month
  • 100 Basic Python Interview Questions Part 1 (1-20) | Python topic month
  • 100 Basic Python Interview Questions Part 2 (21-40) | Python topic month
  • 100 Basic Python Interview Questions Part 3 (41-60) | Python topic month
  • 100 Basic Python Interview Questions Part 4 (61-80) | Python topic month
  • 100 Basic Python Interview Questions Part 5 (81-100) | Python topic month

If you do learn something new from this post, like it, bookmark it and share it with your friends. ๐Ÿค— Finally, don’t forget โค or ๐Ÿ“‘ for support