未验证 提交 16833c62 编写于 作者: 蒲源 提交者: GitHub

polish(pu): polish eps_greedy_multinomial_sample in model_wrapper (#154)

* polish(pu):polish eps_greedy_multinomial_sample in model_wrappers

* polish(pu): delete masac wrapper

* polish(pu): delete sql wrapper
上级 e6604502
......@@ -5,7 +5,6 @@ import numpy as np
import torch
from ding.torch_utils import get_tensor_data
from ding.rl_utils import create_noise_generator
from torch.distributions import Categorical
class IModelWrapper(ABC):
......@@ -210,12 +209,16 @@ class HybridArgmaxSampleWrapper(IModelWrapper):
class MultinomialSampleWrapper(IModelWrapper):
r"""
Overview:
Used to helper the model get the corresponding action from the output['logits']
Used to help the model get the corresponding action from the output['logits']
Interfaces:
register
"""
def forward(self, *args, **kwargs):
if 'alpha' in kwargs.keys():
alpha = kwargs.pop('alpha')
else:
alpha = None
output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit']
......@@ -227,7 +230,11 @@ class MultinomialSampleWrapper(IModelWrapper):
if isinstance(mask, torch.Tensor):
mask = [mask]
logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
action = [sample_action(logit=l) for l in logit]
if alpha is None:
action = [sample_action(logit=l) for l in logit]
else:
# Note that if alpha is passed in here, we will divide logit by alpha.
action = [sample_action(logit=l / alpha) for l in logit]
if len(action) == 1:
action, logit = action[0], logit[0]
output['action'] = action
......@@ -272,17 +279,21 @@ class EpsGreedySampleWrapper(IModelWrapper):
return output
class HybridEpsGreedySampleWrapper(IModelWrapper):
class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
r"""
Overview:
Epsilon greedy sampler used in collector_model to help balance exploration and exploitation.
In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}
Epsilon greedy sampler coupled with multinomial sample used in collector_model
to help balance exploration and exploitation.
Interfaces:
register, forward
register
"""
def forward(self, *args, **kwargs):
eps = kwargs.pop('eps')
if 'alpha' in kwargs.keys():
alpha = kwargs.pop('alpha')
else:
alpha = None
output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit']
......@@ -299,7 +310,11 @@ class HybridEpsGreedySampleWrapper(IModelWrapper):
action = []
for i, l in enumerate(logit):
if np.random.random() > eps:
action.append(l.argmax(dim=-1))
if alpha is None:
action = [sample_action(logit=l) for l in logit]
else:
# Note that if alpha is passed in here, we will divide logit by alpha.
action = [sample_action(logit=l / alpha) for l in logit]
else:
if mask:
action.append(sample_action(prob=mask[i].float()))
......@@ -307,22 +322,21 @@ class HybridEpsGreedySampleWrapper(IModelWrapper):
action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
if len(action) == 1:
action, logit = action[0], logit[0]
output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit}
output['action'] = action
return output
class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
class HybridEpsGreedySampleWrapper(IModelWrapper):
r"""
Overview:
Epsilon greedy sampler coupled with multinomial sample used in collector_model
to help balance exploration and exploitation.
Epsilon greedy sampler used in collector_model to help balance exploration and exploitation.
In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}
Interfaces:
register
register, forward
"""
def forward(self, *args, **kwargs):
eps = kwargs.pop('eps')
alpha = kwargs.pop('alpha')
output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit']
......@@ -339,12 +353,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
action = []
for i, l in enumerate(logit):
if np.random.random() > eps:
prob = torch.softmax(output['logit'] / alpha, dim=-1)
prob = prob / torch.sum(prob, 1, keepdim=True)
pi_action = torch.zeros(prob.shape)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
action.append(pi_action)
action.append(l.argmax(dim=-1))
else:
if mask:
action.append(sample_action(prob=mask[i].float()))
......@@ -352,8 +361,8 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
if len(action) == 1:
action, logit = action[0], logit[0]
output['action'] = action
return output
output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit}
return
class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper):
......@@ -387,11 +396,7 @@ class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper):
action = []
for i, l in enumerate(logit):
if np.random.random() > eps:
prob = torch.softmax(l, dim=-1)
prob = prob / torch.sum(prob, 1, keepdim=True)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
action.append(pi_action)
action = [sample_action(logit=l) for l in logit]
else:
if mask:
action.append(sample_action(prob=mask[i].float()))
......@@ -414,7 +419,7 @@ class EpsGreedySampleNGUWrapper(IModelWrapper):
def forward(self, *args, **kwargs):
kwargs.pop('eps')
eps = {i: 0.4 ** (1 + 8 * i / (args[0]['obs'].shape[0] - 1)) for i in range(args[0]['obs'].shape[0])} # TODO
eps = {i: 0.4 ** (1 + 8 * i / (args[0]['obs'].shape[0] - 1)) for i in range(args[0]['obs'].shape[0])}
output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit']
......
......@@ -19,7 +19,7 @@ from .common_utils import default_preprocess_learn
class SACDiscretePolicy(Policy):
r"""
Overview:
Policy class of Discrete SAC algorithm.
Policy class of discrete SAC algorithm.
Config:
== ==================== ======== ============= ================================= =======================
......@@ -407,7 +407,10 @@ class SACDiscretePolicy(Policy):
"""
self._unroll_len = self._cfg.collect.unroll_len
self._multi_agent = self._cfg.multi_agent
self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
# Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample
# and eps_greedy_sample, and we don't divide logit by alpha,
# for the details please refer to ding/model/wrapper/model_wrappers
self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample')
self._collect_model.reset()
def _forward_collect(self, data: dict, eps: float) -> dict:
......@@ -516,7 +519,7 @@ class SACDiscretePolicy(Policy):
class SACPolicy(Policy):
r"""
Overview:
Policy class of SAC algorithm.
Policy class of continuous SAC algorithm.
https://arxiv.org/pdf/1801.01290.pdf
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册