From 408231671cca75298a6c54a97c7b49b98ab5e3e3 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Mon, 14 Oct 2019 20:23:03 +0800 Subject: [PATCH] update mul_op input data type check test=develop (#20598) --- python/paddle/fluid/layers/nn.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 91d82837bc4..659d77635e4 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14812,6 +14812,29 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): helper = LayerHelper("mul", **locals()) + if not isinstance(x, Variable): + raise TypeError( + "The type of 'x' in mul must be Variable, but received %s" % + (type(x))) + if not isinstance(y, Variable): + raise TypeError( + "The type of 'y' in mul must be Variable, but received %s" % + (type(y))) + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'x' in mul only support float16 in GPU now.") + if convert_dtype(y.dtype) in ['float16']: + warnings.warn( + "The data type of 'y' in mul only support float16 in GPU now.") + if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']: + raise TypeError( + "The data type of 'x' in mul must be float16, float32 or float64, but received %s." + % (convert_dtype(x.dtype))) + if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']: + raise TypeError( + "The data type of 'y' in mul must be float16, float32 or float64, but received %s." + % (convert_dtype(y.dtype))) + if name is None: out = helper.create_variable_for_type_inference(dtype=x.dtype) else: -- GitLab