diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index ea3d1c3ca574f79ecd73e19cd1391a7becee2d9c..9b21c1d0d12e7fdc49f451cd2ce6b629598ce55e 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -187,7 +187,9 @@ std::shared_ptr ProcessGroupCustom::Collective( std::vector& inputs, std::vector& outputs, Fn fn, - CommType op_type) { + CommType op_type, + bool sync_op, + bool use_calc_stream) { const auto places = GetPlaceList(inputs); const auto key = GetKeyFromPlaces(places); @@ -199,20 +201,28 @@ std::shared_ptr ProcessGroupCustom::Collective( } auto& ccl_comms = places_to_customcomm_[key]; - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + if (!use_calc_stream) { + SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + } auto task = CreateTask(places, rank_, op_type, inputs); task->SetOutputs(outputs); for (size_t i = 0; i < inputs.size(); ++i) { phi::DeviceGuard guard(places[i]); - const auto& ccl_stream = places_to_ctx_[key][i]->stream(); + const auto& ccl_stream = + use_calc_stream ? reinterpret_cast( + phi::DeviceContextPool::Instance().Get(places[i])) + ->stream() + : places_to_ctx_[key][i]->stream(); phi::stream::Stream stream(places[i], ccl_stream); fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream); } - for (size_t i = 0; i < inputs.size(); ++i) { - phi::DeviceGuard guard(places[i]); - task->control_events_[i].Record(*places_to_ctx_[key][i]); + if (!use_calc_stream) { + for (size_t i = 0; i < inputs.size(); ++i) { + phi::DeviceGuard guard(places[i]); + task->control_events_[i].Record(*places_to_ctx_[key][i]); + } } return task; } @@ -280,7 +290,9 @@ std::shared_ptr ProcessGroupCustom::AllGather( comm, stream); }, - CommType::ALLGATHER); + CommType::ALLGATHER, + sync_op, + use_calc_stream); } std::shared_ptr ProcessGroupCustom::AllGather( @@ -322,7 +334,9 @@ std::shared_ptr ProcessGroupCustom::AllGather( comm, stream); }, - CommType::ALLGATHER); + CommType::ALLGATHER, + false, + false); } std::shared_ptr ProcessGroupCustom::AllReduce( @@ -333,7 +347,36 @@ std::shared_ptr ProcessGroupCustom::AllReduce( bool use_calc_stream) { std::vector in_wrapper{in_tensor}; std::vector out_wrapper{*out_tensor}; - return AllReduce(in_wrapper, out_wrapper, opts); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_wrapper, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_wrapper, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); + return Collective( + in_wrapper, + out_wrapper, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + phi::ccl::CCLComm comm, + const phi::stream::Stream& stream) { + return phi::DeviceManager::CCLAllReduce( + device_type_, + input.data(), + output.data(), + input.numel(), + phi::ccl::ToCCLDataType(input.dtype()), + ToCustomCCLRedType(opts.reduce_op), + comm, + stream); + }, + CommType::ALLREDUCE, + sync_op, + use_calc_stream); } std::shared_ptr ProcessGroupCustom::AllReduce( @@ -342,9 +385,7 @@ std::shared_ptr ProcessGroupCustom::AllReduce( const AllreduceOptions& opts, bool sync_op // for compatibility, no use now ) { - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; - return AllReduce(in_wrapper, out_wrapper, opts); + return AllReduce(out_tensor, in_tensor, opts, sync_op, false); } std::shared_ptr ProcessGroupCustom::AllReduce( @@ -378,7 +419,9 @@ std::shared_ptr ProcessGroupCustom::AllReduce( comm, stream); }, - CommType::ALLREDUCE); + CommType::ALLREDUCE, + false, + false); } std::shared_ptr ProcessGroupCustom::Broadcast( @@ -389,7 +432,47 @@ std::shared_ptr ProcessGroupCustom::Broadcast( bool use_calc_stream) { std::vector in_wrapper{in_tensor}; std::vector out_wrapper{*out_tensor}; - return Broadcast(in_wrapper, out_wrapper, opts); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_wrapper, device_type_), + true, + platform::errors::InvalidArgument( + "All inputs should be in CustomPlace(%s).", device_type_)); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_wrapper, device_type_), + true, + platform::errors::InvalidArgument( + "All outputs should be in CustomPlace(%s).", device_type_)); + return Collective( + in_wrapper, + out_wrapper, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + phi::ccl::CCLComm comm, + const phi::stream::Stream& stream) { + int root = opts.source_rank * in_wrapper.size() + opts.source_root; + if (rank_ == root) { + return phi::DeviceManager::CCLBroadcast( + device_type_, + input.data(), + input.numel(), + phi::ccl::ToCCLDataType(input.dtype()), + root, + comm, + stream); + } else { + return phi::DeviceManager::CCLBroadcast( + device_type_, + output.data(), + output.numel(), + phi::ccl::ToCCLDataType(output.dtype()), + root, + comm, + stream); + } + }, + CommType::BROADCAST, + sync_op, + use_calc_stream); } std::shared_ptr ProcessGroupCustom::Broadcast( @@ -397,9 +480,7 @@ std::shared_ptr ProcessGroupCustom::Broadcast( const phi::DenseTensor& in_tensor, const BroadcastOptions& opts, bool sync_op) { - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; - return Broadcast(in_wrapper, out_wrapper, opts); + return Broadcast(out_tensor, in_tensor, opts, sync_op, false); } std::shared_ptr ProcessGroupCustom::Barrier( @@ -489,7 +570,9 @@ std::shared_ptr ProcessGroupCustom::Broadcast( stream); } }, - CommType::BROADCAST); + CommType::BROADCAST, + false, + false); } std::shared_ptr diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index ac072a0897c254aa332770043b141ca4eca70877..85fd1dcd0fcb8c79331e292bc427b4f4f7190779 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -176,7 +176,9 @@ class ProcessGroupCustom : public ProcessGroupWithStream { std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT Fn fn, - CommType op_type); + CommType op_type, + bool sync_op, + bool use_calc_stream); void CreateCustomManagerCache(const std::string& places_key, const std::vector& places); diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index babc4d86b363d04caf45b263e1b767cd803d6612..e4c927183f687e10b344f41d13bce52bf56b8a3d 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/operators/collective/c_concat_op.h" +#include "paddle/fluid/operators/collective/c_identity_op.h" #include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/save_combine_op.h" @@ -589,6 +590,21 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { paddle::platform::CustomDeviceContext, paddle::platform::float16>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_identity, + device_type, + paddle::operators:: + CIdentityOpKernel, + paddle::operators:: + CIdentityOpKernel, + paddle::operators:: + CIdentityOpKernel, + paddle::operators:: + CIdentityOpKernel, + paddle::operators::CIdentityOpKernel< + paddle::platform::float16, + paddle::platform::CustomDeviceContext>) {} + #endif }