diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 86f6147cf7ac1e82ac2904bbcdcf9697422560ce..17f942571d0141537e992be9ab73847d2a794698 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -54,6 +54,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { size_t j = 0) const override { PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size()); + PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", in, i); + PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", out, j); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); if (in_var->GetType() != proto::VarType::LOD_TENSOR) { @@ -63,6 +67,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR, "The %d-th output of Output(%s) must be LoDTensor.", j, out); + out_var->SetLoDLevel(in_var->GetLoDLevel()); } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index ddff2c7c261746ac9986e79cff3da7e0a9654adc..89eb00ff65598eff5f4ba541df107e8da04e1a89 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -46,6 +46,16 @@ std::vector InferShapeContext::GetReaderDims( return this->GetRepeatedDims(arg_names[0]); } +void InferShapeContext::ShareLoDs(const std::string &in, + const std::string &out) const { + PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(), + "The number of arguments in %s and %s is not equal.", in, + out); + for (size_t i = 0; i < in.size(); ++i) { + ShareLoD(in, out, i, i); + } +} + DDim InferShapeContext::GetInputsElementDim(const std::string &name, int idx) const { const std::vector &names = Inputs(name); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 5f497cafa0f75f7c23d550ef767d55274de7c900..fd220d961af85dd55fe2031409180823d8f178fc 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -56,6 +56,8 @@ class InferShapeContext { virtual const std::vector &Outputs( const std::string &name) const = 0; + void ShareLoDs(const std::string &in, const std::string &out) const; + virtual void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) const = 0; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index bc58612f9d3a2b433f362787135b6bb23b203f63..57817da71adfd80faad29a48b05ba2f326de6c07 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -94,8 +94,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); - ctx->ShareLoD("X", framework::GradVarName("X")); + auto in_x = "X"; + auto out_x_g_n = framework::GradVarName(in_x); + ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x)); + auto &in_names = ctx->Inputs(in_x); + auto &out_names = ctx->Outputs(out_x_g_n); + PADDLE_ENFORCE_EQ( + in_names.size(), out_names.size(), + "The number of arguments in %s[%d] and %s[%d] is not equal.", in_x, + in_names.size(), out_x_g_n, out_names.size()); + for (size_t i = 0; i < in_names.size(); ++i) { + if (out_names[i] != framework::kEmptyVarName) { + ctx->ShareLoD(in_x, out_x_g_n, i, i); + } + } } };