未验证 提交 8fc46209 编写于 作者: K Kai Song 提交者: GitHub

[CustomDevice]fix ProcessGroupCustom::AllGather (#51157)

上级 017452e9
......@@ -243,11 +243,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
numel > 0
? paddle::distributed::GetPartialTensor(in_tensor, offset, numel)
: in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
phi::distributed::CommStaticCheck::GatherLikeShape(
*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor_maybe_partial};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
......@@ -276,7 +278,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
int64_t offset,
int64_t numel,
bool sync_op) {
return AllGather(out_tensor, in_tensor, offset, numel, sync_op);
return AllGather(out_tensor, in_tensor, offset, numel, sync_op, false);
}
// TODO(sunyilun): methods below will be removed later
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册