From ad76d37e8ace4e0a0d74a719c76945ffe7f9edb5 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Wed, 29 Mar 2023 13:01:43 +0800 Subject: [PATCH] fix bkcl_all_gather and c_embedding_grad bug for xpu (#51785) --- .../collective/process_group_bkcl.cc | 243 ++++++++++-------- .../collective/c_embedding_op_xpu.cc | 13 +- 2 files changed, 148 insertions(+), 108 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index 1b6f512ec8a..47dd2241c2c 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" @@ -82,10 +83,12 @@ ProcessGroupBKCL::ProcessGroupBKCL( : ProcessGroupWithStream(rank, size, gid), store_(store) {} void ProcessGroupBKCL::GroupStart() { + VLOG(3) << "bkcl_group_start"; PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); } void ProcessGroupBKCL::GroupEnd() { + VLOG(3) << "bkcl_group_end"; PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end()); } @@ -112,13 +115,15 @@ std::shared_ptr ProcessGroupBKCL::Recv( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_recv(comm, - output->data(), - output->numel(), - src_rank, - platform::ToBKCLDataType( - framework::TransToProtoVarType(output->type())), - stream); + VLOG(3) << "bkcl_recv"; + int r = bkcl_recv(comm, + output->data(), + output->numel(), + src_rank, + platform::ToBKCLDataType( + framework::TransToProtoVarType(output->type())), + stream); + return r; }, CommType::RECV, sync_op, @@ -143,13 +148,15 @@ std::shared_ptr ProcessGroupBKCL::Send( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_send(comm, - input.data(), - input.numel(), - dst_rank, - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); + VLOG(3) << "bkcl_send"; + int r = bkcl_send(comm, + input.data(), + input.numel(), + dst_rank, + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + stream); + return r; }, CommType::SEND, sync_op, @@ -269,15 +276,17 @@ std::shared_ptr ProcessGroupBKCL::AllReduce( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_all_reduce( - comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - stream); + VLOG(3) << "bkcl_all_reduce"; + int r = + bkcl_all_reduce(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + stream); + return r; }, CommType::ALLREDUCE, sync_op, @@ -298,14 +307,17 @@ std::shared_ptr ProcessGroupBKCL::Broadcast( BKCLContext_t comm, const XPUStream& stream) { int root = opts.source_rank + opts.source_root; - return bkcl_broadcast(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - root, - stream); + VLOG(3) << "bkcl_broadcast"; + int r = + bkcl_broadcast(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + root, + stream); + return r; }, CommType::BROADCAST, sync_op, @@ -315,10 +327,18 @@ std::shared_ptr ProcessGroupBKCL::Broadcast( std::shared_ptr ProcessGroupBKCL::AllGather( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - int64_t offset, // for compatibility, no use now - int64_t numel, // for compatibility, no use now + int64_t offset, + int64_t numel, bool sync_op, bool use_calc_stream) { + const phi::DenseTensor& in_tensor_maybe_partial = + numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; + phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, + in_tensor_maybe_partial, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::XPU); return Collective( out_tensor, in_tensor, @@ -326,14 +346,16 @@ std::shared_ptr ProcessGroupBKCL::AllGather( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_all_gather( - comm, - input.data(), - input.numel(), - output->data(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); + VLOG(3) << "bkcl_all_gather"; + int r = + bkcl_all_gather(comm, + in_tensor_maybe_partial.data(), + in_tensor_maybe_partial.numel(), + output->data(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + stream); + return r; }, CommType::ALLGATHER, sync_op, @@ -353,15 +375,17 @@ std::shared_ptr ProcessGroupBKCL::Reduce( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_reduce(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - opts.root_rank, - stream); + VLOG(3) << "bkcl_reduce"; + int r = bkcl_reduce(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + opts.root_rank, + stream); + return r; }, CommType::REDUCE, sync_op, @@ -381,7 +405,8 @@ std::shared_ptr ProcessGroupBKCL::ReduceScatter( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_reduce_scatter( + VLOG(3) << "bkcl_reduce_scatter"; + int r = bkcl_reduce_scatter( comm, input.data(), output->data(), @@ -390,6 +415,7 @@ std::shared_ptr ProcessGroupBKCL::ReduceScatter( framework::TransToProtoVarType(input.type())), ToBKCLRedType(opts.reduce_op), stream); + return r; }, CommType::REDUCE_SCATTER, sync_op, @@ -465,15 +491,18 @@ std::shared_ptr ProcessGroupBKCL::AllReduce( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_all_reduce( - comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - stream); + VLOG(3) << "bkcl_all_reduce"; + + int r = + bkcl_all_reduce(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + stream); + return r; }, CommType::ALLREDUCE, /*sync_op*/ true, @@ -506,15 +535,17 @@ std::shared_ptr ProcessGroupBKCL::AllReduce( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_all_reduce( - comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - stream); + VLOG(3) << "bkcl_all_reduce"; + int r = + bkcl_all_reduce(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + stream); + return r; }, CommType::ALLREDUCE, sync_op, @@ -549,14 +580,17 @@ std::shared_ptr ProcessGroupBKCL::Broadcast( const XPUStream& stream) { const auto root = opts.source_rank * in_tensors.size() + opts.source_root; - return bkcl_broadcast(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - root, - stream); + VLOG(3) << "bkcl_broadcast"; + int r = + bkcl_broadcast(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + root, + stream); + return r; }, CommType::BROADCAST, /*sync_op*/ true, @@ -592,14 +626,17 @@ std::shared_ptr ProcessGroupBKCL::Broadcast( const XPUStream& stream) { const auto root = opts.source_rank * in_tensors.size() + opts.source_root; - return bkcl_broadcast(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - root, - stream); + VLOG(3) << "bkcl_broadcast"; + int r = + bkcl_broadcast(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + root, + stream); + return r; }, CommType::BROADCAST, sync_op, @@ -634,14 +671,16 @@ std::shared_ptr ProcessGroupBKCL::AllGather( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_all_gather( - comm, - input.data(), - input.numel(), - output->data(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); + VLOG(3) << "bkcl_all_gather"; + int r = + bkcl_all_gather(comm, + input.data(), + input.numel(), + output->data(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + stream); + return r; }, CommType::ALLGATHER, /*sync_op*/ true, @@ -673,14 +712,16 @@ std::shared_ptr ProcessGroupBKCL::AllGather( const phi::DenseTensor& input, BKCLContext_t comm, const XPUStream& stream) { - return bkcl_all_gather( - comm, - input.data(), - input.numel(), - output->data(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); + VLOG(3) << "bkcl_all_gather"; + int r = + bkcl_all_gather(comm, + input.data(), + input.numel(), + output->data(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + stream); + return r; }, CommType::ALLGATHER, sync_op, diff --git a/paddle/fluid/operators/collective/c_embedding_op_xpu.cc b/paddle/fluid/operators/collective/c_embedding_op_xpu.cc index e3f54ebfbeb..8590ff25730 100644 --- a/paddle/fluid/operators/collective/c_embedding_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_embedding_op_xpu.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -83,8 +80,10 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel { auto table_grad_t = context.Output(framework::GradVarName("W")); - T* table_grad_data = - table_grad_t->mutable_data(table_t->dims(), context.GetPlace()); + auto& dev_ctx = context.template device_context(); + table_grad_t->Resize(table_t->dims()); + dev_ctx.template Alloc(table_grad_t, table_t->dtype()); + T* table_grad_data = static_cast(table_grad_t->data()); size_t table_t_mem_size = table_t->numel() * phi::SizeOf(table_grad_t->dtype()); @@ -98,9 +97,8 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel { << ", table_grad_t memory_size:" << table_grad_t_mem_size << ", start_index:" << start_idx; - auto& dev_ctx = context.template device_context(); int r = xpu::constant( - dev_ctx.x_context(), table_grad_data, table_grad_t_mem_size, (T)0); + dev_ctx.x_context(), table_grad_data, table_grad_t->numel(), (T)0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); const T* d_output_data = d_output_t->data(); @@ -132,6 +130,7 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel { PADDLE_THROW(platform::errors::Unavailable( "XPU c_embedding ids only support int32 or int64.")); } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad"); } }; -- GitLab