未验证 提交 9704582e 编写于 作者: T tangwei12 提交者: GitHub

fix op error (#27599)

* fix error

* fix error

* fix error

* merge develop
上级 c68a0313
...@@ -24,10 +24,9 @@ class SplitByrefOp : public framework::OperatorWithKernel { ...@@ -24,10 +24,9 @@ class SplitByrefOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Ids", "SplitByrefOp");
"Input(X) of SplitOp should not be null."); OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "SplitByrefOp");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
"Outputs(Out) of SplitOp should not be empty.");
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out"); auto outs_names = ctx->Outputs("Out");
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num")); size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
...@@ -51,9 +50,10 @@ class SplitByrefOp : public framework::OperatorWithKernel { ...@@ -51,9 +50,10 @@ class SplitByrefOp : public framework::OperatorWithKernel {
outs_dims.push_back(dim); outs_dims.push_back(dim);
} }
} else if (sections.size() > 0) { } else if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(sections.size(), outs_number, PADDLE_ENFORCE_EQ(
"tensor split sections size" sections.size(), outs_number,
"should be equal to output size."); platform::errors::InvalidArgument("tensor split sections size"
"should be equal to output size"));
for (size_t i = 0; i < outs_number; ++i) { for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims; auto dim = in_dims;
dim[0] = sections[i]; dim[0] = sections[i];
......
...@@ -52,13 +52,19 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -52,13 +52,19 @@ class SplitIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"), "SplitIdsOp must have input Ids."); OP_INOUT_CHECK(ctx->HasInputs("Ids"), "Input", "Ids", "SplitIdsOp");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must have output Out."); OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "SplitIdsOp");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputsDim("Ids"); auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2); PADDLE_ENFORCE_EQ(
ids_dims[0].size(), 2,
platform::errors::InvalidArgument(
"ShapeError: The dimensions of the 'split_ids' must be 2. "
"But received split_ids's dimensions = %d, "
"split_ids's shape = [%s].",
ids_dims[0].size(), ids_dims[0]));
} }
} }
......
...@@ -30,12 +30,17 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -30,12 +30,17 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) { if (!platform::is_cpu_place(place)) {
PADDLE_THROW("SplitIds do not support GPU kernel"); PADDLE_THROW(platform::errors::Unimplemented(
"SplitIds do not support GPU kernel"));
} }
const auto ids_vars = ctx.MultiInputVar("Ids"); const auto ids_vars = ctx.MultiInputVar("Ids");
PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0"); PADDLE_ENFORCE_GT(
ids_vars.size(), 0,
platform::errors::InvalidArgument(
ids_vars.size(), 0, "The number of Ids expected > 0, but got %d",
ids_vars.size()));
auto *ids_var = ids_vars[0]; auto *ids_var = ids_vars[0];
if (ids_var->IsType<framework::LoDTensor>()) { if (ids_var->IsType<framework::LoDTensor>()) {
...@@ -83,9 +88,6 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -83,9 +88,6 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
} else if (ids_var->IsType<framework::SelectedRows>()) { } else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids"); const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims(); auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0],
static_cast<int64_t>(ids_selected_rows->rows().size()),
"");
const T *ids_data = ids_selected_rows->value().data<T>(); const T *ids_data = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows(); const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
...@@ -114,9 +116,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -114,9 +116,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
} }
} }
} else { } else {
PADDLE_THROW( PADDLE_THROW(platform::errors::InvalidArgument(
"% should be LoDTensor or SelectedRows, but the received type is %s", "% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.InputNames("Ids")[0], framework::ToTypeName(ids_var->Type())); ctx.InputNames("Ids")[0], framework::ToTypeName(ids_var->Type())));
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册