From cdf3a4c244fa2bf0a09a5f440e3d2d7453072cc7 Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 21 Sep 2018 18:22:03 +0800 Subject: [PATCH] Fix concat_op InferShape (#13513) * add ShareLoDs * refine * add Is EmptyVarName * refine Sharedlod --- paddle/fluid/framework/op_desc.cc | 5 +++++ paddle/fluid/framework/shape_inference.cc | 10 ++++++++++ paddle/fluid/framework/shape_inference.h | 2 ++ paddle/fluid/operators/concat_op.cc | 16 ++++++++++++++-- 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 86f6147cf7..17f942571d 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -54,6 +54,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { size_t j = 0) const override { PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size()); + PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", in, i); + PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", out, j); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); if (in_var->GetType() != proto::VarType::LOD_TENSOR) { @@ -63,6 +67,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR, "The %d-th output of Output(%s) must be LoDTensor.", j, out); + out_var->SetLoDLevel(in_var->GetLoDLevel()); } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index ddff2c7c26..89eb00ff65 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -46,6 +46,16 @@ std::vector InferShapeContext::GetReaderDims( return this->GetRepeatedDims(arg_names[0]); } +void InferShapeContext::ShareLoDs(const std::string &in, + const std::string &out) const { + PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(), + "The number of arguments in %s and %s is not equal.", in, + out); + for (size_t i = 0; i < in.size(); ++i) { + ShareLoD(in, out, i, i); + } +} + DDim InferShapeContext::GetInputsElementDim(const std::string &name, int idx) const { const std::vector &names = Inputs(name); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 5f497cafa0..fd220d961a 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -56,6 +56,8 @@ class InferShapeContext { virtual const std::vector &Outputs( const std::string &name) const = 0; + void ShareLoDs(const std::string &in, const std::string &out) const; + virtual void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) const = 0; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index bc58612f9d..57817da71a 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -94,8 +94,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); - ctx->ShareLoD("X", framework::GradVarName("X")); + auto in_x = "X"; + auto out_x_g_n = framework::GradVarName(in_x); + ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x)); + auto &in_names = ctx->Inputs(in_x); + auto &out_names = ctx->Outputs(out_x_g_n); + PADDLE_ENFORCE_EQ( + in_names.size(), out_names.size(), + "The number of arguments in %s[%d] and %s[%d] is not equal.", in_x, + in_names.size(), out_x_g_n, out_names.size()); + for (size_t i = 0; i < in_names.size(); ++i) { + if (out_names[i] != framework::kEmptyVarName) { + ctx->ShareLoD(in_x, out_x_g_n, i, i); + } + } } }; -- GitLab