未验证 提交 211f5b03 编写于 作者: L lijianshe02 提交者: GitHub

enhance mul_op input error message test=develop (#20414)

* enhance mul_op input error message test=develop
上级 a1f54a89
...@@ -32,10 +32,12 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -32,10 +32,12 @@ class MulOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null."); "Input(X) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Output(Out) of MulOp should not be null."); "Input(Y) of MulOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of MulOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
......
...@@ -14352,6 +14352,23 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -14352,6 +14352,23 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
helper = LayerHelper("mul", **locals()) helper = LayerHelper("mul", **locals())
if not isinstance(x, Variable):
raise TypeError(
"The type of 'x' in mul must be Variable, but received %s" %
(type(x)))
if not isinstance(y, Variable):
raise TypeError(
"The type of 'y' in mul must be Variable, but received %s" %
(type(y)))
if convert_dtype(x.dtype) not in ['float32', 'float64']:
raise TypeError(
"The data type of 'x' in mul must be float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
if convert_dtype(y.dtype) not in ['float32', 'float64']:
raise TypeError(
"The data type of 'y' in softmax must be float32 or float64, but received %s."
% (convert_dtype(y.dtype)))
if name is None: if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
......
...@@ -18,6 +18,8 @@ import unittest ...@@ -18,6 +18,8 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestMulOp(OpTest): class TestMulOp(OpTest):
...@@ -49,6 +51,21 @@ class TestMulOp(OpTest): ...@@ -49,6 +51,21 @@ class TestMulOp(OpTest):
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
class TestMulOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of mul_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.mul, x1, x2)
# The input dtype of mul_op must be float32 or float64.
x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32")
x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.mul, x3, x4)
class TestMulOp2(OpTest): class TestMulOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mul" self.op_type = "mul"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册