未验证 提交 d003573f 编写于 作者: W wawltor 提交者: GitHub

add the error message check for the some operator

add the error message check for the some operator
上级 389a9a7e
...@@ -110,10 +110,12 @@ struct VisitDataArgMinMaxFunctor { ...@@ -110,10 +110,12 @@ struct VisitDataArgMinMaxFunctor {
CALL_ARG_MINMAX_FUNCTOR(6); CALL_ARG_MINMAX_FUNCTOR(6);
break; break;
default: default:
PADDLE_THROW( PADDLE_ENFORCE_LE(
"%s operator doesn't supports tensors whose ranks are greater " x_dims.size(), 6,
"than 6.", platform::errors::InvalidArgument(
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")); "%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break; break;
#undef CALL_ARG_MINMAX_FUNCTOR #undef CALL_ARG_MINMAX_FUNCTOR
} }
...@@ -164,7 +166,8 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { ...@@ -164,7 +166,8 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
axis, x_dims.size(), axis, x_dims.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size())); "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", axis,
x_dims.size()));
const int& dtype = ctx->Attrs().Get<int>("dtype"); const int& dtype = ctx->Attrs().Get<int>("dtype");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -192,10 +195,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { ...@@ -192,10 +195,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
} }
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
all_element_num, INT_MAX, all_element_num, INT_MAX,
"The element num of the argmin/argmax input at axis is " platform::errors::InvalidArgument(
"%d, is larger than int32 maximum value:%d, you must " "The element num of the argmin/argmax input at axis is "
"set the dtype of argmin/argmax to 'int64'.", "%d, is larger than int32 maximum value:%d, you must "
all_element_num, INT_MAX); "set the dtype of argmin/argmax to 'int64'.",
all_element_num, INT_MAX));
} }
} }
std::vector<int64_t> vec; std::vector<int64_t> vec;
......
...@@ -52,7 +52,10 @@ class AssignFunctor { ...@@ -52,7 +52,10 @@ class AssignFunctor {
template <typename T> template <typename T>
void operator()(const T &v) const { void operator()(const T &v) const {
PADDLE_THROW("Not support type for assign op %s", typeid(T).name()); PADDLE_ENFORCE_EQ(
true, false,
platform::errors::PermissionDenied(
"Not support type for assign op with type %s", typeid(T).name()));
} }
private: private:
......
...@@ -43,7 +43,11 @@ class OverflowOp : public framework::OperatorWithKernel { ...@@ -43,7 +43,11 @@ class OverflowOp : public framework::OperatorWithKernel {
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<framework::SelectedRows>()) {
dtype = x_var->Get<framework::SelectedRows>().value().type(); dtype = x_var->Get<framework::SelectedRows>().value().type();
} else { } else {
PADDLE_THROW("Cannot find the input data type by all input data"); PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The input type mismatch, the type of Input(X) must be Tensor or "
"SelectedRows, please check your input."));
} }
return framework::OpKernelType(framework::proto::VarType::Type(dtype), return framework::OpKernelType(framework::proto::VarType::Type(dtype),
ctx.GetPlace()); ctx.GetPlace());
......
...@@ -57,7 +57,11 @@ class OverflowKernel : public framework::OpKernel<T> { ...@@ -57,7 +57,11 @@ class OverflowKernel : public framework::OpKernel<T> {
auto& in = ctx.Input<framework::SelectedRows>("X")->value(); auto& in = ctx.Input<framework::SelectedRows>("X")->value();
functor(in, out); functor(in, out);
} else { } else {
PADDLE_THROW("Unsupported input type."); PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The input type mismatch, the type of Input(X) must be Tensor or "
"SelectedRows, please check your input."));
} }
} }
}; };
......
...@@ -22,8 +22,6 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -22,8 +22,6 @@ class LinspaceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Start"),
"Input(Start) of LinspaceOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace");
......
...@@ -63,7 +63,10 @@ class CUDALinspaceKernel : public framework::OpKernel<T> { ...@@ -63,7 +63,10 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
framework::TensorCopy(*num_t, platform::CPUPlace(), &n); framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
int32_t num = n.data<int32_t>()[0]; int32_t num = n.data<int32_t>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(framework::make_ddim({num})); out->Resize(framework::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
......
...@@ -46,7 +46,10 @@ class CPULinspaceKernel : public framework::OpKernel<T> { ...@@ -46,7 +46,10 @@ class CPULinspaceKernel : public framework::OpKernel<T> {
T start = start_t.data<T>()[0]; T start = start_t.data<T>()[0];
T stop = stop_t.data<T>()[0]; T stop = stop_t.data<T>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(framework::make_ddim({num})); out->Resize(framework::make_ddim({num}));
......
...@@ -60,7 +60,10 @@ class ScaleKernel : public framework::OpKernel<T> { ...@@ -60,7 +60,10 @@ class ScaleKernel : public framework::OpKernel<T> {
out->mutable_data<T>(in->place()); out->mutable_data<T>(in->place());
PADDLE_ENFORCE_EQ(in->dims(), out->dims(), PADDLE_ENFORCE_EQ(in->dims(), out->dims(),
"in and out should have the same dim"); paddle::platform::errors::InvalidArgument(
"the input and output should have the same dim"
"but input dim is %s, output dim is %s",
in->dims(), out->dims()));
auto eigen_out = framework::EigenVector<T>::Flatten(*out); auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in); auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册