未验证 提交 5d0a1fd0 编写于 作者: R RichardWooSJTU 提交者: GitHub

modify interface of allreduce (#56254)

上级 d3f98088
......@@ -265,11 +265,6 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);
distributed::AllreduceOptions opts;
switch (red_type) {
case kRedSum:
......@@ -293,7 +288,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type));
}
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
auto task = pg->AllReduce(out, *in, opts, false, true);
task->Wait();
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册