From 837dd47a968d41629979b8c59b5e00896484dec2 Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Mon, 18 May 2020 20:12:14 +0800 Subject: [PATCH] Add lod in gather/scatter (#24613) * add lod msg in gather and scatter_op, test=develop --- paddle/fluid/operators/gather_op.cc | 1 + paddle/fluid/operators/scatter_op.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 9b07628373..e3cabb9861 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -43,6 +43,7 @@ class GatherOp : public framework::OperatorWithKernel { framework::DDim output_dims(ctx->GetInputDim("X")); output_dims[0] = batch_size; ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 8781c5039d..78442cdf64 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -51,6 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Updates and Ids should have same batch-size.")); ctx->SetOutputDim("Out", ref_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: -- GitLab