diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 96553c0d9930bebb5e6dc10d3a11a6de1cef3f6f..93bc589b5101f665b62fec493350eee826be9344 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14729,13 +14729,19 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): raise TypeError( "The type of 'y' in mul must be Variable, but received %s" % (type(y))) - if convert_dtype(x.dtype) not in ['float32', 'float64']: + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'x' in mul only support float16 in GPU now.") + if convert_dtype(y.dtype) in ['float16']: + warnings.warn( + "The data type of 'y' in mul only support float16 in GPU now.") + if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']: raise TypeError( - "The data type of 'x' in mul must be float32 or float64, but received %s." + "The data type of 'x' in mul must be float16, float32 or float64, but received %s." % (convert_dtype(x.dtype))) - if convert_dtype(y.dtype) not in ['float32', 'float64']: + if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']: raise TypeError( - "The data type of 'y' in softmax must be float32 or float64, but received %s." + "The data type of 'y' in mul must be float16, float32 or float64, but received %s." % (convert_dtype(y.dtype))) if name is None: