未验证 提交 d003573f 编写于 作者: W wawltor 提交者: GitHub

add the error message check for the some operator

add the error message check for the some operator
上级 389a9a7e
......@@ -110,10 +110,12 @@ struct VisitDataArgMinMaxFunctor {
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_THROW(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
PADDLE_ENFORCE_LE(
x_dims.size(), 6,
platform::errors::InvalidArgument(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
......@@ -164,7 +166,8 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LT(
axis, x_dims.size(),
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).", axis,
x_dims.size()));
const int& dtype = ctx->Attrs().Get<int>("dtype");
PADDLE_ENFORCE_EQ(
......@@ -192,10 +195,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE_LE(
all_element_num, INT_MAX,
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, INT_MAX);
platform::errors::InvalidArgument(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, INT_MAX));
}
}
std::vector<int64_t> vec;
......
......@@ -52,7 +52,10 @@ class AssignFunctor {
template <typename T>
void operator()(const T &v) const {
PADDLE_THROW("Not support type for assign op %s", typeid(T).name());
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::PermissionDenied(
"Not support type for assign op with type %s", typeid(T).name()));
}
private:
......
......@@ -43,7 +43,11 @@ class OverflowOp : public framework::OperatorWithKernel {
} else if (x_var->IsType<framework::SelectedRows>()) {
dtype = x_var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW("Cannot find the input data type by all input data");
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The input type mismatch, the type of Input(X) must be Tensor or "
"SelectedRows, please check your input."));
}
return framework::OpKernelType(framework::proto::VarType::Type(dtype),
ctx.GetPlace());
......
......@@ -57,7 +57,11 @@ class OverflowKernel : public framework::OpKernel<T> {
auto& in = ctx.Input<framework::SelectedRows>("X")->value();
functor(in, out);
} else {
PADDLE_THROW("Unsupported input type.");
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The input type mismatch, the type of Input(X) must be Tensor or "
"SelectedRows, please check your input."));
}
}
};
......
......@@ -22,8 +22,6 @@ class LinspaceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Start"),
"Input(Start) of LinspaceOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace");
......
......@@ -63,7 +63,10 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
int32_t num = n.data<int32_t>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(framework::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace());
......
......@@ -46,7 +46,10 @@ class CPULinspaceKernel : public framework::OpKernel<T> {
T start = start_t.data<T>()[0];
T stop = stop_t.data<T>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(framework::make_ddim({num}));
......
......@@ -60,7 +60,10 @@ class ScaleKernel : public framework::OpKernel<T> {
out->mutable_data<T>(in->place());
PADDLE_ENFORCE_EQ(in->dims(), out->dims(),
"in and out should have the same dim");
paddle::platform::errors::InvalidArgument(
"the input and output should have the same dim"
"but input dim is %s, output dim is %s",
in->dims(), out->dims()));
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册