未验证 提交 8fc46209 编写于 作者: K Kai Song 提交者: GitHub

[CustomDevice]fix ProcessGroupCustom::AllGather (#51157)

上级 017452e9
...@@ -243,11 +243,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -243,11 +243,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
numel > 0 numel > 0
? paddle::distributed::GetPartialTensor(in_tensor, offset, numel) ? paddle::distributed::GetPartialTensor(in_tensor, offset, numel)
: in_tensor; : in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, phi::distributed::CommStaticCheck::GatherLikeShape(
in_tensor_maybe_partial, *out_tensor,
/*dst_rank*/ rank_, in_tensor_maybe_partial,
/*cur_rank*/ rank_, /*dst_rank*/ rank_,
size_); /*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor_maybe_partial}; std::vector<phi::DenseTensor> in_wrapper{in_tensor_maybe_partial};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
...@@ -276,7 +278,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -276,7 +278,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op) { bool sync_op) {
return AllGather(out_tensor, in_tensor, offset, numel, sync_op); return AllGather(out_tensor, in_tensor, offset, numel, sync_op, false);
} }
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册