未验证 提交 52d0967a 编写于 作者: W wawltor 提交者: GitHub

Fix the bug of support fp16 in scale op, cherry-pick from #23793

Fix the support the float16 of scale op, add delete the raise case for fp16
上级 a3851278
......@@ -10713,10 +10713,6 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"""
check_variable_and_dtype(
x, "x",
['float32', 'float64', 'uint8', 'int16', 'int32', 'in64', 'uint8'],
"scale")
if in_dygraph_mode():
_scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale
out = core.ops.scale(x, 'scale',
......@@ -10724,6 +10720,10 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
float(bias), 'bias_after_scale', bias_after_scale)
return dygraph_utils._append_activation_in_dygraph(out)
check_variable_and_dtype(x, "x", [
'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
'uint8'
], "scale")
inputs = {'X': [x]}
attrs = {
'bias': float(bias),
......
......@@ -131,12 +131,6 @@ class TestScaleRaiseError(unittest.TestCase):
self.assertRaises(TypeError, test_type)
def test_dtype():
data = fluid.data(shape=[10], dtype="float16", name="input")
fluid.layers.scale(data)
self.assertRaises(TypeError, test_dtype)
# Add FP16 test
@unittest.skipIf(not core.is_compiled_with_cuda(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册