diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index c2d6f124ad292bf46b4e7e9a1dcc2984aae7fcda..a4747e7c7c91b6469f1ade3fafd8eaa7002c3e40 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -52,6 +52,22 @@ class CompileTimeInferShapeContext : public InferShapeContext { const std::vector &Outputs( const std::string &name) const override; + void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, + size_t j = 0) const override { + PADDLE_ENFORCE_LT(i, Inputs(in).size()); + PADDLE_ENFORCE_LT(j, Outputs(out).size()); + auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); + auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); + if (in_var->GetType() != VarDesc::LOD_TENSOR) { + VLOG(3) << "input " << in << "is not LodTensor"; + return; + } + PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR, + "The %d-th output of Output(%s) must be LoDTensor.", j, + out); + in_var->SetLoDLevel(out_var->GetLodLevel()); + } + private: DDim GetDim(const std::string &name) const override; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 222a252dc409bf30d5d6abea95156b41cfcd221a..aa46829fdde82b58a649108bf708901299cd8153 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext { return op_.Outputs(name); } + void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const override { + PADDLE_ENFORCE_LT(i, Inputs(in).size()); + PADDLE_ENFORCE_LT(j, Outputs(out).size()); + Variable* in_var = scope_.FindVar(Inputs(in)[i]); + Variable* out_var = scope_.FindVar(Outputs(out)[j]); + if (!in_var->IsType()) return; + PADDLE_ENFORCE(out_var->IsType(), + "The %d-th output of Output(%s) must be LoDTensor.", j, out); + auto in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_lod(in_tensor.lod()); + } + private: DDim GetDim(const std::string& name) const override { Variable* var = scope_.FindVar(name); diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index 33a1d0b9b217c5d2a4b0fb63f427529e7988b24e..8169df8e4629e2d02d3dabcd6a8a102ad0077a81 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -28,9 +28,6 @@ void InferShapeContext::SetOutputsDim( SetDims(names, dims); } -void InferShapeContext::ShareLoD(const std::string &in, const std::string &out, - size_t i, size_t j) const {} - std::vector InferShapeContext::GetDims( const std::vector &names) const { std::vector ret; diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index f1f1e44bccd771be81cad7c28efe9b1b885eef6b..6f19900ef1a3e88fe78d457a03c344ea586ab551 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -43,9 +43,8 @@ class InferShapeContext { virtual const std::vector &Outputs( const std::string &name) const = 0; - // TODO(qiao) implement this function - void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, - size_t j = 0) const; + virtual void ShareLoD(const std::string &in, const std::string &out, + size_t i = 0, size_t j = 0) const = 0; protected: virtual framework::DDim GetDim(const std::string &name) const = 0; diff --git a/paddle/operators/sequence_conv_op.cc b/paddle/operators/sequence_conv_op.cc index bdb52265a529f560b4622ee037dcb3160ac90dec..a3f2ed14439572e9723c3057d212bb773b2a4e44 100644 --- a/paddle/operators/sequence_conv_op.cc +++ b/paddle/operators/sequence_conv_op.cc @@ -89,7 +89,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { } if (ctx->HasOutput(framework::GradVarName("X"))) { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD(framework::GradVarName("X"), "X"); + ctx->ShareLoD("X", framework::GradVarName("X")); } if (ctx->HasOutput(framework::GradVarName("Filter"))) { ctx->SetOutputDim(framework::GradVarName("Filter"),