未验证 提交 5c41805d 编写于 作者: L lijianshe02 提交者: GitHub

update mul_op input data type check test=develop (#20552)

上级 40effc61
...@@ -14729,13 +14729,19 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -14729,13 +14729,19 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
raise TypeError( raise TypeError(
"The type of 'y' in mul must be Variable, but received %s" % "The type of 'y' in mul must be Variable, but received %s" %
(type(y))) (type(y)))
if convert_dtype(x.dtype) not in ['float32', 'float64']: if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in mul only support float16 in GPU now.")
if convert_dtype(y.dtype) in ['float16']:
warnings.warn(
"The data type of 'y' in mul only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError( raise TypeError(
"The data type of 'x' in mul must be float32 or float64, but received %s." "The data type of 'x' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(x.dtype))) % (convert_dtype(x.dtype)))
if convert_dtype(y.dtype) not in ['float32', 'float64']: if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError( raise TypeError(
"The data type of 'y' in softmax must be float32 or float64, but received %s." "The data type of 'y' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(y.dtype))) % (convert_dtype(y.dtype)))
if name is None: if name is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册