未验证 提交 45385371 编写于 作者: Z zyfncg 提交者: GitHub

Fix bug caused by split infershape (#40116)

* fix bug caused by split infershape

* revert infer_shape of split

* revert split
上级 8dbfc2ae
...@@ -26,6 +26,52 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -26,6 +26,52 @@ class SplitOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of SplitOp should not be null."));
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of SplitOp should not be empty."));
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
std::vector<int> sections = static_cast<std::vector<int>>(
ctx->Attrs().Get<std::vector<int>>("sections"));
const size_t outs_number = outs_names.size();
if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(
sections.size(), outs_number,
platform::errors::InvalidArgument("tensor split sections size "
"should be equal to output size."));
}
if (ctx->HasInput("AxisTensor")) {
auto out_dims = phi::make_ddim(std::vector<int>(in_dims.size(), -1));
std::vector<framework::DDim> outs_dims(outs_number, out_dims);
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
return;
}
bool each_section_is_known =
(sections.size() > 0 && !ctx->HasInputs("SectionsTensorList"));
auto outs_dims = UpdateOutsDims(ctx->IsRuntime(), each_section_is_known,
in_dims, num, sections, axis, outs_number);
ctx->SetOutputsDim("Out", outs_dims);
if (axis != 0) {
// Only pass LoD when not spliting along the first dim.
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -125,10 +171,6 @@ Example: ...@@ -125,10 +171,6 @@ Example:
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(split, SplitInferShapeFunctor,
PT_INFER_META(phi::SplitInferMeta));
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>, ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>, ops::SplitGradMaker<paddle::imperative::OpBase>);
SplitInferShapeFunctor);
...@@ -508,17 +508,6 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -508,17 +508,6 @@ void SplitInferMeta(const MetaTensor& x,
const Scalar& axis, const Scalar& axis,
std::vector<MetaTensor*> out, std::vector<MetaTensor*> out,
MetaConfig config) { MetaConfig config) {
if (!config.is_runtime) {
if (axis.FromTensor() || num_or_sections.FromTensor()) {
auto out_dims = phi::make_ddim(std::vector<int>(x.dims().size(), -1));
for (auto* item : out) {
item->set_dims(out_dims);
item->share_lod(x);
}
return;
}
}
int axis_value = axis.to<int>(); int axis_value = axis.to<int>();
int rank = x.dims().size(); int rank = x.dims().size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -533,34 +522,27 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -533,34 +522,27 @@ void SplitInferMeta(const MetaTensor& x,
axis_value = axis_value + rank; axis_value = axis_value + rank;
} }
std::vector<phi::DDim> out_dims(out.size(), x.dims());
auto input_axis_dim = x.dims().at(axis_value); auto input_axis_dim = x.dims().at(axis_value);
auto num_or_sections_data = num_or_sections.GetData(); auto num_or_sections_data = num_or_sections.GetData();
// step1: get formated sections
std::vector<int64_t> sections;
// num_or_sections is a number // num_or_sections is a number
if (num_or_sections_data.size() == 1) { if (num_or_sections_data.size() == 1) {
if (config.is_runtime || input_axis_dim > 0) { int num = num_or_sections_data.at(0);
int num = num_or_sections_data.at(0);
PADDLE_ENFORCE_EQ(
input_axis_dim % num,
0,
phi::errors::InvalidArgument(
"The input's size along the split dimension "
"must be evenly divisible by Attr(num_or_sections). "
"But received Attr(num_or_sections) "
"= %d, input(X)'s shape = [%s], Attr(dim) = %d.",
num,
x.dims(),
axis_value));
size_t out_axis_dim = input_axis_dim / num; PADDLE_ENFORCE_EQ(input_axis_dim % num,
for (auto& out_dim : out_dims) { 0,
out_dim[axis_value] = out_axis_dim; phi::errors::InvalidArgument(
} "The input's size along the split dimension "
} else { "must be evenly divisible by Attr(num_or_sections). "
for (auto& out_dim : out_dims) { "But received Attr(num_or_sections) "
out_dim[axis_value] = -1; "= %d, input(X)'s shape = [%s], Attr(dim) = %d.",
} num,
x.dims(),
axis_value));
for (int i = 0; i < num; ++i) {
sections.push_back(input_axis_dim / num);
} }
} else { } else {
// num_or_sections is a sections // num_or_sections is a sections
...@@ -568,9 +550,10 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -568,9 +550,10 @@ void SplitInferMeta(const MetaTensor& x,
int unknow_dim_idx = -1; int unknow_dim_idx = -1;
int num_of_unknow = 0; int num_of_unknow = 0;
int sum_of_section = 0; int sum_of_section = 0;
std::vector<int64_t> sections = num_or_sections_data;
for (size_t i = 0; i < num_or_sections_data.size(); ++i) { for (size_t i = 0; i < num_or_sections_data.size(); ++i) {
sections.push_back(num_or_sections_data[i]);
if (num_or_sections_data[i] == unknow_dim_val) { if (num_or_sections_data[i] == unknow_dim_val) {
num_of_unknow++; num_of_unknow++;
unknow_dim_idx = i; unknow_dim_idx = i;
...@@ -622,22 +605,31 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -622,22 +605,31 @@ void SplitInferMeta(const MetaTensor& x,
x.dims(), x.dims(),
axis_value)); axis_value));
} }
for (size_t i = 0; i < out_dims.size(); ++i) { }
// setp2: fill out dims
std::vector<phi::DDim> out_dims(sections.size(), x.dims());
if (config.is_runtime || input_axis_dim > 0) {
for (size_t i = 0; i < sections.size(); ++i) {
out_dims[i][axis_value] = sections[i]; out_dims[i][axis_value] = sections[i];
} }
} else {
for (size_t i = 0; i < sections.size(); ++i) {
out_dims[i][axis_value] = -1;
}
} }
for (size_t i = 0; i < out.size(); ++i) { for (size_t i = 0; i < sections.size(); ++i) {
if (axis_value != 0) { if (axis_value != 0) {
// Only pass LoD when not spliting along the first dim. // Only pass LoD when not spliting along the first dim.
out.at(i)->set_dtype(x.dtype()); out[i]->set_dtype(x.dtype());
out.at(i)->set_dims(out_dims[i]); out[i]->set_dims(out_dims[i]);
out.at(i)->set_layout(x.layout()); out[i]->set_layout(x.layout());
} else { } else {
out.at(i)->set_dtype(x.dtype()); out[i]->set_dtype(x.dtype());
out.at(i)->set_dims(out_dims[i]); out[i]->set_dims(out_dims[i]);
out.at(i)->set_layout(x.layout()); out[i]->set_layout(x.layout());
out.at(i)->share_lod(x); out[i]->share_lod(x);
} }
} }
} }
......
...@@ -28,6 +28,23 @@ void SplitKernel(const Context& dev_ctx, ...@@ -28,6 +28,23 @@ void SplitKernel(const Context& dev_ctx,
const ScalarArray& num_or_sections, const ScalarArray& num_or_sections,
const Scalar& axis_scalar, const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) { std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas;
out_metas.reserve(outs.size());
std::vector<MetaTensor*> out_metas_ptr;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
out_metas_ptr.push_back(&out_metas.back());
}
phi::SplitInferMeta(x, num_or_sections, axis_scalar, out_metas_ptr, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer; std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.template Alloc<T>(outs[j]); dev_ctx.template Alloc<T>(outs[j]);
......
...@@ -27,6 +27,23 @@ void SplitKernel(const Context& dev_ctx, ...@@ -27,6 +27,23 @@ void SplitKernel(const Context& dev_ctx,
const ScalarArray& num_or_sections, const ScalarArray& num_or_sections,
const Scalar& axis_scalar, const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) { std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas;
out_metas.reserve(outs.size());
std::vector<MetaTensor*> out_metas_ptr;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
out_metas_ptr.push_back(&out_metas.back());
}
phi::SplitInferMeta(x, num_or_sections, axis_scalar, out_metas_ptr, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer; std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.template Alloc<T>(outs[j]); dev_ctx.template Alloc<T>(outs[j]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册