Skip to content

Commit

Permalink
feat: add init evo rainbow
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Jan 4, 2025
1 parent 0dcf410 commit ad234a2
Show file tree
Hide file tree
Showing 7 changed files with 708 additions and 16 deletions.
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ chex
colorama
craftax
distrax @ git+https://github.com/google-deepmind/distrax # distrax release doesn't support jax > 0.4.13
envpool
flashbax @ git+https://github.com/instadeepai/flashbax
flax
gymnasium
Expand Down
6 changes: 3 additions & 3 deletions stoix/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ architecture_name: anakin
# --- Training ---
seed: 42 # RNG seed.
update_batch_size: 1 # Number of vectorised gradient updates per device.
total_num_envs: 1024 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size.
total_timesteps: 1e7 # Set the total environment steps.
total_num_envs: 2 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size.
total_timesteps: 1e3 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates

Expand All @@ -13,6 +13,6 @@ evaluation_greedy: False # Evaluate the policy greedily. If True the policy will
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 128 # Number of episodes to evaluate per evaluation.
num_evaluation: 50 # Number of evenly spaced evaluations to perform during training.
num_evaluation: 10 # Number of evenly spaced evaluations to perform during training.
absolute_metric: True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485
4 changes: 2 additions & 2 deletions stoix/configs/network/mlp_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
actor_network:
pre_torso:
_target_: stoix.networks.torso.NoisyMLPTorso
layer_sizes: [256, 256]
layer_sizes: [16, 16]
use_layer_norm: False
activation: silu
sigma_zero: ${system.sigma_zero}
action_head:
_target_: stoix.networks.dueling.NoisyDistributionalDuelingQNetwork
layer_sizes: [512]
layer_sizes: [16]
use_layer_norm: False
activation: silu
vmin: ${system.vmin}
Expand Down
8 changes: 4 additions & 4 deletions stoix/configs/system/q_learning/ff_rainbow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ system_name: ff_rainbow # Name of the system.

# --- RL hyperparameters ---
rollout_length: 4 # Number of environment steps per vectorised environment.
epochs: 128 # Number of sgd steps per rollout.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 512 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_buffer_size: 5_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 32 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
priority_exponent: 0.5 # exponent for the prioritised experience replay
importance_sampling_exponent: 0.4 # exponent for the importance sampling weights
n_step: 5 # how many steps in the transition to use for the n-step return
Expand All @@ -19,7 +19,7 @@ decay_learning_rates: False # Whether learning rates should be linearly decayed
training_epsilon: 0.0 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.0 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward: 1000.0 # maximum absolute reward value
num_atoms: 51 # number of atoms in the distributional Q network
num_atoms: 2 # number of atoms in the distributional Q network
vmin: 0.0 # minimum value of the support
vmax: 500.0 # maximum value of the support
sigma_zero: 0.25 # initialization value for noisy variance terms
12 changes: 6 additions & 6 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ def evaluator_setup(
10,
)

evaluator = jax.pmap(evaluator, axis_name="device")
absolute_metric_evaluator = jax.pmap(absolute_metric_evaluator, axis_name="device")
# evaluator = jax.pmap(evaluator, axis_name="device")
# absolute_metric_evaluator = jax.pmap(absolute_metric_evaluator, axis_name="device")

# Broadcast trained params to cores and split keys for each core.
trained_params = unreplicate_batch_dim(params)
key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
eval_keys = jnp.stack(eval_keys).reshape(n_devices, -1)
# trained_params = unreplicate_batch_dim(params)
# key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
# eval_keys = jnp.stack(eval_keys).reshape(n_devices, -1)

return evaluator, absolute_metric_evaluator, (trained_params, eval_keys)
return evaluator, absolute_metric_evaluator


def get_sebulba_eval_fn(
Expand Down
Loading

0 comments on commit ad234a2

Please sign in to comment.