未验证 提交 ee11f006 编写于 作者: Q Qiao Longfei 提交者: GitHub

add shareLod (#5259)

* add shareLod

* fix sequence_conv grad infershape
上级 360cb183
...@@ -52,6 +52,22 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -52,6 +52,22 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs( const std::vector<std::string> &Outputs(
const std::string &name) const override; const std::string &name) const override;
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
VLOG(3) << "input " << in << "is not LodTensor";
return;
}
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
in_var->SetLoDLevel(out_var->GetLodLevel());
}
private: private:
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const override;
......
...@@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name); return op_.Outputs(name);
} }
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
}
private: private:
DDim GetDim(const std::string& name) const override { DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
......
...@@ -28,9 +28,6 @@ void InferShapeContext::SetOutputsDim( ...@@ -28,9 +28,6 @@ void InferShapeContext::SetOutputsDim(
SetDims(names, dims); SetDims(names, dims);
} }
void InferShapeContext::ShareLoD(const std::string &in, const std::string &out,
size_t i, size_t j) const {}
std::vector<framework::DDim> InferShapeContext::GetDims( std::vector<framework::DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret; std::vector<framework::DDim> ret;
......
...@@ -43,9 +43,8 @@ class InferShapeContext { ...@@ -43,9 +43,8 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs( virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0; const std::string &name) const = 0;
// TODO(qiao) implement this function virtual void ShareLoD(const std::string &in, const std::string &out,
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t i = 0, size_t j = 0) const = 0;
size_t j = 0) const;
protected: protected:
virtual framework::DDim GetDim(const std::string &name) const = 0; virtual framework::DDim GetDim(const std::string &name) const = 0;
......
...@@ -89,7 +89,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -89,7 +89,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
} }
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD(framework::GradVarName("X"), "X"); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
if (ctx->HasOutput(framework::GradVarName("Filter"))) { if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), ctx->SetOutputDim(framework::GradVarName("Filter"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册