未验证 提交 5d0a1fd0 编写于 作者: R RichardWooSJTU 提交者: GitHub

modify interface of allreduce (#56254)

上级 d3f98088
...@@ -265,11 +265,6 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -265,11 +265,6 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
if (map->has(rid)) { if (map->has(rid)) {
// Use ProcessGroup // Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid); distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);
distributed::AllreduceOptions opts; distributed::AllreduceOptions opts;
switch (red_type) { switch (red_type) {
case kRedSum: case kRedSum:
...@@ -293,7 +288,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -293,7 +288,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type)); "Invalid reduce type: %d", red_type));
} }
auto task = pg->AllReduce(in_tensor, out_tensor, opts); auto task = pg->AllReduce(out, *in, opts, false, true);
task->Wait(); task->Wait();
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册