未验证 提交 7d0e9034 编写于 作者: M MRXLT 提交者: GitHub

update error message for unstack op and lamb op; test=develop (#24487)

上级 6885d156
...@@ -177,11 +177,12 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -177,11 +177,12 @@ class LambOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param"); const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, " platform::errors::InvalidArgument(
"but the received is %s", "The Var(%s)'s type should be LoDTensor, "
ctx.InputNames("Param").front(), "but the received is %s",
framework::ToTypeName(param_var->Type())); ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
...@@ -274,7 +275,10 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -274,7 +275,10 @@ class LambOpKernel : public framework::OpKernel<T> {
row_numel, grad_merge.rows().size()); row_numel, grad_merge.rows().size());
for_range(moment_update_functor); for_range(moment_update_functor);
} else { } else {
PADDLE_THROW("Variable type not supported by lamb_op."); PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by lamb_op. Expect LoDTensor or "
"SelectedRows, but got %s",
framework::ToTypeName(param_var->Type())));
} }
// Update parameter // Update parameter
......
...@@ -27,24 +27,35 @@ class UnStackOp : public framework::OperatorWithKernel { ...@@ -27,24 +27,35 @@ class UnStackOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must exist."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UnStack");
int axis = ctx->Attrs().Get<int>("axis"); int axis = ctx->Attrs().Get<int>("axis");
int num = ctx->Attrs().Get<int>("num"); int num = ctx->Attrs().Get<int>("num");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
int rank = x_dim.size(); int rank = x_dim.size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(axis, -rank,
axis, -rank, "Attr(axis) must be inside [-rank, rank), where rank = %d", platform::errors::InvalidArgument(
rank); "The attribute axis is out of range, it must be "
PADDLE_ENFORCE_LT( "inside [-rank, rank), where rank = %d",
axis, rank, "Attr(axis) must be inside [-rank, rank), where rank = %d", rank));
rank); PADDLE_ENFORCE_LT(axis, rank,
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-rank, rank), where rank = %d",
rank));
if (axis < 0) axis += rank; if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num), PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
"Number of Outputs(Y) is wrong"); platform::errors::InvalidArgument(
"Number of Outputs(Y) is wrong. Got %d , but it must "
"equal to attribute num which is %d.",
ctx->Outputs("Y").size(), static_cast<size_t>(num)));
if (x_dim[axis] > 0) { if (x_dim[axis] > 0) {
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong"); PADDLE_ENFORCE_EQ(
num, x_dim[axis],
platform::errors::InvalidArgument(
"The number of attribute num is not equal to the length of the "
"%d axis of Input(X). Expect %d but got %d.",
axis, x_dim[axis], num));
} }
auto vec = framework::vectorize<int>(x_dim); auto vec = framework::vectorize<int>(x_dim);
vec.erase(vec.begin() + axis); vec.erase(vec.begin() + axis);
...@@ -89,24 +100,29 @@ class UnStackGradOp : public framework::OperatorWithKernel { ...@@ -89,24 +100,29 @@ class UnStackGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0, PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
"Number of Inputs(Y@Grad) must be larger than 0"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, "Number of Inputs(Y@Grad) must be larger than 0"));
"Output(X@Grad) must exist."); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", "X",
"UnStackGrad");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) { for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(Y@Grad) must be the same"); platform::errors::InvalidArgument(
"Dims of all Inputs(Y@Grad) must be the same"));
} }
int axis = ctx->Attrs().Get<int>("axis"); int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size(); int rank = input_dims[0].size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(axis, -(rank + 1),
axis, -(rank + 1), platform::errors::InvalidArgument(
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); "The attribute axis is out of range, it must be "
PADDLE_ENFORCE_LT( "inside [-(rank+1), rank+1), where rank = %d",
axis, rank + 1, rank));
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); PADDLE_ENFORCE_LT(axis, rank + 1,
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-(rank+1), rank+1), where rank = %d",
rank));
if (axis < 0) axis += (rank + 1); if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize<int>(input_dims[0]); auto vec = framework::vectorize<int>(input_dims[0]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册