diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 9b0762837303a26c50bca762a790720dd6e687ad..e3cabb986133b378b98e62de4c25dd1bcf8309e1 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -43,6 +43,7 @@ class GatherOp : public framework::OperatorWithKernel { framework::DDim output_dims(ctx->GetInputDim("X")); output_dims[0] = batch_size; ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 8781c5039dddc1a6677219f8ce1acc54bbe082ef..78442cdf641af7f566d112773949fafeeac3e687 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -51,6 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Updates and Ids should have same batch-size.")); ctx->SetOutputDim("Out", ref_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: