Skip to content

Commit

Permalink
feat: change ppo systems
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Jan 2, 2025
1 parent 58d49c8 commit 1196049
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 86 deletions.
42 changes: 25 additions & 17 deletions stoix/systems/ppo/anakin/ff_dpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _actor_loss_fn(
"actor_loss": loss_actor,
"entropy": entropy,
}

return total_loss_actor, loss_info

def _critic_loss_fn(
Expand Down Expand Up @@ -316,10 +317,8 @@ def learner_setup(
env: Environment, keys: chex.Array, config: DictConfig
) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]:
"""Initialise learner_fn, network, optimiser, environment and states."""
# Get available TPU cores.
n_devices = len(jax.local_devices())

# Get number of actions.

# Get number/dimension of actions.
num_actions = int(env.action_spec().shape[-1])
config.system.action_dim = num_actions
config.system.action_minimum = float(env.action_spec().minimum)
Expand Down Expand Up @@ -386,13 +385,13 @@ def learner_setup(

# Initialise environment states and timesteps: across devices and batches.
key, *env_keys = jax.random.split(
key, n_devices * config.arch.update_batch_size * config.arch.num_local_envs + 1
key, config.arch.num_local_devices * config.arch.update_batch_size * config.arch.num_local_envs + 1
)
env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
jnp.stack(env_keys),
)
reshape_states = lambda x: x.reshape(
(n_devices, config.arch.update_batch_size, config.arch.num_local_envs) + x.shape[1:]
(config.arch.num_local_devices, config.arch.update_batch_size, config.arch.num_local_envs) + x.shape[1:]
)
# (devices, update batch size, num_envs, ...)
env_states = jax.tree_util.tree_map(reshape_states, env_states)
Expand All @@ -411,8 +410,8 @@ def learner_setup(

# Define params to be replicated across devices and batches.
key, step_key = jax.random.split(key)
step_keys = jax.random.split(step_key, n_devices * config.arch.update_batch_size)
reshape_keys = lambda x: x.reshape((n_devices, config.arch.update_batch_size) + x.shape[1:])
step_keys = jax.random.split(step_key, config.arch.num_local_devices * config.arch.update_batch_size)
reshape_keys = lambda x: x.reshape((config.arch.num_local_devices, config.arch.update_batch_size) + x.shape[1:])
step_keys = reshape_keys(jnp.stack(step_keys))
opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states)
Expand All @@ -434,10 +433,20 @@ def learner_setup(
def run_experiment(_config: DictConfig) -> float:
"""Runs experiment."""
config = copy.deepcopy(_config)

# Get device and host information
config.arch.num_global_devices = jax.device_count()
config.arch.num_local_devices = jax.local_device_count()
config.arch.num_processes = jax.process_count()
config.arch.process_id = jax.process_index()
if jax.device_count() == jax.local_device_count():
print(f"{Fore.CYAN}{Style.BRIGHT}Running a single-host experiment with {jax.device_count()} devices.{Style.RESET_ALL}")
config.arch.is_multihost = False
else:
print(f"{Fore.CYAN}{Style.BRIGHT}Running a multi-host experiment with {jax.device_count()} devices on {jax.host_count()} hosts ({jax.local_device_count()} devices per host).{Style.RESET_ALL}")
config.arch.is_multihost = True

# Calculate total timesteps.
n_devices = len(jax.local_devices())
config.num_devices = n_devices
config = check_total_timesteps(config)
assert (
config.arch.num_updates >= config.arch.num_evaluation
Expand All @@ -448,7 +457,7 @@ def run_experiment(_config: DictConfig) -> float:

# PRNG keys.
key, key_e, actor_net_key, critic_net_key = jax.random.split(
jax.random.PRNGKey(config.arch.seed), num=4
jax.random.PRNGKey(config.arch.seed+config.arch.process_id), num=4
)

# Setup learner.
Expand All @@ -468,7 +477,7 @@ def run_experiment(_config: DictConfig) -> float:
# Calculate number of updates per evaluation.
config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation
steps_per_rollout = (
n_devices
config.arch.num_global_devices
* config.arch.num_updates_per_eval
* config.system.rollout_length
* config.arch.update_batch_size
Expand Down Expand Up @@ -496,7 +505,6 @@ def run_experiment(_config: DictConfig) -> float:
for eval_step in range(config.arch.num_evaluation):
# Train.
start_time = time.time()

learner_output = learn(learner_state)
jax.block_until_ready(learner_output)

Expand Down Expand Up @@ -525,9 +533,9 @@ def run_experiment(_config: DictConfig) -> float:
trained_params = unreplicate_batch_dim(
learner_output.learner_state.params.actor_params
) # Select only actor params
key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
key_e, *eval_keys = jax.random.split(key_e, config.arch.num_local_devices + 1)
eval_keys = jnp.stack(eval_keys)
eval_keys = eval_keys.reshape(n_devices, -1)
eval_keys = eval_keys.reshape(config.arch.num_local_devices, -1)

# Evaluate.
evaluator_output = evaluator(trained_params, eval_keys)
Expand Down Expand Up @@ -560,9 +568,9 @@ def run_experiment(_config: DictConfig) -> float:
if config.arch.absolute_metric:
start_time = time.time()

key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
key_e, *eval_keys = jax.random.split(key_e, config.arch.num_local_devices + 1)
eval_keys = jnp.stack(eval_keys)
eval_keys = eval_keys.reshape(n_devices, -1)
eval_keys = eval_keys.reshape(config.arch.num_local_devices, -1)

evaluator_output = absolute_metric_evaluator(best_params, eval_keys)
jax.block_until_ready(evaluator_output)
Expand Down
1 change: 0 additions & 1 deletion stoix/systems/ppo/anakin/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,6 @@ def run_experiment(_config: DictConfig) -> float:
for eval_step in range(config.arch.num_evaluation):
# Train.
start_time = time.time()

learner_output = learn(learner_state)
jax.block_until_ready(learner_output)

Expand Down
42 changes: 25 additions & 17 deletions stoix/systems/ppo/anakin/ff_ppo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,8 @@ def learner_setup(
env: Environment, keys: chex.Array, config: DictConfig
) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]:
"""Initialise learner_fn, network, optimiser, environment and states."""
# Get available TPU cores.
n_devices = len(jax.local_devices())

# Get number of actions.

# Get number/dimension of actions.
num_actions = int(env.action_spec().shape[-1])
config.system.action_dim = num_actions
config.system.action_minimum = float(env.action_spec().minimum)
Expand Down Expand Up @@ -387,13 +385,13 @@ def learner_setup(

# Initialise environment states and timesteps: across devices and batches.
key, *env_keys = jax.random.split(
key, n_devices * config.arch.update_batch_size * config.arch.num_local_envs + 1
key, config.arch.num_local_devices * config.arch.update_batch_size * config.arch.num_local_envs + 1
)
env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
jnp.stack(env_keys),
)
reshape_states = lambda x: x.reshape(
(n_devices, config.arch.update_batch_size, config.arch.num_local_envs) + x.shape[1:]
(config.arch.num_local_devices, config.arch.update_batch_size, config.arch.num_local_envs) + x.shape[1:]
)
# (devices, update batch size, num_envs, ...)
env_states = jax.tree_util.tree_map(reshape_states, env_states)
Expand All @@ -412,8 +410,8 @@ def learner_setup(

# Define params to be replicated across devices and batches.
key, step_key = jax.random.split(key)
step_keys = jax.random.split(step_key, n_devices * config.arch.update_batch_size)
reshape_keys = lambda x: x.reshape((n_devices, config.arch.update_batch_size) + x.shape[1:])
step_keys = jax.random.split(step_key, config.arch.num_local_devices * config.arch.update_batch_size)
reshape_keys = lambda x: x.reshape((config.arch.num_local_devices, config.arch.update_batch_size) + x.shape[1:])
step_keys = reshape_keys(jnp.stack(step_keys))
opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states)
Expand All @@ -435,10 +433,20 @@ def learner_setup(
def run_experiment(_config: DictConfig) -> float:
"""Runs experiment."""
config = copy.deepcopy(_config)

# Get device and host information
config.arch.num_global_devices = jax.device_count()
config.arch.num_local_devices = jax.local_device_count()
config.arch.num_processes = jax.process_count()
config.arch.process_id = jax.process_index()
if jax.device_count() == jax.local_device_count():
print(f"{Fore.CYAN}{Style.BRIGHT}Running a single-host experiment with {jax.device_count()} devices.{Style.RESET_ALL}")
config.arch.is_multihost = False
else:
print(f"{Fore.CYAN}{Style.BRIGHT}Running a multi-host experiment with {jax.device_count()} devices on {jax.host_count()} hosts ({jax.local_device_count()} devices per host).{Style.RESET_ALL}")
config.arch.is_multihost = True

# Calculate total timesteps.
n_devices = len(jax.local_devices())
config.num_devices = n_devices
config = check_total_timesteps(config)
assert (
config.arch.num_updates >= config.arch.num_evaluation
Expand All @@ -449,7 +457,7 @@ def run_experiment(_config: DictConfig) -> float:

# PRNG keys.
key, key_e, actor_net_key, critic_net_key = jax.random.split(
jax.random.PRNGKey(config.arch.seed), num=4
jax.random.PRNGKey(config.arch.seed+config.arch.process_id), num=4
)

# Setup learner.
Expand All @@ -469,7 +477,7 @@ def run_experiment(_config: DictConfig) -> float:
# Calculate number of updates per evaluation.
config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation
steps_per_rollout = (
n_devices
config.arch.num_global_devices
* config.arch.num_updates_per_eval
* config.system.rollout_length
* config.arch.update_batch_size
Expand Down Expand Up @@ -497,7 +505,6 @@ def run_experiment(_config: DictConfig) -> float:
for eval_step in range(config.arch.num_evaluation):
# Train.
start_time = time.time()

learner_output = learn(learner_state)
jax.block_until_ready(learner_output)

Expand Down Expand Up @@ -526,9 +533,9 @@ def run_experiment(_config: DictConfig) -> float:
trained_params = unreplicate_batch_dim(
learner_output.learner_state.params.actor_params
) # Select only actor params
key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
key_e, *eval_keys = jax.random.split(key_e, config.arch.num_local_devices + 1)
eval_keys = jnp.stack(eval_keys)
eval_keys = eval_keys.reshape(n_devices, -1)
eval_keys = eval_keys.reshape(config.arch.num_local_devices, -1)

# Evaluate.
evaluator_output = evaluator(trained_params, eval_keys)
Expand Down Expand Up @@ -561,9 +568,9 @@ def run_experiment(_config: DictConfig) -> float:
if config.arch.absolute_metric:
start_time = time.time()

key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
key_e, *eval_keys = jax.random.split(key_e, config.arch.num_local_devices + 1)
eval_keys = jnp.stack(eval_keys)
eval_keys = eval_keys.reshape(n_devices, -1)
eval_keys = eval_keys.reshape(config.arch.num_local_devices, -1)

evaluator_output = absolute_metric_evaluator(best_params, eval_keys)
jax.block_until_ready(evaluator_output)
Expand Down Expand Up @@ -594,6 +601,7 @@ def hydra_entry_point(cfg: DictConfig) -> float:

# Run experiment.
eval_performance = run_experiment(cfg)

print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}")
return eval_performance

Expand Down
39 changes: 23 additions & 16 deletions stoix/systems/ppo/anakin/ff_ppo_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,7 @@ def learner_setup(
env: Environment, keys: chex.Array, config: DictConfig
) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]:
"""Initialise learner_fn, network, optimiser, environment and states."""
# Get available TPU cores.
n_devices = len(jax.local_devices())


# Get number/dimension of actions.
num_actions = int(env.action_spec().num_values)
config.system.action_dim = num_actions
Expand Down Expand Up @@ -385,13 +383,13 @@ def learner_setup(

# Initialise environment states and timesteps: across devices and batches.
key, *env_keys = jax.random.split(
key, n_devices * config.arch.update_batch_size * config.arch.num_local_envs + 1
key, config.arch.num_local_devices * config.arch.update_batch_size * config.arch.num_local_envs + 1
)
env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
jnp.stack(env_keys),
)
reshape_states = lambda x: x.reshape(
(n_devices, config.arch.update_batch_size, config.arch.num_local_envs) + x.shape[1:]
(config.arch.num_local_devices, config.arch.update_batch_size, config.arch.num_local_envs) + x.shape[1:]
)
# (devices, update batch size, num_envs, ...)
env_states = jax.tree_util.tree_map(reshape_states, env_states)
Expand All @@ -410,8 +408,8 @@ def learner_setup(

# Define params to be replicated across devices and batches.
key, step_key = jax.random.split(key)
step_keys = jax.random.split(step_key, n_devices * config.arch.update_batch_size)
reshape_keys = lambda x: x.reshape((n_devices, config.arch.update_batch_size) + x.shape[1:])
step_keys = jax.random.split(step_key, config.arch.num_local_devices * config.arch.update_batch_size)
reshape_keys = lambda x: x.reshape((config.arch.num_local_devices, config.arch.update_batch_size) + x.shape[1:])
step_keys = reshape_keys(jnp.stack(step_keys))
opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states)
Expand All @@ -433,10 +431,20 @@ def learner_setup(
def run_experiment(_config: DictConfig) -> float:
"""Runs experiment."""
config = copy.deepcopy(_config)

# Get device and host information
config.arch.num_global_devices = jax.device_count()
config.arch.num_local_devices = jax.local_device_count()
config.arch.num_processes = jax.process_count()
config.arch.process_id = jax.process_index()
if jax.device_count() == jax.local_device_count():
print(f"{Fore.CYAN}{Style.BRIGHT}Running a single-host experiment with {jax.device_count()} devices.{Style.RESET_ALL}")
config.arch.is_multihost = False
else:
print(f"{Fore.CYAN}{Style.BRIGHT}Running a multi-host experiment with {jax.device_count()} devices on {jax.host_count()} hosts ({jax.local_device_count()} devices per host).{Style.RESET_ALL}")
config.arch.is_multihost = True

# Calculate total timesteps.
n_devices = len(jax.local_devices())
config.num_devices = n_devices
config = check_total_timesteps(config)
assert (
config.arch.num_updates >= config.arch.num_evaluation
Expand All @@ -447,7 +455,7 @@ def run_experiment(_config: DictConfig) -> float:

# PRNG keys.
key, key_e, actor_net_key, critic_net_key = jax.random.split(
jax.random.PRNGKey(config.arch.seed), num=4
jax.random.PRNGKey(config.arch.seed+config.arch.process_id), num=4
)

# Setup learner.
Expand All @@ -467,7 +475,7 @@ def run_experiment(_config: DictConfig) -> float:
# Calculate number of updates per evaluation.
config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation
steps_per_rollout = (
n_devices
config.arch.num_global_devices
* config.arch.num_updates_per_eval
* config.system.rollout_length
* config.arch.update_batch_size
Expand Down Expand Up @@ -495,7 +503,6 @@ def run_experiment(_config: DictConfig) -> float:
for eval_step in range(config.arch.num_evaluation):
# Train.
start_time = time.time()

learner_output = learn(learner_state)
jax.block_until_ready(learner_output)

Expand Down Expand Up @@ -524,9 +531,9 @@ def run_experiment(_config: DictConfig) -> float:
trained_params = unreplicate_batch_dim(
learner_output.learner_state.params.actor_params
) # Select only actor params
key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
key_e, *eval_keys = jax.random.split(key_e, config.arch.num_local_devices + 1)
eval_keys = jnp.stack(eval_keys)
eval_keys = eval_keys.reshape(n_devices, -1)
eval_keys = eval_keys.reshape(config.arch.num_local_devices, -1)

# Evaluate.
evaluator_output = evaluator(trained_params, eval_keys)
Expand Down Expand Up @@ -559,9 +566,9 @@ def run_experiment(_config: DictConfig) -> float:
if config.arch.absolute_metric:
start_time = time.time()

key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
key_e, *eval_keys = jax.random.split(key_e, config.arch.num_local_devices + 1)
eval_keys = jnp.stack(eval_keys)
eval_keys = eval_keys.reshape(n_devices, -1)
eval_keys = eval_keys.reshape(config.arch.num_local_devices, -1)

evaluator_output = absolute_metric_evaluator(best_params, eval_keys)
jax.block_until_ready(evaluator_output)
Expand Down
Loading

0 comments on commit 1196049

Please sign in to comment.