未验证 提交 246a9b6a 编写于 作者: G Guoxia Wang 提交者: GitHub

fix prelu float16 bug (#35584)

上级 523f46fe
...@@ -9818,7 +9818,7 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9818,7 +9818,7 @@ def prelu(x, mode, param_attr=None, name=None):
# [-0.2, 2., 3.] # [-0.2, 2., 3.]
""" """
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'prelu') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu')
helper = LayerHelper('prelu', **locals()) helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']: if mode not in ['all', 'channel', 'element']:
...@@ -9843,7 +9843,7 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9843,7 +9843,7 @@ def prelu(x, mode, param_attr=None, name=None):
alpha = helper.create_parameter( alpha = helper.create_parameter(
attr=helper.param_attr, attr=helper.param_attr,
shape=alpha_shape, shape=alpha_shape,
dtype='float32', dtype=dtype,
is_bias=False, is_bias=False,
default_initializer=Constant(0.25)) default_initializer=Constant(0.25))
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册