Multi-Discrete Example¶
gym uses multi-discrete
to refer to describe environments which have multiple discrete action spaces, a simple example is shown as follows:
import gym
from gym.spaces import Discrete, MultiDiscrete
"""
e.g. Nintendo Game Controller
- Can be conceptualized as 3 discrete action spaces:
1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
"""
# discrete action space env
env = gym.make('PongNoFrameskip-v4')
assert env.action_space == Discrete(6)
# multi discrete action space
md_space = MultiDiscrete([2, 3]) # 6 = 2 * 3
In this page, a simple case of multi-discrete environment along with the multi-discrete implementation of DQN is provided.
Example¶
Here we provide a toy case of multi-discrete environment, which is derived by factorizing the single action space of Atari games in to the Cartesian product of multiple action spaces, e.g. 6=2*3.
class MultiDiscreteEnv(gym.Wrapper):
"""Map the actions from the factorized action spaces to the original single action space.
:param gym.Env env: the environment to wrap.
:param list action_shape: dims of the the factorized action spaces.
"""
def __init__(self, env, action_shape):
super().__init__(env)
self.action_shape = np.flip(np.cumprod(np.flip(np.array(action_shape))))
def step(self, action):
"""
Overview:
Step the environment with the given factorized actions.
Arguments:
- action (:obj:`list`): a list contains the action output of each discrete dimension, e.g.: [1, 1] means 1 * 3 + 1 = 4 for a factorized action 2 * 3 = 6
"""
action = action[0] * self.action_shape[1] + action[0]
obs, reward, done, info = self.env.step(action)
return obs, reward, done, info
Accordingly, the config of a multi-discrete experiment should be altered by changing the action_shape
from an integer into the list of the dims of the factorized action spaces, which locates at config.policy.model
and env.info()
. Also, the key multi_discrete
in config.env
should be set True to utilize the MultiDiscreteEnv
wrapper.
Then we provide a multi-discrete version of DQN implementation. The multi-discrete version DQNMultiDiscretePolicy
inherits DQNPolicy
and only overrides the _forward_learn
interface. In the Q-learning forward part of this overrode version, each action space calculates its own q-value, action and td loss with the global rewards, following the same process of the single action space.
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value = self._target_model.forward(data['next_obs'])['logit']
# Max q value action (main model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']
value_gamma = data.get('value_gamma')
if isinstance(q_value, torch.Tensor):
data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma)
else:
action_num = len(q_value)
loss, td_error_per_sample = [], []
for i in range(action_num):
td_data = q_nstep_td_data(
q_value[i], target_q_value[i], data['action'][i], target_q_action[i], data['reward'], data['done'], data['weight']
)
loss_, td_error_per_sample_ = q_nstep_td_error(td_data, self._gamma, nstep=self._nstep)
loss.append(loss_)
td_error_per_sample.append(td_error_per_sample_.abs())
loss = sum(loss) / (len(loss) + 1e-8)
td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
For the complete code, you can refer to dizoo/common/policy/md_dqn.py