diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index ab77d245f966aedce68c46ff96c419333e6fa83a..070b1f217094bc401ffcedd443cca6fa6f402c2e 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -243,11 +243,13 @@ std::shared_ptr ProcessGroupCustom::AllGather( numel > 0 ? paddle::distributed::GetPartialTensor(in_tensor, offset, numel) : in_tensor; - phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, - in_tensor_maybe_partial, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::GatherLikeShape( + *out_tensor, + in_tensor_maybe_partial, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); std::vector in_wrapper{in_tensor_maybe_partial}; std::vector out_wrapper{*out_tensor}; @@ -276,7 +278,7 @@ std::shared_ptr ProcessGroupCustom::AllGather( int64_t offset, int64_t numel, bool sync_op) { - return AllGather(out_tensor, in_tensor, offset, numel, sync_op); + return AllGather(out_tensor, in_tensor, offset, numel, sync_op, false); } // TODO(sunyilun): methods below will be removed later