未验证 提交 40823167 编写于 作者: L lijianshe02 提交者: GitHub

update mul_op input data type check test=develop (#20598)

上级 5da8db61
......@@ -14812,6 +14812,29 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
helper = LayerHelper("mul", **locals())
if not isinstance(x, Variable):
raise TypeError(
"The type of 'x' in mul must be Variable, but received %s" %
(type(x)))
if not isinstance(y, Variable):
raise TypeError(
"The type of 'y' in mul must be Variable, but received %s" %
(type(y)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in mul only support float16 in GPU now.")
if convert_dtype(y.dtype) in ['float16']:
warnings.warn(
"The data type of 'y' in mul only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'y' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(y.dtype)))
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册