Project Summary

For this project, we want a bot to play the Snake game and achieve reasonably high score.

Without ML, we can use rule-based approach using a hard-coded fixed set of rules. For example, we can use an algorithm that will pick the shortest path to food. However, as the snake grows, said algorithm will not be able to avoid biting itself. Of course, we can program a new set of rules to tackle this specific scenario, but even so, this shows that rule-based approach lack adaptability and scalability. For a game with large state space like Snake, we cannot satisfy all possible scenarios.

Reinforcement learning (RL) overcomes these limitations by learning effective strategies directly from experience, without requiring explicit programming of game-specific heuristics. Unfortunately, RL is not a silver bullet to this problem. In fact, our RL implementation did not yield satisfying result. Nevertheless, taking all of the above factors into consideration, we believe that RL is the right step toward a better solution to solving Snake game.

Approach

1. General Overview

For this project, we employ Reinforcement Learning (RL) to train an AI agent to play the Snake game effectively. Given the large state-space of the game (numerous possible configurations of the snake and fruit on the grid), we opt for policy-based, model-free RL methods. Specifically, we consider two algorithms: Advantage Actor-Critic (A2C) and Proximal Policy Optimization (PPO), both implemented via the Stable-Baselines3 library. These algorithms directly learn a policy (a mapping from states to actions) without modeling the environment’s dynamics, making them suitable for our scenario.

2. Setting Up the Snake Game Environment

We use Gymnasium to create a custom RL environment for the Snake game. The environment is defined as follows:

The Gymnasium environment includes the following key methods:

This setup provides a clean interface for the RL algorithms to interact with the game.

3. Training Process

3.1. PHASE 1

We train two separate models using A2C and PPO from Stable-Baselines3 to compare their performance. Both algorithms process the observation space (the 3D grid) as an image input, using a convolutional neural network (CNN) as a feature extractor to convert the raw grid into a feature vector representing the current state s. This vector is then fed into a policy network to select an action a. The policy network architecture follows the default CNN setup in Stable-Baselines3, as shown below:

Source: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html

Below, we explain how A2C and PPO work and how we apply them to the Snake game.

3.1.1. Training with A2C

A2C is an actor-critic method that combines policy-based and value-based RL. It uses two networks:

The algorithm updates the actor using the advantage function, defined as: [ A(s, a) = Q(s, a) - V(s) ] where ( Q(s, a) ) is the action-value function (approximated via the reward received and bootstrapped future rewards), and ( V(s) ) is the state-value function from the critic. The advantage measures how much better an action is compared to the average action in that state.

The actor’s policy is updated by maximizing the objective: [ J(θ) = E[log π(a|s; θ) * A(s, a)] ] using gradient ascent. Meanwhile, the critic minimizes the loss: [ L(w) = E[(r + γ * V(s’) - V(s))^2] ] where r is the reward, γ is the discount factor (set to 0.99 by default), and s' is the next state.

For the Snake game, A2C takes the 3D observation grid, extracts features with the CNN, and outputs action probabilities (e.g., 25% Up, 25% Down, etc.). We train for 10 million timesteps, using Stable-Baselines3’s default hyperparameters: learning rate = 0.0007, n_steps = 5 (steps per update), and and our gamma = 0.9. These values are sourced from the library’s documentation, and we did not tune them further due to computational constraints.

3.1.2. Training with PPO

PPO is a more stable policy gradient method that improves on A2C by constraining policy updates to avoid large, destabilizing changes. Like A2C, it uses an actor-critic framework but introduces a clipped objective function to limit how much the policy can change in each update.

The PPO objective is: [ J(θ) = E[min(r(θ) * A(s, a), clip(r(θ), 1-ε, 1+ε) * A(s, a))] ] where:

The clipping ensures that the policy doesn’t deviate too far from the previous version, improving training stability. The critic’s loss is the same as in A2C.

For the Snake game, PPO processes the observation grid similarly to A2C, using the same CNN feature extractor. We train for 10 million timesteps with default hyperparameters from Stable-Baselines3: n_steps = 2048, clip_range = 0.2, and our gamma = 0.9 and learning rate = 0.00025. These defaults are well-documented and widely used, so we kept them unchanged.

3.1.3. Comparison and Limitation

3.2. PHASE 2

Using the same A2C algorithm, we train different models for different scenarios to see how they perform. However, we use our custom CNN architecture instead of the default one by SB3. We also made a few adjustment to our training environment due to technical limitations. Details regarding these adjustments can be found in the “Hyper-parameters and Reproducibility” section down below.

Here are the scenarios that we trained our models in:

  1. Arena default: empty 8x8 arena.
  2. Arena with an extra fruit:
    • Every time a normal fruit spawn, there is a 50% chance of spawning an extra fruit.
    • If there is no extra fruit spawned, then the reward for the normal fruit will be +1.
    • If there is extra fruit spawned, then the reward for the normal fruit will be +0.2, and the reward for the extra fruit will be +0.8.
    • If the normal fruit is eaten before the extra fruit, the extra fruit will disappear.
  3. Arena with a rectangle obstacle in the middle.
    • For this scenario, we found out that removing the penalty for terminal state will help the model learn better.
    • As such, we trained our model using the default setup with no penalty for terminal state.
  4. Arena split by half by a wall with a hole in the middle.
    • For this scenario, we found out that removing the penalty for terminal state will help the model learn better.
    • As such, we trained our model using the default setup with no penalty for terminal state.

3.2.1. Custom CNN Architecture

For Phase 2, we developed a custom CNN architecture inspired by both ResNet and the Llama model architecture, incorporating modern deep learning techniques to improve performance. Our custom architecture features several key components:

  1. SwiGLU Activation Function: Instead of traditional ReLU activations, we implemented SwiGLU (Swish-Gated Linear Unit) as described in the “GLU Variants Improve Transformer” paper. SwiGLU combines the benefits of gating mechanisms with smooth activation functions, allowing for better gradient flow during training.

  2. Residual Connections: Following ResNet’s design philosophy, we incorporated residual connections that enable the network to learn identity mappings more easily, helping to address the vanishing gradient problem in deeper networks. These connections allow information to flow directly from earlier layers to later ones, facilitating both faster training and better performance.

  3. Enhanced Feature Extraction:

    • A larger initial convolutional layer with a 7×7 kernel to capture more spatial information
    • Multiple residual blocks to maintain and refine feature representations
    • Strategic max pooling to reduce spatial dimensions while preserving important features
    • A final SwiGLU-activated linear layer to produce the feature vector

The architecture processes the game state (represented as a grid) through these components to extract meaningful patterns that inform the agent’s policy and value functions.

In our experiments, this custom architecture demonstrated significantly faster learning compared to the default SB3 implementation used in Phase 1. The model achieved approximately 7 points on average in just 2.5 million training steps, compared to approximately 5 points after 10 million steps with the original architecture.

This dramatic improvement in early training efficiency suggests that the residual connections help the network learn basic game patterns more quickly, while the SwiGLU activations may provide better gradient flow that facilitates more effective policy updates. While the final performance after extended training may converge to similar levels, the custom architecture’s ability to reach good performance with 75% fewer training steps represents a substantial improvement in computational efficiency.

4. Hyper-parameters and Reproducibility

4.1. PHASE 1

The models are trained using A2C and PPO, and the default CNN architecture from Stable-Baselines3 (two convolutional layers followed by a fully connected layer). Our hyper-parameters for A2C and PPO models are:

We train each model for 10 million timesteps, which corresponds to roughly 10 million interactions with the environment (though episodes terminate earlier upon failure). The training data is generated on-the-fly by the Gymnasium environment, so no external dataset is required. For reproducibility, we set a random seed of 42 in both the environment and the algorithms.

4.2. PHASE 2

Similar to phase 1, except:

Evaluation

1. PHASE 1

We plan to evaluate performance by comparing the average score (total fruit eaten per episode) across 100 test episodes for each model after training.

1.1. A2C model
1.2. PPO model
1.3. Observation
Figure 5. Mean episode's length over 10 mil training steps of PPO model (grey) vs. A2C model (green)

2. PHASE 2

We plan to evaluate performance of our models using the average score (total fruit eaten per episode) across 400 test episodes, as well as the statistics during the training process.

2.1. Arena default:
2.2. Arena with rectangle obstacle:
2.3. Arena split by half by a wall with a hole:
2.4. Arena with extra fruit:

Resources Used:

AI Tool Usage: