未验证 提交 7d0e9034 编写于 作者: M MRXLT 提交者: GitHub

update error message for unstack op and lamb op; test=develop (#24487)

上级 6885d156
......@@ -177,11 +177,12 @@ class LambOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()));
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
......@@ -274,7 +275,10 @@ class LambOpKernel : public framework::OpKernel<T> {
row_numel, grad_merge.rows().size());
for_range(moment_update_functor);
} else {
PADDLE_THROW("Variable type not supported by lamb_op.");
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by lamb_op. Expect LoDTensor or "
"SelectedRows, but got %s",
framework::ToTypeName(param_var->Type())));
}
// Update parameter
......
......@@ -27,24 +27,35 @@ class UnStackOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must exist.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UnStack");
int axis = ctx->Attrs().Get<int>("axis");
int num = ctx->Attrs().Get<int>("num");
auto x_dim = ctx->GetInputDim("X");
int rank = x_dim.size();
PADDLE_ENFORCE_GE(
axis, -rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
PADDLE_ENFORCE_LT(
axis, rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
PADDLE_ENFORCE_GE(axis, -rank,
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-rank, rank), where rank = %d",
rank));
PADDLE_ENFORCE_LT(axis, rank,
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-rank, rank), where rank = %d",
rank));
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
"Number of Outputs(Y) is wrong");
platform::errors::InvalidArgument(
"Number of Outputs(Y) is wrong. Got %d , but it must "
"equal to attribute num which is %d.",
ctx->Outputs("Y").size(), static_cast<size_t>(num)));
if (x_dim[axis] > 0) {
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
PADDLE_ENFORCE_EQ(
num, x_dim[axis],
platform::errors::InvalidArgument(
"The number of attribute num is not equal to the length of the "
"%d axis of Input(X). Expect %d but got %d.",
axis, x_dim[axis], num));
}
auto vec = framework::vectorize<int>(x_dim);
vec.erase(vec.begin() + axis);
......@@ -89,24 +100,29 @@ class UnStackGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
"Number of Inputs(Y@Grad) must be larger than 0");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"Output(X@Grad) must exist.");
platform::errors::InvalidArgument(
"Number of Inputs(Y@Grad) must be larger than 0"));
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", "X",
"UnStackGrad");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(Y@Grad) must be the same");
platform::errors::InvalidArgument(
"Dims of all Inputs(Y@Grad) must be the same"));
}
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE_GE(
axis, -(rank + 1),
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
PADDLE_ENFORCE_LT(
axis, rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
PADDLE_ENFORCE_GE(axis, -(rank + 1),
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-(rank+1), rank+1), where rank = %d",
rank));
PADDLE_ENFORCE_LT(axis, rank + 1,
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-(rank+1), rank+1), where rank = %d",
rank));
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize<int>(input_dims[0]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册