未验证 提交 ae8b5f11 编写于 作者: G guofei 提交者: GitHub

Change ShareDataWith() to TensorCopy() in ref_by_trainer_id (#22717)

As the title
上级 71ab0458
...@@ -40,7 +40,8 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> { ...@@ -40,7 +40,8 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
} }
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size()); PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size());
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
out->ShareDataWith(*(in_list[trainer_id])); framework::TensorCopy(*(in_list[trainer_id]), in_list[trainer_id]->place(),
out);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册