Skip to main content

Overview

The DOOM Neuron project uses Proximal Policy Optimization (PPO) to train biological neurons through reinforcement learning. The encoder learns to generate optimal stimulation patterns, while the decoder maps neural responses (spikes) to game actions.

PPO Algorithm

Core Concept

PPO is an on-policy reinforcement learning algorithm that:
  1. Collects experience rollouts by interacting with the environment
  2. Estimates advantages using Generalized Advantage Estimation (GAE)
  3. Updates the policy using clipped surrogate objective
  4. Constrains updates to prevent catastrophic policy collapse

Training Loop (training_server.py)

for episode in range(max_episodes):
    # Phase 1: Collect rollout (2048 steps)
    for step in range(steps_per_update):
        # Encoder: obs → stimulation parameters
        frequencies, amplitudes, enc_logprob, enc_entropy = policy.sample_encoder(obs)
        
        # Apply stimulation to neurons
        policy.apply_stimulation(stim_socket, frequencies, amplitudes)
        
        # Collect neural responses
        spike_counts = policy.collect_spikes(spike_socket)
        
        # Decoder: spikes → actions
        actions, dec_logprob, dec_entropy = policy.decode_spikes_to_action(spike_counts)
        
        # Execute action in environment
        next_obs, reward, done, info = env.step(actions)
        
        # Store transition
        buffer.store(obs, actions, reward, logprob, value, spike_counts, frequencies, amplitudes)
    
    # Phase 2: Compute advantages using GAE
    advantages, returns = compute_gae_returns(rewards, values, gamma=0.997, gae_lambda=0.95)
    
    # Phase 3: PPO update (4 epochs)
    for epoch in range(num_epochs):
        for batch in minibatches(batch_size=256):
            # Re-evaluate actions under current policy
            logprobs, values, entropy, enc_logprob, enc_entropy = policy.evaluate_actions(
                batch.spike_features,
                batch.actions,
                batch.obs,
                batch.stim_frequencies,
                batch.stim_amplitudes
            )
            
            # PPO clipped loss
            policy_loss = compute_policy_loss(logprobs, batch.old_logprobs, advantages)
            value_loss = compute_value_loss(values, returns)
            entropy_bonus = entropy.mean() * entropy_coef
            encoder_entropy_penalty = enc_entropy.mean() * encoder_entropy_coef
            
            loss = policy_loss + value_loss_coef * value_loss - entropy_bonus + encoder_entropy_penalty
            
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), max_grad_norm)
            optimizer.step()

Policy Gradients

Encoder Gradients

The encoder network learns to generate stimulation parameters via policy gradients through Beta distributions:
# From EncoderNetwork.sample() in ppo_doom.py:495-515
class EncoderNetwork(nn.Module):
    def sample(self, obs, deterministic=False):
        # Compute Beta distribution parameters
        features = self._combine_features(obs)  # Process obs with CNN + MLP
        
        freq_alpha = F.softplus(self.freq_alpha_head(features)) + 1.0
        freq_beta = F.softplus(self.freq_beta_head(features)) + 1.0
        amp_alpha = F.softplus(self.amp_alpha_head(features)) + 1.0
        amp_beta = F.softplus(self.amp_beta_head(features)) + 1.0
        
        # Create Beta distributions
        freq_dist = Beta(freq_alpha, freq_beta)
        amp_dist = Beta(amp_alpha, amp_beta)
        
        # Sample and scale to physical ranges
        freq_u = freq_dist.rsample()  # Reparameterization trick
        amp_u = amp_dist.rsample()
        
        frequencies = min_frequency + freq_u * (max_frequency - min_frequency)  # 4-40 Hz
        amplitudes = min_amplitude + amp_u * (max_amplitude - min_amplitude)    # 1.0-2.5 μA
        
        # Compute log probabilities with Jacobian correction
        freq_log_prob = freq_dist.log_prob(freq_u) - log(freq_range)
        amp_log_prob = amp_dist.log_prob(amp_u) - log(amp_range)
        
        return frequencies, amplitudes, freq_log_prob, amp_log_prob, freq_entropy, amp_entropy
The encoder uses Beta distributions instead of Gaussian because stimulation parameters are naturally bounded (frequency: 4-40 Hz, amplitude: 1.0-2.5 μA). Beta distributions provide better sample efficiency in bounded spaces.

Decoder Gradients

The decoder receives gradients through its action logits:
# From DecoderNetwork.forward() in training_server.py:627-629
class DecoderNetwork(nn.Module):
    def forward(self, spike_features):
        # Linear readout: spikes → action logits
        head_input = self.shared(spike_features) if self.use_mlp else spike_features
        joint_logits = self.joint_head(head_input)  # (batch, 54) for combinatorial actions
        return joint_logits
Action probabilities flow through the decoder:
spike_counts → decoder → action_logits → Categorical → sampled_action
                   ↑                           ↓
                   └───── policy gradient ──────┘
By setting decoder_zero_bias=True (default), the decoder is forced to rely entirely on neural spike activity rather than learning a bias-driven policy. This ensures biological neurons are genuinely controlling the agent.

Generalized Advantage Estimation (GAE)

Advantage Computation

GAE computes advantage estimates that balance bias and variance:
# Pseudocode from PPO implementation
def compute_gae(rewards, values, gamma=0.997, gae_lambda=0.95):
    """
    Compute GAE advantages and returns.
    
    Args:
        rewards: List of rewards [r_0, r_1, ..., r_T]
        values: Value estimates [V(s_0), V(s_1), ..., V(s_T)]
        gamma: Discount factor (0.997)
        gae_lambda: GAE lambda parameter (0.95)
    """
    advantages = []
    gae = 0
    
    # Backward pass through episode
    for t in reversed(range(len(rewards))):
        # TD error: δ_t = r_t + γ V(s_{t+1}) - V(s_t)
        delta = rewards[t] + gamma * values[t+1] - values[t]
        
        # GAE: A_t = δ_t + (γλ) δ_{t+1} + (γλ)² δ_{t+2} + ...
        gae = delta + gamma * gae_lambda * gae
        advantages.insert(0, gae)
    
    # Returns: R_t = A_t + V(s_t)
    returns = [adv + val for adv, val in zip(advantages, values)]
    
    return advantages, returns

Hyperparameters

  • γ (gamma) = 0.997: Discount factor for long-term rewards
    • Higher than typical 0.99 to encourage survival in DOOM
    • Balances immediate kills vs. staying alive
  • λ (lambda) = 0.95: GAE smoothing parameter
    • Controls bias-variance tradeoff
    • λ=0: high bias, low variance (TD learning)
    • λ=1: low bias, high variance (Monte Carlo)
Why GAE?Standard advantage estimation faces a dilemma:
  • TD errors (1-step): Low variance but biased by value estimates
  • Monte Carlo returns (full episode): Unbiased but high variance
GAE creates a smooth interpolation:
A^(1) = δ_t                                    # 1-step (high bias)
A^(2) = δ_t + γλ δ_{t+1}                      # 2-step
A^(3) = δ_t + γλ δ_{t+1} + (γλ)² δ_{t+2}     # 3-step
...
A^(∞) = Σ (γλ)^k δ_{t+k}                      # Full return (low bias)
Lambda=0.95 provides a sweet spot for DOOM, where:
  • Value function learns to predict episode outcomes
  • Advantages capture meaningful state-action quality differences
  • Updates remain stable despite sparse rewards

Loss Functions

Policy Loss (Clipped Surrogate Objective)

def compute_policy_loss(logprobs, old_logprobs, advantages, clip_epsilon=0.2):
    """
    PPO clipped surrogate objective.
    
    Args:
        logprobs: Log probabilities under current policy π_θ
        old_logprobs: Log probabilities under old policy π_{θ_old}
        advantages: GAE advantage estimates
        clip_epsilon: Clipping range (0.2)
    """
    # Importance sampling ratio: r_t(θ) = π_θ(a_t|s_t) / π_{θ_old}(a_t|s_t)
    ratio = torch.exp(logprobs - old_logprobs)
    
    # Unclipped objective: r_t(θ) * A_t
    unclipped_obj = ratio * advantages
    
    # Clipped objective: clip(r_t(θ), 1-ε, 1+ε) * A_t  
    clipped_obj = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    
    # Take minimum (pessimistic bound)
    policy_loss = -torch.min(unclipped_obj, clipped_obj).mean()
    
    return policy_loss
The clipping mechanism prevents large policy updates that could break the neural interface. With biological neurons, catastrophic policy changes could overstimulate channels and require retraining from scratch.

Value Loss (MSE with optional normalization)

def compute_value_loss(values, returns, normalize_returns=False):
    """
    Mean squared error between value predictions and returns.
    
    Args:
        values: V(s) predictions from critic
        returns: GAE returns (advantages + values)
        normalize_returns: Whether to normalize returns (default: False)
    """
    if normalize_returns:
        # Normalize for stable training
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
    
    value_loss = F.mse_loss(values, returns)
    return value_loss
Configuration: normalize_returns=False (disabled per Doom Initial Report)

Entropy Bonus

Entropy regularization encourages exploration:
# From PPOConfig
entropy_coef = 0.02               # Decoder entropy bonus
encoder_entropy_coef = -0.10      # Encoder entropy penalty

# Total entropy term
entropy_term = (
    entropy_coef * decoder_entropy.mean() +  # Encourage decoder exploration
    encoder_entropy_coef * encoder_entropy.mean()  # Penalize encoder entropy
)

total_loss = policy_loss + value_loss_coef * value_loss - entropy_term
The negative encoder entropy coefficient (-0.10) penalizes high entropy in stimulation parameters. This encourages the encoder to be more deterministic once it finds effective stimulation patterns, reducing noise in the neural interface.

Reward Shaping

Base Rewards from VizDoom

# From VizDoomEnv configuration
game.set_kill_reward(200)  # +200 for each enemy killed

Event-Based Rewards

The environment tracks game events and computes shaped rewards:
# From VizDoomEnv.step() reward computation
reward_components = {
    'kill_reward': info['event_enemy_kill'] * 200,
    'armor_bonus': info['event_armor_pickup'] * 50,
    'damage_penalty': info['event_took_damage'] * (-10),
    'ammo_waste_penalty': info['event_ammo_waste'] * (-5),
    'approach_bonus': info['event_move_closer'] * 1.0,
    'retreat_penalty': info['event_move_farther'] * (-0.5),
}

total_reward = sum(reward_components.values())

Simplified Reward Mode

With simplified_reward=True (default), only core events contribute:
if config.simplified_reward:
    # Focus on kills and survival
    reward = (
        kill_reward +
        damage_penalty +  
        armor_bonus
    )
else:
    # Full shaped rewards including aim alignment and velocity
    reward = (
        kill_reward +
        damage_penalty +
        armor_bonus +
        aim_alignment_reward +  # Bonus for facing enemies
        velocity_reward +        # Encourage movement
        ammo_waste_penalty +
        approach_bonus +
        retreat_penalty
    )
Simplified reward was enabled based on the “Doom Initial Report” findings. The full shaped reward system can provide denser learning signals but may introduce confounding factors when analyzing biological learning.

Training Hyperparameters

PPO Configuration

# From PPOConfig in training_server.py:120-128
learning_rate = 3e-4          # Adam optimizer learning rate
gamma = 0.997                 # Discount factor (higher for survival)
gae_lambda = 0.95            # GAE lambda parameter
clip_epsilon = 0.2           # PPO clipping range
value_loss_coef = 0.3        # Value loss weight
entropy_coef = 0.02          # Decoder entropy bonus
max_grad_norm = 3.0          # Gradient clipping threshold
normalize_returns = False    # Return normalization disabled

Rollout Parameters

num_envs = 1                 # Single environment (no parallelization)
steps_per_update = 2048      # Steps collected before each update
batch_size = 256             # Minibatch size for SGD
num_epochs = 4               # PPO epochs per rollout
max_episodes = 2000          # Total training episodes

Network Architecture

hidden_size = 128                   # MLP hidden layer size
encoder_cnn_channels = 64          # CNN base channels (increased from 16)
encoder_trainable = True           # Learn encoder via policy gradients
encoder_entropy_coef = -0.10       # Penalize encoder entropy

decoder_use_mlp = False            # Use linear readout (not MLP)
decoder_zero_bias = True           # Zero decoder biases
decoder_enforce_nonnegative = False  # Allow negative weights
decoder_mlp_hidden = 256           # MLP hidden size (if enabled)

Gradient Clipping

# Prevent gradient explosion
nn.utils.clip_grad_norm_(policy.parameters(), max_norm=max_grad_norm)
  • max_grad_norm = 3.0: Conservative clipping to protect neural interface
  • Biological neurons may be sensitive to extreme stimulation changes
  • Lower values (0.5-1.0) tested but found too restrictive

Combinatorial Action Space

The decoder outputs logits over 54 discrete joint actions:
# From PPOPolicy.__init__() in training_server.py:752-769
forward_options = ['none', 'forward', 'backward']    # 3 options
strafe_options = ['none', 'left', 'right']           # 3 options  
camera_options = ['none', 'turn_left', 'turn_right'] # 3 options
attack_options = ['idle', 'attack']                  # 2 options
speed_options = ['off']                               # 1 option (removed)

# Total: 3 × 3 × 3 × 2 × 1 = 54 actions
for fwd in forward_options:
    for strafe in strafe_options:
        for turn in camera_options:
            for attack in attack_options:
                combinatorial_action_defs.append({
                    'forward': fwd, 'strafe': strafe,
                    'turn': turn, 'attack': attack
                })
Earlier versions used factored action spaces (separate distributions for movement/turning/attack), but this caused issues:
  1. Independence assumption violated: Movement and attack should be coordinated
  2. Credit assignment difficulty: Hard to tell which action component caused reward
  3. Exploration challenges: Random independent actions rarely produce coherent behavior
The combinatorial action space:
  • Models realistic action combinations (e.g., “strafe left while shooting”)
  • Simplifies credit assignment to single action selection
  • Reduces decoder to single softmax head (54 logits)
  • Allows natural exploration through single categorical distribution

Training Dynamics

Exploration Strategy

# Sample from action distribution during training
joint_dist = Categorical(logits=joint_logits)
joint_action = joint_dist.sample()  # Stochastic sampling

# Deterministic evaluation
joint_action = joint_logits.argmax(dim=-1)  # Greedy selection
Exploration is managed by:
  1. Entropy bonus (0.02): Encourages trying diverse actions
  2. Stochastic sampling: Naturally explores via softmax distribution
  3. Encoder exploration: Beta distributions add stimulation variability

Value Function Bootstrap

The value network learns to predict episode returns:
# Value targets from GAE
returns = advantages + values  # R_t = A_t + V(s_t)

# Value loss
value_loss = (value_predictions - returns)²
This helps with:
  • Sparse rewards: Predict long-term value even without immediate reward
  • Credit assignment: Understand which states lead to eventual success
  • Advantage estimation: Reduce variance in policy gradients

Biological Learning Considerations

Encoder Adaptation

The encoder must learn stimulation patterns that:
  1. Evoke informative spike patterns from biological neurons
  2. Differentiate game states through neural activity
  3. Remain within safe stimulation bounds (1-2.5 μA, 4-40 Hz)

Decoder Constraints

# Linear readout without bias
if decoder_zero_bias:
    decoder.bias.data.zero_()
    decoder.bias.requires_grad = False

# Optional non-negative weights
if decoder_enforce_nonnegative:
    weights = F.softplus(decoder.weight)  # Ensure w ≥ 0
These constraints ensure:
  • Actions depend on spike counts, not learned offsets
  • Decoder interpretability (positive weights = excitatory)
  • Biological plausibility of the read-out mechanism
Monitor the ratio of decoder weight magnitude to bias magnitude during training. If bias dominates, the biological neurons aren’t contributing meaningfully to decisions.

Logging and Metrics

Key metrics tracked during training:
metrics = {
    'Policy/loss': policy_loss.item(),
    'Policy/entropy': entropy.mean().item(),
    'Value/loss': value_loss.item(),
    'Value/mean': values.mean().item(),
    'Encoder/entropy': encoder_entropy.mean().item(),
    'Encoder/freq_mean': frequencies.mean().item(),
    'Encoder/amp_mean': amplitudes.mean().item(),
    'Decoder/weight_l2': decoder.weight.pow(2).sum().item(),
    'Episode/reward': episode_reward,
    'Episode/length': episode_length,
    'Episode/kills': kill_count,
}
These are logged to TensorBoard for monitoring convergence and diagnosing issues.