未验证 提交 808df649 编写于 作者: W wanghuancoder 提交者: GitHub

fix scale bug (#45705)

* fix scale bug
上级 ae6a8271
...@@ -177,13 +177,14 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): ...@@ -177,13 +177,14 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.scale(x, scale, float(bias), bias_after_scale) out = _C_ops.scale(x, scale, float(bias), bias_after_scale)
if _non_static_mode(): return dygraph_utils._append_activation_in_dygraph(out, act)
elif _in_legacy_dygraph():
_scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale _scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale
out = _legacy_C_ops.scale(x, 'scale', out = _legacy_C_ops.scale(x, 'scale',
float(_scale), 'bias', float(_scale), 'bias',
float(bias), 'bias_after_scale', bias_after_scale) float(bias), 'bias_after_scale', bias_after_scale)
return dygraph_utils._append_activation_in_dygraph(out) return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype(x, "x", [ check_variable_and_dtype(x, "x", [
'float16', 'uint16', 'float32', 'float64', 'int8', 'int16', 'int32', 'float16', 'uint16', 'float32', 'float64', 'int8', 'int16', 'int32',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册