From 246a9b6a7a463810438239daad72b08cef515634 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Fri, 10 Sep 2021 21:46:28 +0800 Subject: [PATCH] fix prelu float16 bug (#35584) --- python/paddle/fluid/layers/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 424ae4524f..dc436a70cb 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9818,7 +9818,7 @@ def prelu(x, mode, param_attr=None, name=None): # [-0.2, 2., 3.] """ - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'prelu') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu') helper = LayerHelper('prelu', **locals()) if mode not in ['all', 'channel', 'element']: @@ -9843,7 +9843,7 @@ def prelu(x, mode, param_attr=None, name=None): alpha = helper.create_parameter( attr=helper.param_attr, shape=alpha_shape, - dtype='float32', + dtype=dtype, is_bias=False, default_initializer=Constant(0.25)) out = helper.create_variable_for_type_inference(dtype) -- GitLab