From 5c41805dc96549c56e48a798ab5632b5a0922299 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Mon, 14 Oct 2019 10:41:12 +0800 Subject: [PATCH] update mul_op input data type check test=develop (#20552) --- python/paddle/fluid/layers/nn.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 96553c0d99..93bc589b51 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14729,13 +14729,19 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): raise TypeError( "The type of 'y' in mul must be Variable, but received %s" % (type(y))) - if convert_dtype(x.dtype) not in ['float32', 'float64']: + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'x' in mul only support float16 in GPU now.") + if convert_dtype(y.dtype) in ['float16']: + warnings.warn( + "The data type of 'y' in mul only support float16 in GPU now.") + if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']: raise TypeError( - "The data type of 'x' in mul must be float32 or float64, but received %s." + "The data type of 'x' in mul must be float16, float32 or float64, but received %s." % (convert_dtype(x.dtype))) - if convert_dtype(y.dtype) not in ['float32', 'float64']: + if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']: raise TypeError( - "The data type of 'y' in softmax must be float32 or float64, but received %s." + "The data type of 'y' in mul must be float16, float32 or float64, but received %s." % (convert_dtype(y.dtype))) if name is None: -- GitLab