td3_vae.py 32.0 KB
Newer Older
P
puyuan1996 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import torch
import copy

from ding.torch_utils import Adam, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
from .ddpg import DDPGPolicy
P
puyuan1996 已提交
14
from ding.model.template.vae import VanillaVAE
15 16
from ding.utils import RunningMeanStd
from torch.nn import functional as F
P
puyuan1996 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99


@POLICY_REGISTRY.register('td3-vae')
class TD3VAEPolicy(DDPGPolicy):
    r"""
    Overview:
        Policy class of TD3 algorithm.

        Since DDPG and TD3 share many common things, we can easily derive this TD3
        class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper.

        https://arxiv.org/pdf/1802.09477.pdf

    Property:
        learn_mode, collect_mode, eval_mode

    Config:

    == ====================  ========    ==================  =================================   =======================
    ID Symbol                Type        Default Value       Description                         Other(Shape)
    == ====================  ========    ==================  =================================   =======================
    1  ``type``              str         td3                 | RL policy register name, refer    | this arg is optional,
                                                             | to registry ``POLICY_REGISTRY``   | a placeholder
    2  ``cuda``              bool        True                | Whether to use cuda for network   |
    3  | ``random_``         int         25000               | Number of randomly collected      | Default to 25000 for
       | ``collect_size``                                    | training samples in replay        | DDPG/TD3, 10000 for
       |                                                     | buffer when training starts.      | sac.
    4  | ``model.twin_``     bool        True                | Whether to use two critic         | Default True for TD3,
       | ``critic``                                          | networks or only one.             | Clipped Double
       |                                                     |                                   | Q-learning method in
       |                                                     |                                   | TD3 paper.
    5  | ``learn.learning``  float       1e-3                | Learning rate for actor           |
       | ``_rate_actor``                                     | network(aka. policy).             |
    6  | ``learn.learning``  float       1e-3                | Learning rates for critic         |
       | ``_rate_critic``                                    | network (aka. Q-network).         |
    7  | ``learn.actor_``    int         2                   | When critic network updates       | Default 2 for TD3, 1
       | ``update_freq``                                     | once, how many times will actor   | for DDPG. Delayed
       |                                                     | network update.                   | Policy Updates method
       |                                                     |                                   | in TD3 paper.
    8  | ``learn.noise``     bool        True                | Whether to add noise on target    | Default True for TD3,
       |                                                     | network's action.                 | False for DDPG.
       |                                                     |                                   | Target Policy Smoo-
       |                                                     |                                   | thing Regularization
       |                                                     |                                   | in TD3 paper.
    9  | ``learn.noise_``    dict        | dict(min=-0.5,    | Limit for range of target         |
       | ``range``                       |      max=0.5,)    | policy smoothing noise,           |
       |                                 |                   | aka. noise_clip.                  |
    10 | ``learn.-``         bool        False               | Determine whether to ignore       | Use ignore_done only
       | ``ignore_done``                                     | done flag.                        | in halfcheetah env.
    11 | ``learn.-``         float       0.005               | Used for soft update of the       | aka. Interpolation
       | ``target_theta``                                    | target network.                   | factor in polyak aver
       |                                                     |                                   | aging for target
       |                                                     |                                   | networks.
    12 | ``collect.-``       float       0.1                 | Used for add noise during co-     | Sample noise from dis
       | ``noise_sigma``                                     | llection, through controlling     | tribution, Ornstein-
       |                                                     | the sigma of distribution         | Uhlenbeck process in
       |                                                     |                                   | DDPG paper, Guassian
       |                                                     |                                   | process in ours.
    == ====================  ========    ==================  =================================   =======================
   """

    # You can refer to DDPG's default config for more details.
    config = dict(
        # (str) RL policy register name (refer to function "POLICY_REGISTRY").
        type='td3',
        # (bool) Whether to use cuda for network.
        cuda=False,
        # (bool type) on_policy: Determine whether on-policy or off-policy.
        # on-policy setting influences the behaviour of buffer.
        # Default False in TD3.
        on_policy=False,
        # (bool) Whether use priority(priority sample, IS weight, update priority)
        # Default False in TD3.
        priority=False,
        # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
        priority_IS_weight=False,
        # (int) Number of training samples(randomly collected) in replay buffer when training starts.
        # Default 25000 in DDPG/TD3.
        random_collect_size=25000,
        # (str) Action space type
        action_space='continuous',  # ['continuous', 'hybrid']
        # (bool) Whether use batch normalization for reward
        reward_batch_norm=False,
P
puyuan1996 已提交
100
        original_action_shape=2,
P
puyuan1996 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        model=dict(
            # (bool) Whether to use two critic networks or only one.
            # Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
            # Default True for TD3, False for DDPG.
            twin_critic=True,
        ),
        learn=dict(
            multi_gpu=False,
            # How many updates(iterations) to train after collector's one collection.
            # Bigger "update_per_collect" means bigger off-policy.
            # collect data -> update policy-> collect data -> ...
            update_per_collect=1,
            # (int) Minibatch size for gradient descent.
            batch_size=256,
            # (float) Learning rates for actor network(aka. policy).
            learning_rate_actor=1e-3,
            # (float) Learning rates for critic network(aka. Q-network).
            learning_rate_critic=1e-3,
            # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
            # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
            # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
            # However, interaction with HalfCheetah always gets done with False,
            # Since we inplace done==True with done==False to keep
            # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
            # when the episode step is greater than max episode step.
            ignore_done=False,
            # (float type) target_theta: Used for soft update of the target network,
            # aka. Interpolation factor in polyak averaging for target networks.
            # Default to 0.005.
            target_theta=0.005,
            # (float) discount factor for the discounted sum of rewards, aka. gamma.
            discount_factor=0.99,
            # (int) When critic network updates once, how many times will actor network update.
            # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
            # Default 1 for DDPG, 2 for TD3.
            actor_update_freq=2,
            # (bool) Whether to add noise on target network's action.
            # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
            # Default True for TD3, False for DDPG.
            noise=True,
            # (float) Sigma for smoothing noise added to target policy.
            noise_sigma=0.2,
            # (dict) Limit for range of target policy smoothing noise, aka. noise_clip.
            noise_range=dict(
                min=-0.5,
                max=0.5,
            ),
        ),
        collect=dict(
150
            # n_sample=1,
151
            # each_iter_n_sample=48,
P
puyuan1996 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
            # (int) Cut trajectories into pieces with length "unroll_len".
            unroll_len=1,
            # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
            noise_sigma=0.1,
        ),
        eval=dict(
            evaluator=dict(
                # (int) Evaluate every "eval_freq" training iterations.
                eval_freq=5000,
            ),
        ),
        other=dict(
            replay_buffer=dict(
                # (int) Maximum size of replay buffer.
                replay_buffer_size=100000,
            ),
        ),
    )

    def _init_learn(self) -> None:
        r"""
        Overview:
            Learn mode init method. Called by ``self.__init__``.
            Init actor and critic optimizers, algorithm config, main and target models.
        """
        self._priority = self._cfg.priority
        self._priority_IS_weight = self._cfg.priority_IS_weight
        # actor and critic optimizer
        self._optimizer_actor = Adam(
            self._model.actor.parameters(),
            lr=self._cfg.learn.learning_rate_actor,
        )
        self._optimizer_critic = Adam(
            self._model.critic.parameters(),
            lr=self._cfg.learn.learning_rate_critic,
        )
        self._reward_batch_norm = self._cfg.reward_batch_norm

        self._gamma = self._cfg.learn.discount_factor
        self._actor_update_freq = self._cfg.learn.actor_update_freq
        self._twin_critic = self._cfg.model.twin_critic  # True for TD3, False for DDPG

        # main and target models
        self._target_model = copy.deepcopy(self._model)
        if self._cfg.action_space == 'hybrid':
            self._target_model = model_wrap(self._target_model, wrapper_name='hybrid_argmax_sample')
        self._target_model = model_wrap(
            self._target_model,
            wrapper_name='target',
            update_type='momentum',
            update_kwargs={'theta': self._cfg.learn.target_theta}
        )
        if self._cfg.learn.noise:
            self._target_model = model_wrap(
                self._target_model,
                wrapper_name='action_noise',
                noise_type='gauss',
                noise_kwargs={
                    'mu': 0.0,
                    'sigma': self._cfg.learn.noise_sigma
                },
                noise_range=self._cfg.learn.noise_range
            )
        self._learn_model = model_wrap(self._model, wrapper_name='base')
        if self._cfg.action_space == 'hybrid':
            self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample')
        self._learn_model.reset()
        self._target_model.reset()

        self._forward_learn_cnt = 0  # count iterations
P
puyuan1996 已提交
222 223 224 225
        # action_shape, obs_shape, latent_action_dim, hidden_size_list
        self._vae_model = VanillaVAE(
            self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256]
        )
226 227
        # self._vae_model = VanillaVAE(2, 8, 6, [256, 256, 256])

P
puyuan1996 已提交
228 229 230 231
        self._optimizer_vae = Adam(
            self._vae_model.parameters(),
            lr=self._cfg.learn.learning_rate_vae,
        )
232
        self._running_mean_std_predict_loss = RunningMeanStd(epsilon=1e-4)
P
puyuan1996 已提交
233
        self.c_percentage_bound_lower = -1 * torch.ones([6])
P
puyuan1996 已提交
234
        self.c_percentage_bound_upper = torch.ones([6])
235

P
puyuan1996 已提交
236 237 238 239 240 241 242 243 244
    def _forward_learn(self, data: dict) -> Dict[str, Any]:
        r"""
        Overview:
            Forward and backward function of learn mode.
        Arguments:
            - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
        Returns:
            - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
        """
245
        # warmup phase
P
puyuan1996 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
        if 'warm_up' in data[0].keys() and data[0]['warm_up'] is True:
            loss_dict = {}
            data = default_preprocess_learn(
                data,
                use_priority=self._cfg.priority,
                use_priority_IS_weight=self._cfg.priority_IS_weight,
                ignore_done=self._cfg.learn.ignore_done,
                use_nstep=False
            )
            if self._cuda:
                data = to_device(data, self._device)

            # ====================
            # train vae
            # ====================
261
            result = self._vae_model({'action': data['action'], 'obs': data['obs']})
P
puyuan1996 已提交
262

P
puyuan1996 已提交
263 264 265
            result['original_action'] = data['action']
            result['true_residual'] = data['next_obs'] - data['obs']

266
            vae_loss = self._vae_model.loss_function(result, kld_weight=0.01, predict_weight=0.01)  # TODO(pu): weight
267

P
puyuan1996 已提交
268
            loss_dict['vae_loss'] = vae_loss['loss'].item()
P
puyuan1996 已提交
269
            loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'].item()
P
puyuan1996 已提交
270 271
            loss_dict['kld_loss'] = vae_loss['kld_loss'].item()
            loss_dict['predict_loss'] = vae_loss['predict_loss'].item()
272
            self._running_mean_std_predict_loss.update(vae_loss['predict_loss'].unsqueeze(-1).cpu().detach().numpy())
P
puyuan1996 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

            # vae update
            self._optimizer_vae.zero_grad()
            vae_loss['loss'].backward()
            self._optimizer_vae.step()
            # For compatibility
            loss_dict['actor_loss'] = torch.Tensor([0]).item()
            loss_dict['critic_loss'] = torch.Tensor([0]).item()
            loss_dict['critic_twin_loss'] = torch.Tensor([0]).item()
            loss_dict['total_loss'] = torch.Tensor([0]).item()
            q_value_dict = {}
            q_value_dict['q_value'] = torch.Tensor([0]).item()
            q_value_dict['q_value_twin'] = torch.Tensor([0]).item()
            return {
                'cur_lr_actor': self._optimizer_actor.defaults['lr'],
                'cur_lr_critic': self._optimizer_critic.defaults['lr'],
                'action': torch.Tensor([0]).item(),
                'priority': torch.Tensor([0]).item(),
                'td_error': torch.Tensor([0]).item(),
                **loss_dict,
                **q_value_dict,
            }
P
puyuan1996 已提交
295
        else:
296
            self._forward_learn_cnt += 1
P
puyuan1996 已提交
297
            loss_dict = {}
298
            q_value_dict = {}
P
puyuan1996 已提交
299 300 301 302 303 304 305
            data = default_preprocess_learn(
                data,
                use_priority=self._cfg.priority,
                use_priority_IS_weight=self._cfg.priority_IS_weight,
                ignore_done=self._cfg.learn.ignore_done,
                use_nstep=False
            )
306
            if data['vae_phase'][0].item() is True:
P
puyuan1996 已提交
307 308
                if self._cuda:
                    data = to_device(data, self._device)
P
puyuan1996 已提交
309

P
puyuan1996 已提交
310 311 312
                # ====================
                # train vae
                # ====================
313
                result = self._vae_model({'action': data['action'], 'obs': data['obs']})
P
puyuan1996 已提交
314 315 316

                result['original_action'] = data['action']
                result['true_residual'] = data['next_obs'] - data['obs']
P
puyuan1996 已提交
317

P
puyuan1996 已提交
318
                # latent space constraint (LSC)
319
                # NOTE: using tanh is important, update latent_action using z, shape (128,6)
P
puyuan1996 已提交
320
                data['latent_action'] = torch.tanh(result['z'].clone().detach())  # NOTE: tanh
321
                # data['latent_action'] = result['z'].clone().detach()
P
puyuan1996 已提交
322 323 324 325 326 327
                self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(
                    result['recons_action'].shape[0] * 0.02
                ), :]  # values, indices
                self.c_percentage_bound_upper = data['latent_action'].sort(
                    dim=0
                )[0][int(result['recons_action'].shape[0] * 0.98), :]
328

P
puyuan1996 已提交
329 330 331
                vae_loss = self._vae_model.loss_function(
                    result, kld_weight=0.01, predict_weight=0.01
                )  # TODO(pu): weight
P
puyuan1996 已提交
332

P
puyuan1996 已提交
333 334 335
                loss_dict['vae_loss'] = vae_loss['loss']
                loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
                loss_dict['kld_loss'] = vae_loss['kld_loss']
336
                loss_dict['predict_loss'] = vae_loss['predict_loss']
P
puyuan1996 已提交
337

P
puyuan1996 已提交
338 339 340 341
                # vae update
                self._optimizer_vae.zero_grad()
                vae_loss['loss'].backward()
                self._optimizer_vae.step()
P
puyuan1996 已提交
342

P
puyuan1996 已提交
343 344 345 346 347 348 349 350 351 352
                return {
                    'cur_lr_actor': self._optimizer_actor.defaults['lr'],
                    'cur_lr_critic': self._optimizer_critic.defaults['lr'],
                    # 'q_value': np.array(q_value).mean(),
                    'action': torch.Tensor([0]).item(),
                    'priority': torch.Tensor([0]).item(),
                    'td_error': torch.Tensor([0]).item(),
                    **loss_dict,
                    **q_value_dict,
                }
P
puyuan1996 已提交
353

P
puyuan1996 已提交
354
            else:
355 356 357 358 359 360 361
                # ====================
                # critic learn forward
                # ====================
                self._learn_model.train()
                self._target_model.train()
                next_obs = data['next_obs']
                reward = data['reward']
P
puyuan1996 已提交
362 363 364 365 366 367

                # ====================
                # relabel latent action
                # ====================
                if self._cuda:
                    data = to_device(data, self._device)
P
puyuan1996 已提交
368
                result = self._vae_model({'action': data['action'], 'obs': data['obs']})
369
                true_residual = data['next_obs'] - data['obs']
P
puyuan1996 已提交
370 371

                # Representation shift correction (RSC)
P
puyuan1996 已提交
372
                for i in range(result['recons_action'].shape[0]):
P
puyuan1996 已提交
373 374
                    if F.mse_loss(result['prediction_residual'][i],
                                  true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
375
                        # NOTE: using tanh is important, update latent_action using z
P
puyuan1996 已提交
376
                        data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach())  # NOTE: tanh
377 378
                        # data['latent_action'][i] = result['z'][i].clone().detach()

P
puyuan1996 已提交
379 380
                # update all latent action
                # data['latent_action'] = torch.tanh(result['z'].clone().detach())
381

382 383 384 385
                if self._reward_batch_norm:
                    reward = (reward - reward.mean()) / (reward.std() + 1e-8)

                # current q value
P
puyuan1996 已提交
386 387 388 389 390 391
                q_value = self._learn_model.forward(
                    {
                        'obs': data['obs'],
                        'action': data['latent_action']
                    }, mode='compute_critic'
                )['q_value']
392
                q_value_dict = {}
P
puyuan1996 已提交
393
                if self._twin_critic:
394 395
                    q_value_dict['q_value'] = q_value[0].mean()
                    q_value_dict['q_value_twin'] = q_value[1].mean()
P
puyuan1996 已提交
396
                else:
397 398 399
                    q_value_dict['q_value'] = q_value.mean()
                # target q value.
                with torch.no_grad():
400 401
                    # NOTE: here  next_actor_data['action'] is latent action
                    next_actor_data = self._target_model.forward(next_obs, mode='compute_actor')
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
                    next_actor_data['obs'] = next_obs
                    target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value']
                if self._twin_critic:
                    # TD3: two critic networks
                    target_q_value = torch.min(target_q_value[0], target_q_value[1])  # find min one as target q value
                    # critic network1
                    td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
                    critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
                    loss_dict['critic_loss'] = critic_loss
                    # critic network2(twin network)
                    td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
                    critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
                    loss_dict['critic_twin_loss'] = critic_twin_loss
                    td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
                else:
                    # DDPG: single critic network
                    td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
                    critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
                    loss_dict['critic_loss'] = critic_loss
                # ================
                # critic update
                # ================
                self._optimizer_critic.zero_grad()
                for k in loss_dict:
                    if 'critic' in k:
                        loss_dict[k].backward()
                self._optimizer_critic.step()
                # ===============================
                # actor learn forward and update
                # ===============================
                # actor updates every ``self._actor_update_freq`` iters
                if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
434 435
                    # NOTE: actor_data['action] is latent action
                    actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
436 437 438 439 440
                    actor_data['obs'] = data['obs']
                    if self._twin_critic:
                        actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
                    else:
                        actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
P
puyuan1996 已提交
441

442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
                    loss_dict['actor_loss'] = actor_loss
                    # actor update
                    self._optimizer_actor.zero_grad()
                    actor_loss.backward()
                    self._optimizer_actor.step()
                # =============
                # after update
                # =============
                loss_dict['total_loss'] = sum(loss_dict.values())
                # self._forward_learn_cnt += 1
                self._target_model.update(self._learn_model.state_dict())
                if self._cfg.action_space == 'hybrid':
                    action_log_value = -1.  # TODO(nyz) better way to viz hybrid action
                else:
                    action_log_value = data['action'].mean()

                return {
                    'cur_lr_actor': self._optimizer_actor.defaults['lr'],
                    'cur_lr_critic': self._optimizer_critic.defaults['lr'],
                    'action': action_log_value,
                    'priority': td_error_per_sample.abs().tolist(),
                    'td_error': td_error_per_sample.abs().mean(),
                    **loss_dict,
                    **q_value_dict,
                }
P
puyuan1996 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523

    def _state_dict_learn(self) -> Dict[str, Any]:
        return {
            'model': self._learn_model.state_dict(),
            'target_model': self._target_model.state_dict(),
            'optimizer_actor': self._optimizer_actor.state_dict(),
            'optimizer_critic': self._optimizer_critic.state_dict(),
        }

    def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
        self._learn_model.load_state_dict(state_dict['model'])
        self._target_model.load_state_dict(state_dict['target_model'])
        self._optimizer_actor.load_state_dict(state_dict['optimizer_actor'])
        self._optimizer_critic.load_state_dict(state_dict['optimizer_critic'])

    def _init_collect(self) -> None:
        r"""
        Overview:
            Collect mode init method. Called by ``self.__init__``.
            Init traj and unroll length, collect model.
        """
        self._unroll_len = self._cfg.collect.unroll_len
        # collect model
        self._collect_model = model_wrap(
            self._model,
            wrapper_name='action_noise',
            noise_type='gauss',
            noise_kwargs={
                'mu': 0.0,
                'sigma': self._cfg.collect.noise_sigma
            },
            noise_range=None
        )
        if self._cfg.action_space == 'hybrid':
            self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample')
        self._collect_model.reset()

    def _forward_collect(self, data: dict, **kwargs) -> dict:
        r"""
        Overview:
            Forward function of collect mode.
        Arguments:
            - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
                values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
        Returns:
            - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
        ReturnsKeys
            - necessary: ``action``
            - optional: ``logit``
        """
        data_id = list(data.keys())
        data = default_collate(list(data.values()))
        if self._cuda:
            data = to_device(data, self._device)
        self._collect_model.eval()
        with torch.no_grad():
            output = self._collect_model.forward(data, mode='compute_actor', **kwargs)
P
puyuan1996 已提交
524
            output['latent_action'] = output['action']
P
puyuan1996 已提交
525 526

            # latent space constraint (LSC)
527
            for i in range(output['action'].shape[-1]):
P
puyuan1996 已提交
528 529 530
                output['action'][:, i].clamp_(
                    self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
                )
P
puyuan1996 已提交
531

532 533 534
            # TODO(pu): decode into original hybrid actions, here data is obs
            # this is very important to generate self.obs_encoding using in decode phase
            output['action'] = self._vae_model.decode_with_obs(output['action'], data)[0]
P
puyuan1996 已提交
535

536
        # NOTE: add noise in the original actions
537 538 539
        from ding.rl_utils.exploration import GaussianNoise
        action = output['action']
        gaussian_noise = GaussianNoise(mu=0.0, sigma=0.1)
P
puyuan1996 已提交
540
        noise = gaussian_noise(output['action'].shape, output['action'].device)
541 542 543
        if self._cfg.learn.noise_range is not None:
            noise = noise.clamp(self._cfg.learn.noise_range['min'], self._cfg.learn.noise_range['max'])
        action += noise
P
puyuan1996 已提交
544
        self.action_range = {'min': -1, 'max': 1}
545 546 547 548
        if self.action_range is not None:
            action = action.clamp(self.action_range['min'], self.action_range['max'])
        output['action'] = action

P
puyuan1996 已提交
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
        if self._cuda:
            output = to_device(output, 'cpu')
        output = default_decollate(output)
        return {i: d for i, d in zip(data_id, output)}

    def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]:
        r"""
        Overview:
            Generate dict type transition data from inputs.
        Arguments:
            - obs (:obj:`Any`): Env observation
            - model_output (:obj:`dict`): Output of collect model, including at least ['action']
            - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
                (here 'obs' indicates obs after env step, i.e. next_obs).
        Return:
            - transition (:obj:`Dict[str, Any]`): Dict type transition data.
        """
P
puyuan1996 已提交
566 567 568 569 570 571 572 573 574 575 576 577 578 579
        if 'latent_action' in model_output.keys():
            transition = {
                'obs': obs,
                'next_obs': timestep.obs,
                'action': model_output['action'],
                'latent_action': model_output['latent_action'],
                'reward': timestep.reward,
                'done': timestep.done,
            }
        else:  # if random collect at fist
            transition = {
                'obs': obs,
                'next_obs': timestep.obs,
                'action': model_output['action'],
580
                'latent_action': 999,
P
puyuan1996 已提交
581 582 583
                'reward': timestep.reward,
                'done': timestep.done,
            }
P
puyuan1996 已提交
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
        if self._cfg.action_space == 'hybrid':
            transition['logit'] = model_output['logit']
        return transition

    def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
        return get_train_sample(data, self._unroll_len)

    def _init_eval(self) -> None:
        r"""
        Overview:
            Evaluate mode init method. Called by ``self.__init__``.
            Init eval model. Unlike learn and collect model, eval model does not need noise.
        """
        self._eval_model = model_wrap(self._model, wrapper_name='base')
        if self._cfg.action_space == 'hybrid':
            self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample')
        self._eval_model.reset()

    def _forward_eval(self, data: dict) -> dict:
        r"""
        Overview:
            Forward function of eval mode, similar to ``self._forward_collect``.
        Arguments:
            - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
                values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
        Returns:
            - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
        ReturnsKeys
            - necessary: ``action``
            - optional: ``logit``
        """
        data_id = list(data.keys())
        data = default_collate(list(data.values()))
        if self._cuda:
            data = to_device(data, self._device)
        self._eval_model.eval()
        with torch.no_grad():
            output = self._eval_model.forward(data, mode='compute_actor')
P
puyuan1996 已提交
622
            output['latent_action'] = output['action']
P
puyuan1996 已提交
623 624

            # latent space constraint (LSC)
625
            for i in range(output['action'].shape[-1]):
P
puyuan1996 已提交
626 627 628
                output['action'][:, i].clamp_(
                    self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
                )
P
puyuan1996 已提交
629

630 631 632
            # TODO(pu): decode into original hybrid actions, here data is obs
            # this is very important to generate self.obs_encoding using in decode phase
            output['action'] = self._vae_model.decode_with_obs(output['action'], data)[0]
P
puyuan1996 已提交
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
        if self._cuda:
            output = to_device(output, 'cpu')
        output = default_decollate(output)
        return {i: d for i, d in zip(data_id, output)}

    def default_model(self) -> Tuple[str, List[str]]:
        return 'qac', ['ding.model.template.qac']

    def _monitor_vars_learn(self) -> List[str]:
        r"""
        Overview:
            Return variables' names if variables are to used in monitor.
        Returns:
            - vars (:obj:`List[str]`): Variables' name list.
        """
        ret = [
            'cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'q_value_twin',
650
            'action', 'td_error', 'vae_loss', 'reconstruction_loss', 'kld_loss', 'predict_loss'
P
puyuan1996 已提交
651 652 653
        ]
        if self._twin_critic:
            ret += ['critic_twin_loss']
P
puyuan1996 已提交
654
        return ret