diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 424ae4524f7848853f974630c4e1b9a4bfb30b2e..dc436a70cb97db5adbd2f17dc8a7bdfa1f94517e 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9818,7 +9818,7 @@ def prelu(x, mode, param_attr=None, name=None): # [-0.2, 2., 3.] """ - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'prelu') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu') helper = LayerHelper('prelu', **locals()) if mode not in ['all', 'channel', 'element']: @@ -9843,7 +9843,7 @@ def prelu(x, mode, param_attr=None, name=None): alpha = helper.create_parameter( attr=helper.param_attr, shape=alpha_shape, - dtype='float32', + dtype=dtype, is_bias=False, default_initializer=Constant(0.25)) out = helper.create_variable_for_type_inference(dtype)