未验证 提交 837dd47a 编写于 作者: S ShenLiang 提交者: GitHub

Add lod in gather/scatter (#24613)

* add lod msg in gather and scatter_op, test=develop
上级 85ff7974
...@@ -43,6 +43,7 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -43,6 +43,7 @@ class GatherOp : public framework::OperatorWithKernel {
framework::DDim output_dims(ctx->GetInputDim("X")); framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims); ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
protected: protected:
......
...@@ -51,6 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -51,6 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Updates and Ids should have same batch-size.")); "Updates and Ids should have same batch-size."));
ctx->SetOutputDim("Out", ref_dims); ctx->SetOutputDim("Out", ref_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册