From 16abaa9365fa0e6b3888be4f1907f7cd1222f96c Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Sun, 5 Nov 2023 20:01:48 +0800 Subject: [PATCH 1/6] feat: support discrete environment --- docs/source/envs/discrete_env.rst | 18 ++ docs/source/index.rst | 1 + docs/source/model/actor.rst | 14 ++ docs/source/saferl/lag.rst | 2 +- omnisafe/adapter/online_adapter.py | 14 +- omnisafe/algorithms/algo_wrapper.py | 5 +- .../on_policy/base/policy_gradient.py | 4 +- omnisafe/algorithms/on_policy/base/ppo.py | 4 +- omnisafe/common/buffer/base.py | 10 +- omnisafe/configs/on-policy/CPO.yaml | 36 +++ omnisafe/configs/on-policy/CPPOPID.yaml | 36 +++ omnisafe/configs/on-policy/IPO.yaml | 36 +++ omnisafe/configs/on-policy/NaturalPG.yaml | 36 +++ omnisafe/configs/on-policy/OnCRPO.yaml | 36 +++ omnisafe/configs/on-policy/P3O.yaml | 36 +++ omnisafe/configs/on-policy/PCPO.yaml | 36 +++ omnisafe/configs/on-policy/PDO.yaml | 37 ++++ omnisafe/configs/on-policy/PPO.yaml | 36 +++ omnisafe/configs/on-policy/PPOLag.yaml | 36 +++ .../configs/on-policy/PolicyGradient.yaml | 36 +++ omnisafe/configs/on-policy/RCPO.yaml | 36 +++ omnisafe/configs/on-policy/TRPO.yaml | 36 +++ omnisafe/configs/on-policy/TRPOLag.yaml | 36 +++ omnisafe/configs/on-policy/TRPOPID.yaml | 36 +++ omnisafe/envs/__init__.py | 27 +++ omnisafe/envs/core.py | 1 + omnisafe/envs/discrete_env.py | 208 ++++++++++++++++++ omnisafe/envs/mujoco_env.py | 5 +- omnisafe/envs/safety_gymnasium_env.py | 6 +- omnisafe/envs/safety_gymnasium_modelbased.py | 2 + omnisafe/envs/wrapper.py | 5 +- omnisafe/evaluator.py | 25 ++- omnisafe/models/actor/__init__.py | 1 + omnisafe/models/actor/actor_builder.py | 47 ++-- omnisafe/models/actor/categorical_actor.py | 128 +++++++++++ omnisafe/models/base.py | 12 +- omnisafe/typing.py | 2 +- omnisafe/utils/config.py | 38 +++- tests/simple_env.py | 1 + tests/test_env.py | 48 +++- tests/test_model.py | 59 +++-- tests/test_utils.py | 9 +- 42 files changed, 1165 insertions(+), 72 deletions(-) create mode 100644 docs/source/envs/discrete_env.rst create mode 100644 omnisafe/envs/discrete_env.py create mode 100644 omnisafe/models/actor/categorical_actor.py diff --git a/docs/source/envs/discrete_env.rst b/docs/source/envs/discrete_env.rst new file mode 100644 index 000000000..3dc30b223 --- /dev/null +++ b/docs/source/envs/discrete_env.rst @@ -0,0 +1,18 @@ +OmniSafe Discrete Environment +============================= + +.. currentmodule:: omnisafe.envs.discrete_env + +Discrete Environment Interface +------------------------------ + +.. card:: + :class-header: sd-bg-success sd-text-white + :class-card: sd-outline-success sd-rounded-1 + + Documentation + ^^^ + + .. autoclass:: DiscreteEnv + :members: + :private-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index e759bebee..ed029b413 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -461,6 +461,7 @@ this project, don't hesitate to ask your question on `the GitHub issue page Config: self.algo in ALGORITHMS['all'] ), f"{self.algo} doesn't exist. Please choose from {ALGORITHMS['all']}." self.algo_type = ALGORITHM2TYPE.get(self.algo, '') + self.env_type = ENVIRONMNET2TYPE.get(self.env_id, '') if self.train_terminal_cfgs is not None: if self.algo_type in ['model-based', 'offline']: assert ( @@ -146,7 +147,7 @@ def _init_checks(self) -> None: def _init_algo(self) -> None: """Initialize the algorithm.""" - check_all_configs(self.cfgs, self.algo_type) + check_all_configs(self.cfgs, self.algo_type, self.env_type) if distributed.fork( self.cfgs.train_cfgs.parallel, device=self.cfgs.train_cfgs.device, diff --git a/omnisafe/algorithms/on_policy/base/policy_gradient.py b/omnisafe/algorithms/on_policy/base/policy_gradient.py index 01409b574..d6656f365 100644 --- a/omnisafe/algorithms/on_policy/base/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/base/policy_gradient.py @@ -555,15 +555,15 @@ def _loss_pi( """ distribution = self._actor_critic.actor(obs) logp_ = self._actor_critic.actor.log_prob(act) - std = self._actor_critic.actor.std ratio = torch.exp(logp_ - logp) loss = -(ratio * adv).mean() entropy = distribution.entropy().mean().item() + if self._cfgs.model_cfgs.actor_type == 'gaussian_learning': + self._logger.store({'Train/PolicyStd': self._actor_critic.actor.std}) self._logger.store( { 'Train/Entropy': entropy, 'Train/PolicyRatio': ratio, - 'Train/PolicyStd': std, 'Loss/Loss_pi': loss.mean().item(), }, ) diff --git a/omnisafe/algorithms/on_policy/base/ppo.py b/omnisafe/algorithms/on_policy/base/ppo.py index 69f0ce4e9..814dca535 100644 --- a/omnisafe/algorithms/on_policy/base/ppo.py +++ b/omnisafe/algorithms/on_policy/base/ppo.py @@ -65,7 +65,6 @@ def _loss_pi( """ distribution = self._actor_critic.actor(obs) logp_ = self._actor_critic.actor.log_prob(act) - std = self._actor_critic.actor.std ratio = torch.exp(logp_ - logp) ratio_cliped = torch.clamp( ratio, @@ -76,11 +75,12 @@ def _loss_pi( loss -= self._cfgs.algo_cfgs.entropy_coef * distribution.entropy().mean() # useful extra info entropy = distribution.entropy().mean().item() + if self._cfgs.model_cfgs.actor_type == 'gaussian_learning': + self._logger.store({'Train/PolicyStd': self._actor_critic.actor.std}) self._logger.store( { 'Train/Entropy': entropy, 'Train/PolicyRatio': ratio, - 'Train/PolicyStd': std, 'Loss/Loss_pi': loss.mean().item(), }, ) diff --git a/omnisafe/common/buffer/base.py b/omnisafe/common/buffer/base.py index 08864ecb0..fe7be93fc 100644 --- a/omnisafe/common/buffer/base.py +++ b/omnisafe/common/buffer/base.py @@ -19,7 +19,7 @@ from abc import ABC, abstractmethod import torch -from gymnasium.spaces import Box +from gymnasium.spaces import Box, Discrete from omnisafe.typing import DEVICE_CPU, OmnisafeSpace @@ -57,8 +57,8 @@ class BaseBuffer(ABC): data (dict[str, torch.Tensor]): The data of the buffer. Raises: - NotImplementedError: If the observation space or the action space is not Box. - NotImplementedError: If the action space or the action space is not Box. + NotImplementedError: If the observation space or the action space is not Box nor Discrete. + NotImplementedError: If the action space or the action space is not Box nor Discrete. """ def __init__( @@ -72,10 +72,14 @@ def __init__( self._device: torch.device = device if isinstance(obs_space, Box): obs_buf = torch.zeros((size, *obs_space.shape), dtype=torch.float32, device=device) + elif isinstance(obs_space, Discrete): + obs_buf = torch.zeros((size, 1), dtype=torch.float32, device=device) else: raise NotImplementedError if isinstance(act_space, Box): act_buf = torch.zeros((size, *act_space.shape), dtype=torch.float32, device=device) + elif isinstance(act_space, Discrete): + act_buf = torch.zeros((size), dtype=torch.float32, device=device) else: raise NotImplementedError diff --git a/omnisafe/configs/on-policy/CPO.yaml b/omnisafe/configs/on-policy/CPO.yaml index 6620a9702..012e29761 100644 --- a/omnisafe/configs/on-policy/CPO.yaml +++ b/omnisafe/configs/on-policy/CPO.yaml @@ -126,3 +126,39 @@ defaults: activation: tanh # learning rate lr: 0.001 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/CPPOPID.yaml b/omnisafe/configs/on-policy/CPPOPID.yaml index 5054ffaa0..bca1e5ae3 100644 --- a/omnisafe/configs/on-policy/CPPOPID.yaml +++ b/omnisafe/configs/on-policy/CPPOPID.yaml @@ -142,3 +142,39 @@ defaults: penalty_max: 100.0 # Initial value of lagrangian multiplier lagrangian_multiplier_init: 0.001 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/IPO.yaml b/omnisafe/configs/on-policy/IPO.yaml index 852b08344..594e387bb 100644 --- a/omnisafe/configs/on-policy/IPO.yaml +++ b/omnisafe/configs/on-policy/IPO.yaml @@ -134,3 +134,39 @@ defaults: lambda_lr: 0.035 # Type of lagrangian optimizer lambda_optimizer: "Adam" + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/NaturalPG.yaml b/omnisafe/configs/on-policy/NaturalPG.yaml index e737e3873..d1ba57d92 100644 --- a/omnisafe/configs/on-policy/NaturalPG.yaml +++ b/omnisafe/configs/on-policy/NaturalPG.yaml @@ -126,3 +126,39 @@ defaults: activation: tanh # learning rate lr: 0.001 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/OnCRPO.yaml b/omnisafe/configs/on-policy/OnCRPO.yaml index 73e558700..469820165 100644 --- a/omnisafe/configs/on-policy/OnCRPO.yaml +++ b/omnisafe/configs/on-policy/OnCRPO.yaml @@ -128,3 +128,39 @@ defaults: activation: tanh # learning rate lr: 0.001 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/P3O.yaml b/omnisafe/configs/on-policy/P3O.yaml index cbb7317dd..f065f954f 100644 --- a/omnisafe/configs/on-policy/P3O.yaml +++ b/omnisafe/configs/on-policy/P3O.yaml @@ -122,3 +122,39 @@ defaults: activation: tanh # learning rate lr: 0.0003 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/PCPO.yaml b/omnisafe/configs/on-policy/PCPO.yaml index 6620a9702..012e29761 100644 --- a/omnisafe/configs/on-policy/PCPO.yaml +++ b/omnisafe/configs/on-policy/PCPO.yaml @@ -126,3 +126,39 @@ defaults: activation: tanh # learning rate lr: 0.001 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/PDO.yaml b/omnisafe/configs/on-policy/PDO.yaml index 1bb19bc09..e5c39b91a 100644 --- a/omnisafe/configs/on-policy/PDO.yaml +++ b/omnisafe/configs/on-policy/PDO.yaml @@ -127,3 +127,40 @@ defaults: lambda_lr: 0.035 # Type of lagrangian optimizer lambda_optimizer: "Adam" + + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/PPO.yaml b/omnisafe/configs/on-policy/PPO.yaml index e23621510..4bca1f265 100644 --- a/omnisafe/configs/on-policy/PPO.yaml +++ b/omnisafe/configs/on-policy/PPO.yaml @@ -118,3 +118,39 @@ defaults: activation: tanh # learning rate lr: 0.0003 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/PPOLag.yaml b/omnisafe/configs/on-policy/PPOLag.yaml index 4d673fdbe..b39e2b329 100644 --- a/omnisafe/configs/on-policy/PPOLag.yaml +++ b/omnisafe/configs/on-policy/PPOLag.yaml @@ -128,3 +128,39 @@ defaults: lambda_lr: 0.035 # Type of lagrangian optimizer lambda_optimizer: "Adam" + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/PolicyGradient.yaml b/omnisafe/configs/on-policy/PolicyGradient.yaml index 4f72d15ae..18b19c1dd 100644 --- a/omnisafe/configs/on-policy/PolicyGradient.yaml +++ b/omnisafe/configs/on-policy/PolicyGradient.yaml @@ -116,3 +116,39 @@ defaults: activation: tanh # learning rate lr: 0.0003 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/RCPO.yaml b/omnisafe/configs/on-policy/RCPO.yaml index 713a04193..001570e77 100644 --- a/omnisafe/configs/on-policy/RCPO.yaml +++ b/omnisafe/configs/on-policy/RCPO.yaml @@ -134,3 +134,39 @@ defaults: lambda_lr: 0.035 # Type of lagrangian optimizer lambda_optimizer: "Adam" + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/TRPO.yaml b/omnisafe/configs/on-policy/TRPO.yaml index 455ba163f..997e167bc 100644 --- a/omnisafe/configs/on-policy/TRPO.yaml +++ b/omnisafe/configs/on-policy/TRPO.yaml @@ -124,3 +124,39 @@ defaults: activation: tanh # learning rate lr: 0.001 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/TRPOLag.yaml b/omnisafe/configs/on-policy/TRPOLag.yaml index 713a04193..001570e77 100644 --- a/omnisafe/configs/on-policy/TRPOLag.yaml +++ b/omnisafe/configs/on-policy/TRPOLag.yaml @@ -134,3 +134,39 @@ defaults: lambda_lr: 0.035 # Type of lagrangian optimizer lambda_optimizer: "Adam" + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/configs/on-policy/TRPOPID.yaml b/omnisafe/configs/on-policy/TRPOPID.yaml index 33f9ea076..42b71a545 100644 --- a/omnisafe/configs/on-policy/TRPOPID.yaml +++ b/omnisafe/configs/on-policy/TRPOPID.yaml @@ -150,3 +150,39 @@ defaults: penalty_max: 100.0 # Initial value of lagrangian multiplier lagrangian_multiplier_init: 0.001 + +CartPole-v1: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 500 + # total number of steps to train + total_steps: 1000000 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" + +Taxi-v3: + # logger configurations + logger_cfgs: + # save model frequency + save_model_freq: 5 + # training configurations + train_cfgs: + # max time-step for each episode + time_limit: 200 + # algorithm configurations + algo_cfgs: + # normalize observation + obs_normalize: False + # entropy coefficient + entropy_coef: 0.01 + # model configurations + model_cfgs: + # actor type, options: gaussian, gaussian_learning + actor_type: "discrete" diff --git a/omnisafe/envs/__init__.py b/omnisafe/envs/__init__.py index 7a9f2ea2b..e805252cb 100644 --- a/omnisafe/envs/__init__.py +++ b/omnisafe/envs/__init__.py @@ -14,7 +14,34 @@ # ============================================================================== """Environment API for OmniSafe.""" +import itertools +from types import MappingProxyType + from omnisafe.envs.core import CMDP, env_register, make, support_envs +from omnisafe.envs.discrete_env import DiscreteEnv from omnisafe.envs.mujoco_env import MujocoEnv from omnisafe.envs.safety_gymnasium_env import SafetyGymnasiumEnv from omnisafe.envs.safety_gymnasium_modelbased import SafetyGymnasiumModelBased + + +ENVIRONMENTS = { + 'box': tuple( + MujocoEnv.support_envs() + + SafetyGymnasiumEnv.support_envs() + + SafetyGymnasiumModelBased.support_envs(), + ), + 'discrete': tuple(DiscreteEnv.support_envs()), +} + +ENVIRONMNET2TYPE = { + env: env_type for env_type, environments in ENVIRONMENTS.items() for env in environments +} + +__all__ = ENVIRONMENTS['all'] = tuple(itertools.chain.from_iterable(ENVIRONMENTS.values())) + +assert len(ENVIRONMNET2TYPE) == len(__all__), 'Duplicate algorithm names found.' + +ENVIRONMENTS = MappingProxyType(ENVIRONMENTS) # make this immutable +ENVIRONMNET2TYPE = MappingProxyType(ENVIRONMNET2TYPE) # make this immutable + +del itertools, MappingProxyType diff --git a/omnisafe/envs/core.py b/omnisafe/envs/core.py index 999ac45fe..137d8eb51 100644 --- a/omnisafe/envs/core.py +++ b/omnisafe/envs/core.py @@ -53,6 +53,7 @@ class CMDP(ABC): _time_limit: int | None = None need_time_limit_wrapper: bool need_auto_reset_wrapper: bool + need_action_scale_wrapper: bool _support_envs: ClassVar[list[str]] diff --git a/omnisafe/envs/discrete_env.py b/omnisafe/envs/discrete_env.py new file mode 100644 index 000000000..92662288f --- /dev/null +++ b/omnisafe/envs/discrete_env.py @@ -0,0 +1,208 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Environments with discrete observation/action space in Gymnasium.""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import gymnasium +import numpy as np +import torch +from gymnasium import spaces + +from omnisafe.envs.core import CMDP, env_register +from omnisafe.typing import DEVICE_CPU, Discrete + + +@env_register +class DiscreteEnv(CMDP): + """Discrete Gymnasium Environment. + + This environment only served as an example to integrate discrete action and + observation environment into OmniSafe. We support ``CartPole-v1`` and ``Taxi-v3``. + The former is ``Box`` observation space and ``Discrete`` action space, while + the latter is ``Discrete`` observation and ``Discrete`` action space. + + Args: + env_id (str): Environment id. + num_envs (int, optional): Number of environments. Defaults to 1. + device (torch.device, optional): Device to store the data. Defaults to + ``torch.device('cpu')``. + + Keyword Args: + render_mode (str, optional): The render mode ranges from 'human' to 'rgb_array' and 'rgb_array_list'. + Defaults to 'rgb_array'. + + Attributes: + need_auto_reset_wrapper (bool): Whether to use auto reset wrapper. + need_time_limit_wrapper (bool): Whether to use time limit wrapper. + need_action_repeat_wrapper (bool): Whether to use action repeat wrapper. + need_action_scale_wrapper (bool): Whether to use action scale wrapper. + """ + + need_action_scale_wrapper = False + need_obs_normalize_wrapper = False + need_auto_reset_wrapper = False + need_time_limit_wrapper = False + + _support_envs: ClassVar[list[str]] = [ + 'CartPole-v1', + 'Taxi-v3', + ] + + def __init__( + self, + env_id: str, + num_envs: int = 1, + device: torch.device = DEVICE_CPU, + **kwargs: Any, + ) -> None: + """Initialize an instance of :class:`DiscreteEnv`.""" + super().__init__(env_id) + self._num_envs = num_envs + self._device = torch.device(device) + + if num_envs > 1: + self._env = gymnasium.vector.make( + id=env_id, + num_envs=num_envs, + render_mode=kwargs.get('render_mode'), + ) + assert isinstance( + self._env.single_action_space, + Discrete, + ), 'Only support Discrete action space.' + self._action_space = self._env.single_action_space + self._observation_space = self._env.single_observation_space # type: ignore + else: + self.need_time_limit_wrapper = True + self.need_auto_reset_wrapper = True + self._env = gymnasium.make(id=env_id, autoreset=True, render_mode=kwargs.get('render_mode')) # type: ignore + self._action_space = self._env.action_space # type: ignore + self._observation_space = self._env.observation_space # type: ignore + self._metadata = self._env.metadata + + def step( + self, + action: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + dict[str, Any], + ]: + """Step the environment. + + .. note:: + OmniSafe uses auto reset wrapper to reset the environment when the episode is + terminated. So the ``obs`` will be the first observation of the next episode. And the + true ``final_observation`` in ``info`` will be stored in the ``final_observation`` key + of ``info``. + + Args: + action (torch.Tensor): Action to take. + + Returns: + observation: The agent's observation of the current environment. + reward: The amount of reward returned after previous action. + cost: The amount of cost returned after previous action. + terminated: Whether the episode has ended. + truncated: Whether the episode has been truncated due to a time limit. + info: Some information logged by the environment. + """ + obs, reward, terminated, truncated, info = self._env.step( + action.detach().cpu().numpy().tolist(), + ) + obs, reward, terminated, truncated = ( + torch.as_tensor(x, dtype=torch.float32, device=self._device) + for x in (obs, reward, terminated, truncated) + ) + if isinstance(self._observation_space, spaces.Discrete): + obs = obs.unsqueeze(-1) + if 'final_observation' in info: + if isinstance(info['final_observation'], np.ndarray): + info['final_observation'] = np.array( + [ + array if array is not None else np.zeros(obs.shape[-1]) + for array in info['final_observation'] + ], + ) + info['final_observation'] = torch.as_tensor( + info['final_observation'], + dtype=torch.float32, + device=self._device, + ) + if isinstance(self._observation_space, spaces.Discrete): + info['final_observation'] = info['final_observation'].unsqueeze(-1) + + return obs, reward, torch.zeros_like(reward), terminated, truncated, info + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Reset the environment. + + Args: + seed (int, optional): The random seed. Defaults to None. + options (dict[str, Any], optional): The options for the environment. Defaults to None. + + + Returns: + observation: Agent's observation of the current environment. + info: Some information logged by the environment. + """ + obs, info = self._env.reset(seed=seed, options=options) + obs = torch.as_tensor(obs, dtype=torch.float32, device=self._device) + if isinstance(self._observation_space, spaces.Discrete): + obs = obs.unsqueeze(-1) + return obs, info + + def set_seed(self, seed: int) -> None: + """Set the seed for the environment. + + Args: + seed (int): Seed to set. + """ + self.reset(seed=seed) + + def sample_action(self) -> torch.Tensor: + """Sample a random action. + + Returns: + A random action. + """ + return torch.as_tensor( + self._env.action_space.sample(), + dtype=torch.int64, + device=self._device, + ) + + def render(self) -> Any: + """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment. + + Returns: + The render frames: we recommend to use `np.ndarray` + which could construct video by moviepy. + """ + return self._env.render() + + def close(self) -> None: + """Close the environment.""" + self._env.close() diff --git a/omnisafe/envs/mujoco_env.py b/omnisafe/envs/mujoco_env.py index f5ba67f53..02c490e9c 100644 --- a/omnisafe/envs/mujoco_env.py +++ b/omnisafe/envs/mujoco_env.py @@ -33,12 +33,15 @@ class MujocoEnv(CMDP): Attributes: need_auto_reset_wrapper (bool): Whether to use auto reset wrapper. need_time_limit_wrapper (bool): Whether to use time limit wrapper. + need_action_repeat_wrapper (bool): Whether to use action repeat wrapper. + need_action_scale_wrapper (bool): Whether to use action scale wrapper. """ need_auto_reset_wrapper = True - need_time_limit_wrapper = False need_action_repeat_wrapper = True + need_action_scale_wrapper = True + _support_envs: ClassVar[list[str]] = [ 'Ant-v4', 'Hopper-v4', diff --git a/omnisafe/envs/safety_gymnasium_env.py b/omnisafe/envs/safety_gymnasium_env.py index f9eb75d20..d0abbf228 100644 --- a/omnisafe/envs/safety_gymnasium_env.py +++ b/omnisafe/envs/safety_gymnasium_env.py @@ -47,10 +47,12 @@ class SafetyGymnasiumEnv(CMDP): Attributes: need_auto_reset_wrapper (bool): Whether to use auto reset wrapper. need_time_limit_wrapper (bool): Whether to use time limit wrapper. + need_action_scale_wrapper (bool): Whether to use action scale wrapper. """ - need_auto_reset_wrapper: bool = False - need_time_limit_wrapper: bool = False + need_auto_reset_wrapper = False + need_time_limit_wrapper = False + need_action_scale_wrapper = True _support_envs: ClassVar[list[str]] = [ 'SafetyPointGoal0-v0', diff --git a/omnisafe/envs/safety_gymnasium_modelbased.py b/omnisafe/envs/safety_gymnasium_modelbased.py index 3edc7584b..5fb325065 100644 --- a/omnisafe/envs/safety_gymnasium_modelbased.py +++ b/omnisafe/envs/safety_gymnasium_modelbased.py @@ -36,10 +36,12 @@ class SafetyGymnasiumModelBased(CMDP): # pylint: disable=too-many-instance-attr _support_envs (list[str]): List of supported environments. need_auto_reset_wrapper (bool): Whether to use auto reset wrapper. need_time_limit_wrapper (bool): Whether to use time limit wrapper. + need_action_scale_wrapper (bool): Whether to use action scale wrapper. """ need_auto_reset_wrapper = False need_time_limit_wrapper = False + need_action_scale_wrapper = True _support_envs: ClassVar[list[str]] = [ 'SafetyPointGoal0-v0-modelbased', diff --git a/omnisafe/envs/wrapper.py b/omnisafe/envs/wrapper.py index 5fe8525e7..01623c066 100644 --- a/omnisafe/envs/wrapper.py +++ b/omnisafe/envs/wrapper.py @@ -584,7 +584,10 @@ def __init__(self, env: CMDP, device: torch.device) -> None: """Initialize an instance of :class:`Unsqueeze`.""" super().__init__(env=env, device=device) assert self.num_envs == 1, 'Unsqueeze only works with single environment' - assert isinstance(self.observation_space, spaces.Box), 'Observation space must be Box' + assert isinstance( + self.observation_space, + (spaces.Box, spaces.Discrete), + ), 'Observation space must be Box or Discrete' def step( self, diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py index 7664fd29e..4d25da862 100644 --- a/omnisafe/evaluator.py +++ b/omnisafe/evaluator.py @@ -139,9 +139,8 @@ def __load_model_and_env( FileNotFoundError: If the model is not found. """ # load the saved model - model_path = os.path.join(save_dir, 'torch_save', model_name) try: - model_params = torch.load(model_path) + model_params = torch.load(os.path.join(save_dir, 'torch_save', model_name)) except FileNotFoundError as error: raise FileNotFoundError('The model is not found in the save directory.') from error @@ -158,16 +157,16 @@ def __load_model_and_env( / self._cfgs.algo_cfgs.max_ep_len * torch.ones(1) ) - assert isinstance(observation_space, Box), 'The observation space must be Box.' - assert isinstance(action_space, Box), 'The action space must be Box.' - if self._cfgs['algo_cfgs']['obs_normalize']: + if self._cfgs['algo_cfgs']['obs_normalize'] and isinstance(observation_space, Box): obs_normalizer = Normalizer(shape=observation_space.shape, clip=5) obs_normalizer.load_state_dict(model_params['obs_normalizer']) self._env = ObsNormalize(self._env, device=torch.device('cpu'), norm=obs_normalizer) if self._env.need_time_limit_wrapper: + self._cfgs['train_cfgs'].get('time_limit', 1000) self._env = TimeLimit(self._env, device=torch.device('cpu'), time_limit=1000) - self._env = ActionScale(self._env, device=torch.device('cpu'), low=-1.0, high=1.0) + if self._env.need_action_scale_wrapper: + self._env = ActionScale(self._env, device=torch.device('cpu'), low=-1.0, high=1.0) if hasattr(self._cfgs['algo_cfgs'], 'action_repeat'): self._env = ActionRepeat( @@ -183,11 +182,19 @@ def __load_model_and_env( 'RCEPETS', 'CCEPETS', ]: + assert isinstance(observation_space, Box), 'The observation space must be Box.' + assert isinstance(action_space, Box), 'The action space must be Box.' dynamics_state_space = ( self._env.coordinate_observation_space - if self._env.coordinate_observation_space is not None + if hasattr(self._env, 'coordinate_observation_space') else self._env.observation_space ) + get_cost_from_obs_tensor = ( + self._env.get_cost_from_obs_tensor + if hasattr(self._env, 'get_cost_from_obs_tensor') + else None + ) + assert self._env.action_space is not None and isinstance( self._env.action_space.shape, tuple, @@ -213,7 +220,7 @@ def __load_model_and_env( action_shape=action_space.shape, actor_critic=self._actor_critic, rew_func=None, - cost_func=self._env.get_cost_from_obs_tensor, + cost_func=get_cost_from_obs_tensor, terminal_func=None, ) self._dynamics.ensemble_model.load_state_dict(model_params['dynamics']) @@ -273,6 +280,8 @@ def __load_model_and_env( else: if 'Saute' in self._cfgs['algo'] or 'Simmer' in self._cfgs['algo']: + assert isinstance(observation_space, Box), 'The observation space must be Box.' + assert isinstance(action_space, Box), 'The action space must be Box.' observation_space = Box( low=np.hstack((observation_space.low, -np.inf)), high=np.hstack((observation_space.high, np.inf)), diff --git a/omnisafe/models/actor/__init__.py b/omnisafe/models/actor/__init__.py index db2e52daa..020224923 100644 --- a/omnisafe/models/actor/__init__.py +++ b/omnisafe/models/actor/__init__.py @@ -15,6 +15,7 @@ """The abstract interfaces of Actor networks for the Actor-Critic algorithm.""" from omnisafe.models.actor.actor_builder import ActorBuilder +from omnisafe.models.actor.categorical_actor import CategoricalActor from omnisafe.models.actor.gaussian_actor import GaussianActor from omnisafe.models.actor.gaussian_learning_actor import GaussianLearningActor from omnisafe.models.actor.gaussian_sac_actor import GaussianSACActor diff --git a/omnisafe/models/actor/actor_builder.py b/omnisafe/models/actor/actor_builder.py index cd1a0df15..b63a0a0a0 100644 --- a/omnisafe/models/actor/actor_builder.py +++ b/omnisafe/models/actor/actor_builder.py @@ -16,6 +16,7 @@ from __future__ import annotations +from omnisafe.models.actor.categorical_actor import CategoricalActor from omnisafe.models.actor.gaussian_learning_actor import GaussianLearningActor from omnisafe.models.actor.gaussian_sac_actor import GaussianSACActor from omnisafe.models.actor.mlp_actor import MLPActor @@ -29,13 +30,16 @@ class ActorBuilder: """Class for building actor networks. - Args: - obs_space (OmnisafeSpace): Observation space. - act_space (OmnisafeSpace): Action space. - hidden_sizes (list of int): List of hidden layer sizes. - activation (Activation, optional): Activation function. Defaults to ``'relu'``. - weight_initialization_mode (InitFunction, optional): Weight initialization mode. Defaults to - ``'kaiming_uniform'``. + Actor networks are used in the Actor design of Reinforcement Learning (RL) + to choose actions based on the current state of the environment. + + Attributes: + obs_space (OmnisafeSpace): The space that defines valid observations. + act_space (OmnisafeSpace): The space that defines valid actions. + hidden_sizes (list[int]): The number of nodes at each hidden layer in the network. + activation (str, optional): The activation function used after each layer. Defaults to ``'relu'``. + weight_initialization_mode (str, optional): The method to initialize weights in the network. + Defaults to ``'kaiming_uniform'``. """ def __init__( @@ -58,21 +62,24 @@ def build_actor( self, actor_type: ActorType, ) -> Actor: - """Build actor network. + """Generate an actor model of the given type using preset parameters. - Currently, we support the following actor types: - - ``gaussian_learning``: Gaussian actor with learnable standard deviation parameters. - - ``gaussian_sac``: Gaussian actor with learnable standard deviation network. - - ``mlp``: Multi-layer perceptron actor, used in ``DDPG`` and ``TD3``. + The supported actor types include: + - `gaussian_learning`: Gaussian actor with learnable standard deviation parameters. + - `gaussian_sac`: Gaussian actor with learnable standard deviation network. + - `mlp`: Multi-layer perceptron actor, typically used in DDPG and TD3. + - `vae`: Variational AutoEncoder actor, used for continual and low-data learning. + - `perturbation`: Perturbation Actor for domain randomization. + - `discrete`: Discrete/Categorical actor, used in environments with discrete action spaces. Args: - actor_type (ActorType): Type of actor network, e.g. ``gaussian_learning``. + actor_type (str): The type of actor network to build. Returns: - Actor network, ranging from GaussianLearningActor, GaussianSACActor to MLPActor. + Actor: An instance of the requested actor model. Raises: - NotImplementedError: If the actor type is not implemented. + NotImplementedError: If the requested actor type has not been implemented. """ if actor_type == 'gaussian_learning': return GaussianLearningActor( @@ -114,7 +121,15 @@ def build_actor( activation=self._activation, weight_initialization_mode=self._weight_initialization_mode, ) + if actor_type == 'discrete': + return CategoricalActor( + self._obs_space, + self._act_space, + self._hidden_sizes, + activation=self._activation, + weight_initialization_mode=self._weight_initialization_mode, + ) raise NotImplementedError( f'Actor type {actor_type} is not implemented! ' - f'Available actor types are: gaussian_learning, gaussian_sac, mlp, vae, perturbation.', + f'Available actor types are: gaussian_learning, gaussian_sac, mlp, vae, perturbation, discrete.', ) diff --git a/omnisafe/models/actor/categorical_actor.py b/omnisafe/models/actor/categorical_actor.py new file mode 100644 index 000000000..b5c7c65ec --- /dev/null +++ b/omnisafe/models/actor/categorical_actor.py @@ -0,0 +1,128 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of CategoricalActor.""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributions import Categorical, Distribution + +from omnisafe.models.base import Actor +from omnisafe.typing import Activation, InitFunction, OmnisafeSpace +from omnisafe.utils.model import build_mlp_network + + +# pylint: disable-next=too-many-instance-attributes +class CategoricalActor(Actor): + """Implementation of CategoricalActor. + + CategoricalActor is an actor suitable for discrete action. It is used in + discrete action space environment such as ``CartPole-v1`` and so on. + + Args: + obs_space (OmnisafeSpace): Observation space. + act_space (OmnisafeSpace): Action space. + hidden_sizes (list of int): List of hidden layer sizes. + activation (Activation, optional): Activation function. Defaults to ``'relu'``. + weight_initialization_mode (InitFunction, optional): Weight initialization mode. Defaults to + ``'kaiming_uniform'``. + """ + + _current_dist: Categorical + + def __init__( + self, + obs_space: OmnisafeSpace, + act_space: OmnisafeSpace, + hidden_sizes: list[int], + activation: Activation = 'relu', + weight_initialization_mode: InitFunction = 'kaiming_uniform', + ) -> None: + """Initialize an instance of :class:`CategoricalActor`.""" + super().__init__(obs_space, act_space, hidden_sizes, activation, weight_initialization_mode) + + self.logits: nn.Module = build_mlp_network( + sizes=[self._obs_dim, *self._hidden_sizes, self._act_dim], + activation=activation, + weight_initialization_mode=weight_initialization_mode, + ) + + def _distribution(self, obs: torch.Tensor) -> Categorical: + """Get the distribution of the actor. + + .. warning:: + This method is not supposed to be called by users. You should call :meth:`forward` + instead. + + Args: + obs (torch.Tensor): Observation from environments. + + Returns: + Categorical distribution over actions based on the actor's logits. + """ + logits = self.logits(obs) + return Categorical(logits=logits) + + def predict(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor: + """Predict the action based on given observations. + + The predicted action depends on the ``deterministic`` flag. + + - If ``deterministic`` is ``True``, the predicted action is the action with highest probability. + - If ``deterministic`` is ``False``, the predicted action is sampled from the distribution. + + Args: + obs (torch.Tensor): Observation from environments. + deterministic (bool, optional): Whether to use deterministic policy. Defaults to False. + + Returns: + The action with highest probability if deterministic is True, + otherwise a sampled action from the distribution. + """ + self._current_dist = self._distribution(obs=obs) + self._after_inference = True + if deterministic: + return torch.argmax(self._current_dist.logits, dim=0, keepdim=False) + return self._current_dist.sample() + + def forward(self, obs: torch.Tensor) -> Distribution: + """Forward method. + + Args: + obs (torch.Tensor): Observation from environments. + + Returns: + The current distribution. + """ + self._current_dist = self._distribution(obs) + self._after_inference = True + return self._current_dist + + def log_prob(self, act: torch.Tensor) -> torch.Tensor: + """Compute the log probability of the action given the current distribution. + + .. warning:: + You must call :meth:`forward` or :meth:`predict` before calling this method. + + Args: + act (torch.Tensor): Action from :meth:`predict` or :meth:`forward` . + + Returns: + Log probability of the action. + """ + assert self._after_inference, 'log_prob() should be called after predict() or forward()' + self._after_inference = False + return self._current_dist.log_prob(act) diff --git a/omnisafe/models/base.py b/omnisafe/models/base.py index 97c4db308..e46b63116 100644 --- a/omnisafe/models/base.py +++ b/omnisafe/models/base.py @@ -65,11 +65,15 @@ def __init__( if isinstance(self._obs_space, spaces.Box) and len(self._obs_space.shape) == 1: self._obs_dim: int = self._obs_space.shape[0] + elif isinstance(self._obs_space, spaces.Discrete): + self._obs_dim = 1 else: raise NotImplementedError if isinstance(self._act_space, spaces.Box) and len(self._act_space.shape) == 1: self._act_dim: int = self._act_space.shape[0] + elif isinstance(self._act_space, spaces.Discrete): + self._act_dim = int(self._act_space.n) else: raise NotImplementedError @@ -201,11 +205,15 @@ def __init__( self._use_obs_encoder: bool = use_obs_encoder if isinstance(self._obs_space, spaces.Box) and len(self._obs_space.shape) == 1: - self._obs_dim = self._obs_space.shape[0] + self._obs_dim: int = self._obs_space.shape[0] + elif isinstance(self._obs_space, spaces.Discrete): + self._obs_dim = 1 else: raise NotImplementedError if isinstance(self._act_space, spaces.Box) and len(self._act_space.shape) == 1: - self._act_dim = self._act_space.shape[0] + self._act_dim: int = self._act_space.shape[0] + elif isinstance(self._act_space, spaces.Discrete): + self._act_dim = int(self._act_space.n) else: raise NotImplementedError diff --git a/omnisafe/typing.py b/omnisafe/typing.py index bf73b558f..e536c4acb 100644 --- a/omnisafe/typing.py +++ b/omnisafe/typing.py @@ -39,7 +39,7 @@ AdvatageEstimator = Literal['gae', 'gae-rtg', 'vtrace', 'plain'] InitFunction = Literal['kaiming_uniform', 'xavier_normal', 'glorot', 'xavier_uniform', 'orthogonal'] CriticType = Literal['v', 'q'] -ActorType = Literal['gaussian_learning', 'gaussian_sac', 'mlp', 'vae', 'perturbation'] +ActorType = Literal['gaussian_learning', 'gaussian_sac', 'mlp', 'vae', 'perturbation', 'discrete'] DEVICE_CPU = torch.device('cpu') diff --git a/omnisafe/utils/config.py b/omnisafe/utils/config.py index 29698b5f6..83671fc64 100644 --- a/omnisafe/utils/config.py +++ b/omnisafe/utils/config.py @@ -255,7 +255,7 @@ def get_default_kwargs_yaml(algo: str, env_id: str, algo_type: str) -> Config: return default_kwargs -def check_all_configs(configs: Config, algo_type: str) -> None: +def check_all_configs(configs: Config, algo_type: str, env_type: str) -> None: """Check all configs. This function is used to check the configs. @@ -263,12 +263,48 @@ def check_all_configs(configs: Config, algo_type: str) -> None: Args: configs (Config): The configs to be checked. algo_type (str): The algorithm type. + env_type (str): The environment type """ __check_algo_configs(configs.algo_cfgs, algo_type) + __check_env_configs(configs, env_type) __check_parallel_and_vectorized(configs, algo_type) __check_logger_configs(configs.logger_cfgs) +def __check_env_configs(configs: Config, env_type: str) -> None: + """Check whether configs are aligned with the type of environment. + + Args: + configs (Config): The model configs to be checked. + env_type (str): The environment type. + """ + if env_type == 'discrete': + assert ( + configs.model_cfgs.actor_type == 'discrete' + ), 'Discrete environments only support discrete actor!' + assert configs.algo in [ + 'NaturalPG', + 'PolicyGradient', + 'PPO', + 'TRPO', + 'RCPO', + 'PDO', + 'PPOLag', + 'TRPOLag', + 'OnCRPO', + 'P3O', + 'IPO', + 'CPPOPID', + 'TRPOPID', + 'CPO', + 'PCPO', + ], f'Currently, OmniSafe does not support {configs.algo} running on discrete environments!' + if env_type == 'box': + assert ( + configs.model_cfgs.actor_type != 'discrete' + ), 'Box environments do not support discrete actor!' + + def __check_parallel_and_vectorized(configs: Config, algo_type: str) -> None: """Check parallel and vectorized configs. diff --git a/tests/simple_env.py b/tests/simple_env.py index 7d8b7eba9..814fd496f 100644 --- a/tests/simple_env.py +++ b/tests/simple_env.py @@ -35,6 +35,7 @@ class SimpleEnv(CMDP): metadata: ClassVar[dict[str, int]] = {'render_fps': 30} need_auto_reset_wrapper = True need_time_limit_wrapper = True + need_action_scale_wrapper: bool = False _num_envs = 1 _coordinate_observation_space: OmnisafeSpace diff --git a/tests/test_env.py b/tests/test_env.py index 8bafb457a..51336d78d 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -14,7 +14,7 @@ # ============================================================================== """Test envs.""" -from gymnasium.spaces import Box +from gymnasium.spaces import Box, Discrete import helpers from omnisafe.envs.core import make @@ -150,3 +150,49 @@ def test_mujoco(num_envs, env_id) -> None: assert isinstance(info, dict) env.close() + + +@helpers.parametrize( + num_envs=[1, 2], +) +def test_discrete(num_envs) -> None: + """Test envs.""" + env_id = 'CartPole-v1' + env = make(env_id, num_envs=num_envs) + + obs_space = env.observation_space + act_space = env.action_space + + assert isinstance(obs_space, Box) + assert isinstance(act_space, Discrete) + + env.set_seed(0) + obs, _ = env.reset() + if num_envs > 1: + assert obs.shape == (num_envs, obs_space.shape[0]) + else: + assert obs.shape == (obs_space.shape[0],) + + act = env.sample_action() + + obs, reward, cost, terminated, truncated, info = env.step(act) + + if num_envs > 1: + assert obs.shape == (num_envs, obs_space.shape[0]) + assert reward.shape == (num_envs,) + assert cost.shape == (num_envs,) + assert terminated.shape == (num_envs,) + assert truncated.shape == (num_envs,) + assert isinstance(info, dict) + else: + assert obs.shape == (obs_space.shape[0],) + assert reward.shape == () + assert cost.shape == () + assert terminated.shape == () + assert truncated.shape == () + assert isinstance(info, dict) + + env.close() + + +test_discrete(num_envs=2) diff --git a/tests/test_model.py b/tests/test_model.py index fdd2108b4..8789087ec 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -42,11 +42,11 @@ def test_critic( use_obs_encoder: bool, ) -> None: """Test critic.""" - obs_sapce = Box(low=-1.0, high=1.0, shape=(obs_dim,)) + obs_space = Box(low=-1.0, high=1.0, shape=(obs_dim,)) act_space = Box(low=-1.0, high=1.0, shape=(act_dim,)) builder = CriticBuilder( - obs_space=obs_sapce, + obs_space=obs_space, act_space=act_space, hidden_sizes=[hidden_sizes, hidden_sizes], activation=activation, @@ -81,11 +81,11 @@ def test_actor( deterministic: bool, ) -> None: """Test actor.""" - obs_sapce = Box(low=-1.0, high=1.0, shape=(obs_dim,)) + obs_space = Box(low=-1.0, high=1.0, shape=(obs_dim,)) act_space = Box(low=-1.0, high=1.0, shape=(act_dim,)) builder = ActorBuilder( - obs_space=obs_sapce, + obs_space=obs_space, act_space=act_space, hidden_sizes=[hidden_sizes, hidden_sizes], activation=activation, @@ -136,7 +136,7 @@ def test_actor_critic( """Test actor critic.""" obs_dim = 10 act_dim = 5 - obs_sapce = Box(low=-1.0, high=1.0, shape=(obs_dim,)) + obs_space = Box(low=-1.0, high=1.0, shape=(obs_dim,)) act_space = Box(low=-1.0, high=1.0, shape=(act_dim,)) model_cfgs = Config( @@ -150,7 +150,7 @@ def test_actor_critic( ) ac = ActorCritic( - obs_space=obs_sapce, + obs_space=obs_space, act_space=act_space, model_cfgs=model_cfgs, epochs=10, @@ -164,7 +164,7 @@ def test_actor_critic( ac.annealing(5) cac = ConstraintActorCritic( - obs_space=obs_sapce, + obs_space=obs_space, act_space=act_space, model_cfgs=model_cfgs, epochs=10, @@ -179,25 +179,38 @@ def test_actor_critic( cac.annealing(5) -@helpers.parametrize(obs_act_type=[('discrete', 'continuous'), ('continuous', 'discrete')]) -def test_raise_error(obs_act_type): - obs_type, act_type = obs_act_type - - obs_sapce = Discrete(10) if obs_type == 'discrete' else Box(low=-1.0, high=1.0, shape=(10,)) - act_space = Discrete(5) if act_type == 'discrete' else Box(low=-1.0, high=1.0, shape=(5,)) +@helpers.parametrize( + obs_dim=[10], + act_dim=[5], + hidden_sizes=[64], + activation=['tanh', 'relu'], + deterministic=[True, False], +) +def test_discrete_actor( + obs_dim: int, + act_dim: int, + hidden_sizes: int, + activation: Activation, + deterministic: bool, +) -> None: + """Test actor.""" + box_obs_space = Box(low=-1.0, high=1.0, shape=(obs_dim,)) + # discrete_obs_space = Discrete(1) + act_space = Discrete(act_dim) builder = ActorBuilder( - obs_space=obs_sapce, + obs_space=box_obs_space, act_space=act_space, - hidden_sizes=[3, 3], + hidden_sizes=[hidden_sizes, hidden_sizes], + activation=activation, ) + obs = torch.randn(obs_dim, dtype=torch.float32) + actor_discrete = builder.build_actor(actor_type='discrete') with pytest.raises(NotImplementedError): - builder.build_actor(actor_type='gaussian_learning') + builder.build_actor(actor_type='invalid') - builder = CriticBuilder( - obs_space=obs_sapce, - act_space=act_space, - hidden_sizes=[3, 3], - ) - with pytest.raises(NotImplementedError): - builder.build_critic(critic_type='q') + _ = actor_discrete(obs) + action = actor_discrete.predict(obs, deterministic) + assert action.shape == torch.Size([]), f'actor output shape is {action.shape}' + logp = actor_discrete.log_prob(action) + assert logp.shape == torch.Size([]), f'actor log_prob shape is {logp.shape}' diff --git a/tests/test_utils.py b/tests/test_utils.py index fe84974cc..5ba9524ed 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -80,12 +80,17 @@ def test_custom_cfgs_to_dict(): def test_config(): """Test config""" - cfg = Config(a=1, b={'c': 2}) + cfg = Config(a=1, b={'c': 2}, model_cfgs={'actor_type': 'gaussian_learning'}) cfg.a = 2 cfg.recurisve_update({'a': {'d': 3}, 'e': {'f': 4}}) cfg = get_default_kwargs_yaml('PPO', 'Simple-v0', 'on-policy') cfg.recurisve_update({'exp_name': 'test_configs', 'env_id': 'Simple-v0', 'algo': 'PPO'}) - check_all_configs(cfg, 'on-policy') + check_all_configs(cfg, 'on-policy', 'box') + with pytest.raises(AssertionError): + check_all_configs(cfg, 'off-pocliy', 'discrete') + cfg.recurisve_update({'model_cfgs': {'actor_type': 'discrete'}}) + with pytest.raises(AssertionError): + check_all_configs(cfg, 'on-pocliy', 'box') def test_distributed(): From f698e4ff523ebc6b8d54c29c1463db780182b7f2 Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Tue, 7 Nov 2023 13:35:15 +0800 Subject: [PATCH 2/6] chore: update config and docstring) --- omnisafe/common/buffer/base.py | 2 +- omnisafe/common/buffer/onpolicy_buffer.py | 2 +- omnisafe/common/buffer/vector_onpolicy_buffer.py | 2 +- omnisafe/configs/on-policy/CPO.yaml | 8 ++++---- omnisafe/configs/on-policy/CPPOPID.yaml | 6 ++++-- omnisafe/configs/on-policy/IPO.yaml | 6 ++++-- omnisafe/configs/on-policy/NaturalPG.yaml | 8 ++++---- omnisafe/configs/on-policy/OnCRPO.yaml | 8 ++++---- omnisafe/configs/on-policy/P3O.yaml | 6 ++++-- omnisafe/configs/on-policy/PCPO.yaml | 8 ++++---- omnisafe/configs/on-policy/PDO.yaml | 6 ++++-- omnisafe/configs/on-policy/PPO.yaml | 6 ++++-- omnisafe/configs/on-policy/PPOLag.yaml | 6 ++++-- omnisafe/configs/on-policy/PolicyGradient.yaml | 6 ++++-- omnisafe/configs/on-policy/RCPO.yaml | 8 ++++---- omnisafe/configs/on-policy/TRPO.yaml | 8 ++++---- omnisafe/configs/on-policy/TRPOLag.yaml | 8 ++++---- omnisafe/configs/on-policy/TRPOPID.yaml | 8 ++++---- omnisafe/envs/discrete_env.py | 4 +++- 19 files changed, 66 insertions(+), 50 deletions(-) diff --git a/omnisafe/common/buffer/base.py b/omnisafe/common/buffer/base.py index fe7be93fc..6b32c3583 100644 --- a/omnisafe/common/buffer/base.py +++ b/omnisafe/common/buffer/base.py @@ -28,7 +28,7 @@ class BaseBuffer(ABC): r"""Abstract base class for buffer. .. warning:: - The buffer only supports Box spaces. + The buffer only supports ``Box`` and ``Discrete`` spaces. In base buffer, we store the following data: diff --git a/omnisafe/common/buffer/onpolicy_buffer.py b/omnisafe/common/buffer/onpolicy_buffer.py index b6f9586df..8703bbabb 100644 --- a/omnisafe/common/buffer/onpolicy_buffer.py +++ b/omnisafe/common/buffer/onpolicy_buffer.py @@ -31,7 +31,7 @@ class OnPolicyBuffer(BaseBuffer): # pylint: disable=too-many-instance-attribute state-action pairs, ranging from ``GAE``, ``GAE-RTG`` , ``V-trace`` to ``Plain`` method. .. warning:: - The buffer only supports Box spaces. + The buffer only supports ``Box`` and ``Discrete`` spaces. Compared to the base buffer, the on-policy buffer stores extra data: diff --git a/omnisafe/common/buffer/vector_onpolicy_buffer.py b/omnisafe/common/buffer/vector_onpolicy_buffer.py index a920d8e6a..63c4d7fdf 100644 --- a/omnisafe/common/buffer/vector_onpolicy_buffer.py +++ b/omnisafe/common/buffer/vector_onpolicy_buffer.py @@ -30,7 +30,7 @@ class VectorOnPolicyBuffer(OnPolicyBuffer): stored in a list of on-policy buffers, each of which corresponds to one environment. .. warning:: - The buffer only supports Box spaces. + The buffer only supports ``Box`` and ``Discrete`` spaces. Args: obs_space (OmnisafeSpace): Observation space. diff --git a/omnisafe/configs/on-policy/CPO.yaml b/omnisafe/configs/on-policy/CPO.yaml index 012e29761..6645beaf9 100644 --- a/omnisafe/configs/on-policy/CPO.yaml +++ b/omnisafe/configs/on-policy/CPO.yaml @@ -131,7 +131,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -147,17 +147,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/configs/on-policy/CPPOPID.yaml b/omnisafe/configs/on-policy/CPPOPID.yaml index bca1e5ae3..7875f9cd9 100644 --- a/omnisafe/configs/on-policy/CPPOPID.yaml +++ b/omnisafe/configs/on-policy/CPPOPID.yaml @@ -147,7 +147,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -163,11 +163,13 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation diff --git a/omnisafe/configs/on-policy/IPO.yaml b/omnisafe/configs/on-policy/IPO.yaml index 594e387bb..54bbe1e74 100644 --- a/omnisafe/configs/on-policy/IPO.yaml +++ b/omnisafe/configs/on-policy/IPO.yaml @@ -139,7 +139,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -155,11 +155,13 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation diff --git a/omnisafe/configs/on-policy/NaturalPG.yaml b/omnisafe/configs/on-policy/NaturalPG.yaml index d1ba57d92..8a446fb7d 100644 --- a/omnisafe/configs/on-policy/NaturalPG.yaml +++ b/omnisafe/configs/on-policy/NaturalPG.yaml @@ -131,7 +131,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -147,17 +147,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/configs/on-policy/OnCRPO.yaml b/omnisafe/configs/on-policy/OnCRPO.yaml index 469820165..a47df1bad 100644 --- a/omnisafe/configs/on-policy/OnCRPO.yaml +++ b/omnisafe/configs/on-policy/OnCRPO.yaml @@ -133,7 +133,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -149,17 +149,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/configs/on-policy/P3O.yaml b/omnisafe/configs/on-policy/P3O.yaml index f065f954f..651a47492 100644 --- a/omnisafe/configs/on-policy/P3O.yaml +++ b/omnisafe/configs/on-policy/P3O.yaml @@ -127,7 +127,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -143,11 +143,13 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation diff --git a/omnisafe/configs/on-policy/PCPO.yaml b/omnisafe/configs/on-policy/PCPO.yaml index 012e29761..6645beaf9 100644 --- a/omnisafe/configs/on-policy/PCPO.yaml +++ b/omnisafe/configs/on-policy/PCPO.yaml @@ -131,7 +131,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -147,17 +147,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/configs/on-policy/PDO.yaml b/omnisafe/configs/on-policy/PDO.yaml index e5c39b91a..0a628935c 100644 --- a/omnisafe/configs/on-policy/PDO.yaml +++ b/omnisafe/configs/on-policy/PDO.yaml @@ -133,7 +133,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -149,11 +149,13 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation diff --git a/omnisafe/configs/on-policy/PPO.yaml b/omnisafe/configs/on-policy/PPO.yaml index 4bca1f265..313f89701 100644 --- a/omnisafe/configs/on-policy/PPO.yaml +++ b/omnisafe/configs/on-policy/PPO.yaml @@ -123,7 +123,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -139,11 +139,13 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation diff --git a/omnisafe/configs/on-policy/PPOLag.yaml b/omnisafe/configs/on-policy/PPOLag.yaml index b39e2b329..acfbd0ec9 100644 --- a/omnisafe/configs/on-policy/PPOLag.yaml +++ b/omnisafe/configs/on-policy/PPOLag.yaml @@ -133,7 +133,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -149,11 +149,13 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation diff --git a/omnisafe/configs/on-policy/PolicyGradient.yaml b/omnisafe/configs/on-policy/PolicyGradient.yaml index 18b19c1dd..6cb7780ee 100644 --- a/omnisafe/configs/on-policy/PolicyGradient.yaml +++ b/omnisafe/configs/on-policy/PolicyGradient.yaml @@ -121,7 +121,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -137,11 +137,13 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation diff --git a/omnisafe/configs/on-policy/RCPO.yaml b/omnisafe/configs/on-policy/RCPO.yaml index 001570e77..0ce722353 100644 --- a/omnisafe/configs/on-policy/RCPO.yaml +++ b/omnisafe/configs/on-policy/RCPO.yaml @@ -139,7 +139,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -155,17 +155,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/configs/on-policy/TRPO.yaml b/omnisafe/configs/on-policy/TRPO.yaml index 997e167bc..69ef8177e 100644 --- a/omnisafe/configs/on-policy/TRPO.yaml +++ b/omnisafe/configs/on-policy/TRPO.yaml @@ -129,7 +129,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -145,17 +145,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/configs/on-policy/TRPOLag.yaml b/omnisafe/configs/on-policy/TRPOLag.yaml index 001570e77..0ce722353 100644 --- a/omnisafe/configs/on-policy/TRPOLag.yaml +++ b/omnisafe/configs/on-policy/TRPOLag.yaml @@ -139,7 +139,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -155,17 +155,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/configs/on-policy/TRPOPID.yaml b/omnisafe/configs/on-policy/TRPOPID.yaml index 42b71a545..9c645640e 100644 --- a/omnisafe/configs/on-policy/TRPOPID.yaml +++ b/omnisafe/configs/on-policy/TRPOPID.yaml @@ -155,7 +155,7 @@ CartPole-v1: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode @@ -171,17 +171,17 @@ Taxi-v3: # logger configurations logger_cfgs: # save model frequency - save_model_freq: 5 + save_model_freq: 10 # training configurations train_cfgs: # max time-step for each episode time_limit: 200 + # total number of steps to train + total_steps: 1000000 # algorithm configurations algo_cfgs: # normalize observation obs_normalize: False - # entropy coefficient - entropy_coef: 0.01 # model configurations model_cfgs: # actor type, options: gaussian, gaussian_learning diff --git a/omnisafe/envs/discrete_env.py b/omnisafe/envs/discrete_env.py index 92662288f..4e7544ccc 100644 --- a/omnisafe/envs/discrete_env.py +++ b/omnisafe/envs/discrete_env.py @@ -32,7 +32,9 @@ class DiscreteEnv(CMDP): """Discrete Gymnasium Environment. This environment only served as an example to integrate discrete action and - observation environment into OmniSafe. We support ``CartPole-v1`` and ``Taxi-v3``. + observation environment into OmniSafe. We support + `CartPole-v1 `_ + and `Taxi-v3 `_. The former is ``Box`` observation space and ``Discrete`` action space, while the latter is ``Discrete`` observation and ``Discrete`` action space. From d8c32992f72d5b3ff0914ba571d7f7967707b092 Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 13 Nov 2023 10:50:36 +0800 Subject: [PATCH 3/6] chore: update test --- omnisafe/common/buffer/base.py | 24 ++++++++++++++-------- omnisafe/envs/core.py | 2 +- omnisafe/envs/discrete_env.py | 2 +- omnisafe/models/actor/categorical_actor.py | 9 +++++--- omnisafe/models/base.py | 8 +++----- tests/test_buffer.py | 17 ++++++++------- 6 files changed, 37 insertions(+), 25 deletions(-) diff --git a/omnisafe/common/buffer/base.py b/omnisafe/common/buffer/base.py index 6b32c3583..e7dd63da8 100644 --- a/omnisafe/common/buffer/base.py +++ b/omnisafe/common/buffer/base.py @@ -18,6 +18,7 @@ from abc import ABC, abstractmethod +import numpy as np import torch from gymnasium.spaces import Box, Discrete @@ -70,16 +71,23 @@ def __init__( ) -> None: """Initialize an instance of :class:`BaseBuffer`.""" self._device: torch.device = device - if isinstance(obs_space, Box): - obs_buf = torch.zeros((size, *obs_space.shape), dtype=torch.float32, device=device) - elif isinstance(obs_space, Discrete): - obs_buf = torch.zeros((size, 1), dtype=torch.float32, device=device) + + if isinstance(obs_space, (Box, Discrete)): + obs_buf = torch.zeros( + (size, int(np.array(obs_space.shape).prod())), + dtype=torch.float32, + device=device, + ) else: raise NotImplementedError - if isinstance(act_space, Box): - act_buf = torch.zeros((size, *act_space.shape), dtype=torch.float32, device=device) - elif isinstance(act_space, Discrete): - act_buf = torch.zeros((size), dtype=torch.float32, device=device) + + if isinstance(act_space, (Box, Discrete)): + act_buf = torch.zeros( + (size, int(np.array(act_space.shape).prod())), + dtype=torch.float32, + device=device, + ) + else: raise NotImplementedError diff --git a/omnisafe/envs/core.py b/omnisafe/envs/core.py index 137d8eb51..098c7f8e3 100644 --- a/omnisafe/envs/core.py +++ b/omnisafe/envs/core.py @@ -259,7 +259,7 @@ def reset( observation: The initial observation of the space. info: Some information logged by the environment. """ - return self._env.reset(seed=seed, options=options) + return self._env.reset(seed=seed) def set_seed(self, seed: int) -> None: """Set the seed for this env's random number generator(s). diff --git a/omnisafe/envs/discrete_env.py b/omnisafe/envs/discrete_env.py index 4e7544ccc..29e46f1e3 100644 --- a/omnisafe/envs/discrete_env.py +++ b/omnisafe/envs/discrete_env.py @@ -128,7 +128,7 @@ def step( info: Some information logged by the environment. """ obs, reward, terminated, truncated, info = self._env.step( - action.detach().cpu().numpy().tolist(), + action.detach().cpu().squeeze().numpy(), ) obs, reward, terminated, truncated = ( torch.as_tensor(x, dtype=torch.float32, device=self._device) diff --git a/omnisafe/models/actor/categorical_actor.py b/omnisafe/models/actor/categorical_actor.py index b5c7c65ec..f2375e7b8 100644 --- a/omnisafe/models/actor/categorical_actor.py +++ b/omnisafe/models/actor/categorical_actor.py @@ -16,6 +16,7 @@ from __future__ import annotations +import numpy as np import torch import torch.nn as nn from torch.distributions import Categorical, Distribution @@ -95,8 +96,10 @@ def predict(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tenso self._current_dist = self._distribution(obs=obs) self._after_inference = True if deterministic: - return torch.argmax(self._current_dist.logits, dim=0, keepdim=False) - return self._current_dist.sample() + action = torch.argmax(self._current_dist.logits, dim=-1, keepdim=True) + else: + action = self._current_dist.sample() + return action.view(-1, int(np.array(self._act_space.shape).prod())) def forward(self, obs: torch.Tensor) -> Distribution: """Forward method. @@ -125,4 +128,4 @@ def log_prob(self, act: torch.Tensor) -> torch.Tensor: """ assert self._after_inference, 'log_prob() should be called after predict() or forward()' self._after_inference = False - return self._current_dist.log_prob(act) + return self._current_dist.log_prob(act.squeeze()) diff --git a/omnisafe/models/base.py b/omnisafe/models/base.py index e46b63116..18c96b1d7 100644 --- a/omnisafe/models/base.py +++ b/omnisafe/models/base.py @@ -18,6 +18,7 @@ from abc import ABC, abstractmethod +import numpy as np import torch import torch.nn as nn from gymnasium import spaces @@ -62,11 +63,8 @@ def __init__( self._activation: Activation = activation self._hidden_sizes: list[int] = hidden_sizes self._after_inference: bool = False - - if isinstance(self._obs_space, spaces.Box) and len(self._obs_space.shape) == 1: - self._obs_dim: int = self._obs_space.shape[0] - elif isinstance(self._obs_space, spaces.Discrete): - self._obs_dim = 1 + if isinstance(self._obs_space, (spaces.Box, spaces.Discrete)): + self._obs_dim: int = int(np.array(self._obs_space.shape).prod()) else: raise NotImplementedError diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 0fee90a46..03a90d82f 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -14,8 +14,11 @@ # ============================================================================== """Test Buffers.""" +from __future__ import annotations + +import numpy as np import torch -from gymnasium.spaces import Box +from gymnasium.spaces import Box, Discrete import helpers from omnisafe.common.buffer import ( @@ -27,8 +30,8 @@ @helpers.parametrize( - obs_space=[Box(low=-1, high=1, shape=(1,))], - act_space=[Box(low=-1, high=1, shape=(1,))], + obs_space=[Box(low=-1, high=1, shape=(1,)), Discrete(n=5)], + act_space=[Box(low=-1, high=1, shape=(1,)), Discrete(n=5)], size=[100], gamma=[0.9], lam=[0.9], @@ -41,8 +44,8 @@ num_envs=[2], ) def test_vector_onpolicy_buffer( - obs_space: Box, - act_space: Box, + obs_space: Box | Discrete, + act_space: Box | Discrete, size: int, gamma: float, lam: float, @@ -82,8 +85,8 @@ def test_vector_onpolicy_buffer( assert vector_buffer.buffers is not [], f'vector_buffer.buffers is {vector_buffer.buffers}' # checking the store function - obs_dim = obs_space.shape[0] - act_dim = act_space.shape[0] + obs_dim = int(np.array(obs_space.shape).prod()) + act_dim = int(np.array(act_space.shape).prod()) for _ in range(size): obs = torch.rand((num_envs, obs_dim), dtype=torch.float32, device=device) act = torch.rand((num_envs, act_dim), dtype=torch.float32, device=device) From 69bd4639068e3e245d86f509d5843f9fc2581010 Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 13 Nov 2023 10:50:58 +0800 Subject: [PATCH 4/6] chore: update test --- omnisafe/envs/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnisafe/envs/core.py b/omnisafe/envs/core.py index 098c7f8e3..137d8eb51 100644 --- a/omnisafe/envs/core.py +++ b/omnisafe/envs/core.py @@ -259,7 +259,7 @@ def reset( observation: The initial observation of the space. info: Some information logged by the environment. """ - return self._env.reset(seed=seed) + return self._env.reset(seed=seed, options=options) def set_seed(self, seed: int) -> None: """Set the seed for this env's random number generator(s). From a112651b778267205770b00fa613b4d338d2a38e Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 13 Nov 2023 11:11:35 +0800 Subject: [PATCH 5/6] chore: update test --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 8789087ec..42b6ceaaa 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -211,6 +211,6 @@ def test_discrete_actor( _ = actor_discrete(obs) action = actor_discrete.predict(obs, deterministic) - assert action.shape == torch.Size([]), f'actor output shape is {action.shape}' + assert action.shape == torch.Size([1, 1]), f'actor output shape is {action.shape}' logp = actor_discrete.log_prob(action) assert logp.shape == torch.Size([]), f'actor log_prob shape is {logp.shape}' From e68bf5ca139c435474a7a5cc579524d41cff9966 Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Wed, 22 Nov 2023 21:41:14 +0800 Subject: [PATCH 6/6] chore: clean the code --- omnisafe/envs/discrete_env.py | 1 - omnisafe/evaluator.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/omnisafe/envs/discrete_env.py b/omnisafe/envs/discrete_env.py index 29e46f1e3..60eb5ccb2 100644 --- a/omnisafe/envs/discrete_env.py +++ b/omnisafe/envs/discrete_env.py @@ -165,7 +165,6 @@ def reset( seed (int, optional): The random seed. Defaults to None. options (dict[str, Any], optional): The options for the environment. Defaults to None. - Returns: observation: Agent's observation of the current environment. info: Some information logged by the environment. diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py index 4d25da862..e46051d5c 100644 --- a/omnisafe/evaluator.py +++ b/omnisafe/evaluator.py @@ -163,8 +163,8 @@ def __load_model_and_env( obs_normalizer.load_state_dict(model_params['obs_normalizer']) self._env = ObsNormalize(self._env, device=torch.device('cpu'), norm=obs_normalizer) if self._env.need_time_limit_wrapper: - self._cfgs['train_cfgs'].get('time_limit', 1000) - self._env = TimeLimit(self._env, device=torch.device('cpu'), time_limit=1000) + time_limit = self._cfgs['train_cfgs'].get('time_limit', 1000) + self._env = TimeLimit(self._env, device=torch.device('cpu'), time_limit=time_limit) if self._env.need_action_scale_wrapper: self._env = ActionScale(self._env, device=torch.device('cpu'), low=-1.0, high=1.0) @@ -287,7 +287,6 @@ def __load_model_and_env( high=np.hstack((observation_space.high, np.inf)), shape=(observation_space.shape[0] + 1,), ) - actor_type = self._cfgs['model_cfgs']['actor_type'] pi_cfg = self._cfgs['model_cfgs']['actor'] weight_initialization_mode = self._cfgs['model_cfgs']['weight_initialization_mode'] actor_builder = ActorBuilder( @@ -297,7 +296,7 @@ def __load_model_and_env( activation=pi_cfg['activation'], weight_initialization_mode=weight_initialization_mode, ) - self._actor = actor_builder.build_actor(actor_type) + self._actor = actor_builder.build_actor(self._cfgs['model_cfgs']['actor_type']) self._actor.load_state_dict(model_params['pi']) # pylint: disable-next=too-many-locals