From 8fc462099e4aba705bc432bc45c2d5adc8ceda61 Mon Sep 17 00:00:00 2001 From: Kai Song <50285351+USTCKAY@users.noreply.github.com> Date: Tue, 7 Mar 2023 15:00:38 +0800 Subject: [PATCH] [CustomDevice]fix ProcessGroupCustom::AllGather (#51157) --- .../distributed/collective/process_group_custom.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index ab77d245f96..070b1f21709 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -243,11 +243,13 @@ std::shared_ptr ProcessGroupCustom::AllGather( numel > 0 ? paddle::distributed::GetPartialTensor(in_tensor, offset, numel) : in_tensor; - phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, - in_tensor_maybe_partial, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::GatherLikeShape( + *out_tensor, + in_tensor_maybe_partial, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); std::vector in_wrapper{in_tensor_maybe_partial}; std::vector out_wrapper{*out_tensor}; @@ -276,7 +278,7 @@ std::shared_ptr ProcessGroupCustom::AllGather( int64_t offset, int64_t numel, bool sync_op) { - return AllGather(out_tensor, in_tensor, offset, numel, sync_op); + return AllGather(out_tensor, in_tensor, offset, numel, sync_op, false); } // TODO(sunyilun): methods below will be removed later -- GitLab