未验证 提交 ee11f006 编写于 作者: Q Qiao Longfei 提交者: GitHub

add shareLod (#5259)

* add shareLod

* fix sequence_conv grad infershape
上级 360cb183
......@@ -52,6 +52,22 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs(
const std::string &name) const override;
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
VLOG(3) << "input " << in << "is not LodTensor";
return;
}
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
in_var->SetLoDLevel(out_var->GetLodLevel());
}
private:
DDim GetDim(const std::string &name) const override;
......
......@@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name);
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
}
private:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
......
......@@ -28,9 +28,6 @@ void InferShapeContext::SetOutputsDim(
SetDims(names, dims);
}
void InferShapeContext::ShareLoD(const std::string &in, const std::string &out,
size_t i, size_t j) const {}
std::vector<framework::DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret;
......
......@@ -43,9 +43,8 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
// TODO(qiao) implement this function
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const;
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
protected:
virtual framework::DDim GetDim(const std::string &name) const = 0;
......
......@@ -89,7 +89,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD(framework::GradVarName("X"), "X");
ctx->ShareLoD("X", framework::GradVarName("X"));
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册