提交 2b181eda 编写于 作者: N niuyazhe

fix(nyz): rename sum keepdims to keepdim for compatiblity and remove sql wrapper

上级 f087d2c7
......@@ -338,7 +338,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
for i, l in enumerate(logit):
if np.random.random() > eps:
prob = torch.softmax(output['logit'] / alpha, dim=-1)
prob = prob / torch.sum(prob, 1, keepdims=True)
prob = prob / torch.sum(prob, 1, keepdim=True)
pi_action = torch.zeros(prob.shape)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
......@@ -386,7 +386,7 @@ class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper):
for i, l in enumerate(logit):
if np.random.random() > eps:
prob = torch.softmax(l, dim=-1)
prob = prob / torch.sum(prob, 1, keepdims=True)
prob = prob / torch.sum(prob, 1, keepdim=True)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
action.append(pi_action)
......@@ -441,51 +441,6 @@ class EpsGreedySampleNGUWrapper(IModelWrapper):
return output
class EpsGreedySampleWrapperSql(IModelWrapper):
r"""
Overview:
Epsilon greedy sampler coupled with multinomial sample used in collector_model
to help balance exploration and exploitation.
Interfaces:
register
"""
def forward(self, *args, **kwargs):
eps = kwargs.pop('eps')
alpha = kwargs.pop('alpha')
output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit']
assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
if isinstance(logit, torch.Tensor):
logit = [logit]
if 'action_mask' in output:
mask = output['action_mask']
if isinstance(mask, torch.Tensor):
mask = [mask]
logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
else:
mask = None
action = []
for i, l in enumerate(logit):
if np.random.random() > eps:
prob = torch.softmax(output['logit'] / alpha, dim=-1)
prob = prob / torch.sum(prob, 1, keepdims=True)
pi_action = torch.zeros(prob.shape)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
action.append(pi_action)
else:
if mask:
action.append(sample_action(prob=mask[i].float()))
else:
action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
if len(action) == 1:
action, logit = action[0], logit[0]
output['action'] = action
return output
class ActionNoiseWrapper(IModelWrapper):
r"""
Overview:
......@@ -629,7 +584,6 @@ wrapper_name_map = {
'hybrid_argmax_sample': HybridArgmaxSampleWrapper,
'eps_greedy_sample': EpsGreedySampleWrapper,
'eps_greedy_sample_ngu': EpsGreedySampleNGUWrapper,
'eps_greedy_sample_sql': EpsGreedySampleWrapperSql,
'eps_greedy_multinomial_sample': EpsGreedyMultinomialSampleWrapper,
'hybrid_eps_greedy_sample': HybridEpsGreedySampleWrapper,
'hybrid_eps_greedy_multinomial_sample': HybridEpsGreedyMultinomialSampleWrapper,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册