未验证 提交 4b646fe0 编写于 作者: B Bo Zhou 提交者: GitHub

Fix precision issue in the action_mapping function (#368)

* Update common.py

* fix precision issue

fix #367
上级 3a27f407
......@@ -37,7 +37,7 @@ def fetch_framework_var(attr_name):
core_var = scope.find_var(attr_name)
if core_var == None:
raise KeyError(
"Unable to find the variable:{}. Synchronize paramsters before initialization or attr_name does not exist."
"Unable to find the variable:{}. Synchronize parameters before initialization or attr_name does not exist."
.format(attr_name))
shape = core_var.get_tensor().shape()
framework_var = fluid.layers.create_parameter(
......
......@@ -15,7 +15,7 @@
import numpy as np
import scipy.signal
__all__ = ['calc_discount_sum_rewards', 'calc_gae']
__all__ = ['calc_discount_sum_rewards', 'calc_gae', 'action_mapping']
def calc_discount_sum_rewards(rewards, gamma):
......@@ -49,3 +49,24 @@ def calc_gae(rewards, values, next_value, gamma, lam):
tds = rewards + gamma * np.append(values[1:], next_value) - values
advantages = calc_discount_sum_rewards(tds, gamma * lam)
return advantages
def action_mapping(model_output_act, low_bound, high_bound):
""" mapping action space [-1, 1] of model output
to new action space [low_bound, high_bound].
Args:
model_output_act: np.array, which value is in [-1, 1]
low_bound: float, low bound of env action space
high_bound: float, high bound of env action space
Returns:
action: np.array, which value is in [low_bound, high_bound]
"""
assert np.all(((model_output_act<=1.0 + 1e-3), (model_output_act>=-1.0 - 1e-3))), \
'the action should be in range [-1.0, 1.0]'
assert high_bound > low_bound
action = low_bound + (model_output_act - (-1.0)) * (
(high_bound - low_bound) / 2.0)
action = np.clip(action, low_bound, high_bound)
return action
......@@ -18,9 +18,9 @@ import subprocess
import numpy as np
__all__ = [
'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3',
'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC',
'kill_process', 'get_fluid_version'
'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32',
'_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process',
'get_fluid_version'
]
......@@ -37,27 +37,6 @@ def has_func(obj, fun):
return callable(check_fun)
def action_mapping(model_output_act, low_bound, high_bound):
""" mapping action space [-1, 1] of model output
to new action space [low_bound, high_bound].
Args:
model_output_act: np.array, which value is in [-1, 1]
low_bound: float, low bound of env action space
high_bound: float, high bound of env action space
Returns:
action: np.array, which value is in [low_bound, high_bound]
"""
assert np.all(((model_output_act<=1.0), (model_output_act>=-1.0))), \
'the action should be in range [-1.0, 1.0]'
assert high_bound > low_bound
action = low_bound + (model_output_act - (-1.0)) * (
(high_bound - low_bound) / 2.0)
action = np.clip(action, low_bound, high_bound)
return action
def to_str(byte):
""" convert byte to string in pytohn2/3
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册