diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 6678320f9ffa61e3e6c51fd806569c2571d63d69..5b8922505cc089d66f0b444fc65ccec8ed051876 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -26,6 +26,52 @@ class SplitOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of SplitOp should not be null.")); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, + platform::errors::InvalidArgument( + "Outputs(Out) of SplitOp should not be empty.")); + auto in_dims = ctx->GetInputDim("X"); + auto outs_names = ctx->Outputs("Out"); + size_t axis = static_cast(ctx->Attrs().Get("axis")); + size_t num = static_cast(ctx->Attrs().Get("num")); + std::vector sections = static_cast>( + ctx->Attrs().Get>("sections")); + const size_t outs_number = outs_names.size(); + + if (sections.size() > 0) { + PADDLE_ENFORCE_EQ( + sections.size(), outs_number, + platform::errors::InvalidArgument("tensor split sections size " + "should be equal to output size.")); + } + + if (ctx->HasInput("AxisTensor")) { + auto out_dims = phi::make_ddim(std::vector(in_dims.size(), -1)); + std::vector outs_dims(outs_number, out_dims); + ctx->SetOutputsDim("Out", outs_dims); + for (size_t i = 0; i < outs_number; ++i) { + ctx->ShareLoD("X", "Out", 0, i); + } + return; + } + + bool each_section_is_known = + (sections.size() > 0 && !ctx->HasInputs("SectionsTensorList")); + + auto outs_dims = UpdateOutsDims(ctx->IsRuntime(), each_section_is_known, + in_dims, num, sections, axis, outs_number); + ctx->SetOutputsDim("Out", outs_dims); + if (axis != 0) { + // Only pass LoD when not spliting along the first dim. + for (size_t i = 0; i < outs_number; ++i) { + ctx->ShareLoD("X", "Out", 0, i); + } + } + } + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -125,10 +171,6 @@ Example: namespace ops = paddle::operators; -DELCARE_INFER_SHAPE_FUNCTOR(split, SplitInferShapeFunctor, - PT_INFER_META(phi::SplitInferMeta)); - REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker, - ops::SplitGradMaker, - SplitInferShapeFunctor); + ops::SplitGradMaker); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index ed95c8ff67785708cfcfd48624d574c4e53392fa..ff58c53ad9b403daa5562aaae7056741990cd7f7 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -508,17 +508,6 @@ void SplitInferMeta(const MetaTensor& x, const Scalar& axis, std::vector out, MetaConfig config) { - if (!config.is_runtime) { - if (axis.FromTensor() || num_or_sections.FromTensor()) { - auto out_dims = phi::make_ddim(std::vector(x.dims().size(), -1)); - for (auto* item : out) { - item->set_dims(out_dims); - item->share_lod(x); - } - return; - } - } - int axis_value = axis.to(); int rank = x.dims().size(); PADDLE_ENFORCE_EQ( @@ -533,34 +522,27 @@ void SplitInferMeta(const MetaTensor& x, axis_value = axis_value + rank; } - std::vector out_dims(out.size(), x.dims()); - auto input_axis_dim = x.dims().at(axis_value); auto num_or_sections_data = num_or_sections.GetData(); + // step1: get formated sections + std::vector sections; // num_or_sections is a number if (num_or_sections_data.size() == 1) { - if (config.is_runtime || input_axis_dim > 0) { - int num = num_or_sections_data.at(0); - PADDLE_ENFORCE_EQ( - input_axis_dim % num, - 0, - phi::errors::InvalidArgument( - "The input's size along the split dimension " - "must be evenly divisible by Attr(num_or_sections). " - "But received Attr(num_or_sections) " - "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", - num, - x.dims(), - axis_value)); + int num = num_or_sections_data.at(0); - size_t out_axis_dim = input_axis_dim / num; - for (auto& out_dim : out_dims) { - out_dim[axis_value] = out_axis_dim; - } - } else { - for (auto& out_dim : out_dims) { - out_dim[axis_value] = -1; - } + PADDLE_ENFORCE_EQ(input_axis_dim % num, + 0, + phi::errors::InvalidArgument( + "The input's size along the split dimension " + "must be evenly divisible by Attr(num_or_sections). " + "But received Attr(num_or_sections) " + "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", + num, + x.dims(), + axis_value)); + + for (int i = 0; i < num; ++i) { + sections.push_back(input_axis_dim / num); } } else { // num_or_sections is a sections @@ -568,9 +550,10 @@ void SplitInferMeta(const MetaTensor& x, int unknow_dim_idx = -1; int num_of_unknow = 0; int sum_of_section = 0; - std::vector sections = num_or_sections_data; for (size_t i = 0; i < num_or_sections_data.size(); ++i) { + sections.push_back(num_or_sections_data[i]); + if (num_or_sections_data[i] == unknow_dim_val) { num_of_unknow++; unknow_dim_idx = i; @@ -622,22 +605,31 @@ void SplitInferMeta(const MetaTensor& x, x.dims(), axis_value)); } - for (size_t i = 0; i < out_dims.size(); ++i) { + } + + // setp2: fill out dims + std::vector out_dims(sections.size(), x.dims()); + if (config.is_runtime || input_axis_dim > 0) { + for (size_t i = 0; i < sections.size(); ++i) { out_dims[i][axis_value] = sections[i]; } + } else { + for (size_t i = 0; i < sections.size(); ++i) { + out_dims[i][axis_value] = -1; + } } - for (size_t i = 0; i < out.size(); ++i) { + for (size_t i = 0; i < sections.size(); ++i) { if (axis_value != 0) { // Only pass LoD when not spliting along the first dim. - out.at(i)->set_dtype(x.dtype()); - out.at(i)->set_dims(out_dims[i]); - out.at(i)->set_layout(x.layout()); + out[i]->set_dtype(x.dtype()); + out[i]->set_dims(out_dims[i]); + out[i]->set_layout(x.layout()); } else { - out.at(i)->set_dtype(x.dtype()); - out.at(i)->set_dims(out_dims[i]); - out.at(i)->set_layout(x.layout()); - out.at(i)->share_lod(x); + out[i]->set_dtype(x.dtype()); + out[i]->set_dims(out_dims[i]); + out[i]->set_layout(x.layout()); + out[i]->share_lod(x); } } } diff --git a/paddle/phi/kernels/cpu/split_kernel.cc b/paddle/phi/kernels/cpu/split_kernel.cc index 4acf9b02028f994c38144d716fdd56c6bbb6afa2..324798effbe56b8b7bdf0c3d31b21cd079a8cf1c 100644 --- a/paddle/phi/kernels/cpu/split_kernel.cc +++ b/paddle/phi/kernels/cpu/split_kernel.cc @@ -28,6 +28,23 @@ void SplitKernel(const Context& dev_ctx, const ScalarArray& num_or_sections, const Scalar& axis_scalar, std::vector outs) { + // need to infershape output + if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) { + std::vector out_metas; + out_metas.reserve(outs.size()); + std::vector out_metas_ptr; + for (size_t i = 0; i < outs.size(); ++i) { + out_metas.push_back(outs[i]); + out_metas_ptr.push_back(&out_metas.back()); + } + + phi::SplitInferMeta(x, num_or_sections, axis_scalar, out_metas_ptr, true); + + for (size_t i = 0; i < out_metas.size(); ++i) { + outs[i]->Resize(out_metas[i].dims()); + } + } + std::vector shape_refer; for (size_t j = 0; j < outs.size(); ++j) { dev_ctx.template Alloc(outs[j]); diff --git a/paddle/phi/kernels/gpu/split_kernel.cu b/paddle/phi/kernels/gpu/split_kernel.cu index d2473d5b0b110a122247c32c779b7a700c3249b1..c28fc3794f092a4cee8d7fc351190c13291892b1 100644 --- a/paddle/phi/kernels/gpu/split_kernel.cu +++ b/paddle/phi/kernels/gpu/split_kernel.cu @@ -27,6 +27,23 @@ void SplitKernel(const Context& dev_ctx, const ScalarArray& num_or_sections, const Scalar& axis_scalar, std::vector outs) { + // need to infershape output + if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) { + std::vector out_metas; + out_metas.reserve(outs.size()); + std::vector out_metas_ptr; + for (size_t i = 0; i < outs.size(); ++i) { + out_metas.push_back(outs[i]); + out_metas_ptr.push_back(&out_metas.back()); + } + + phi::SplitInferMeta(x, num_or_sections, axis_scalar, out_metas_ptr, true); + + for (size_t i = 0; i < out_metas.size(); ++i) { + outs[i]->Resize(out_metas[i].dims()); + } + } + std::vector shape_refer; for (size_t j = 0; j < outs.size(); ++j) { dev_ctx.template Alloc(outs[j]);