From 3d1741b79403fe5424ebae6b5d55f33f8bae2362 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Mon, 28 Dec 2020 13:20:02 +0800 Subject: [PATCH] [Kunlun] bug fix of PR2: Support MultiDevicePass and BKCL in parallel executor (#29926) --- .../multi_devices_graph_pass/multi_devices_graph_pass.h | 3 +++ paddle/fluid/framework/parallel_executor.cc | 8 +++++--- paddle/fluid/platform/device_context.h | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h index 42d22bfe6d4..97d3a40874b 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h @@ -39,10 +39,13 @@ class Graph; namespace paddle { namespace platform { +#if defined(PADDLE_WITH_NCCL) class NCCLContextMap; class NCCLCommunicator; +#elif defined(PADDLE_WITH_XPU_BKCL) class BKCLContextMap; class BKCLCommunicator; +#endif } namespace framework { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8f38a56e98f..947a3c9455f 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -968,9 +968,6 @@ void ParallelExecutor::BCastParamsToDevices( continue; } auto &dims = main_tensor.dims(); - - VLOG(1) << "bcast var=" << var; - if (paddle::platform::is_gpu_place(main_tensor.place())) { #if defined(PADDLE_WITH_NCCL) std::vector buffers; @@ -1013,6 +1010,11 @@ void ParallelExecutor::BCastParamsToDevices( std::vector buffers; buffers.reserve(member_->places_.size()); size_t numel = main_tensor.numel(); + // TODO(liuyuhui): BKCL only support parameters using float type, + // other parameters need to be strongly converted to float before + // broadcasting, + // but broadcast is equivalent to no type of operation, does not affect + // correctness. BKCLDataType data_type = BKCL_FLOAT; // BKCLDataType data_type = platform::ToBKCLDataType(main_tensor.type()); for (size_t i = 0; i < member_->places_.size(); ++i) { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 8e5363fafa3..4e79e645aaa 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -123,7 +123,7 @@ class XPUDeviceContext : public DeviceContext { void Wait() const override; #ifdef PADDLE_WITH_XPU_BKCL - /*! \brief Return nccl context. */ + /*! \brief Return bkcl context. */ BKCLContext_t bkcl_context() const { return bkcl_context_; } /*! \brief Set bkcl context. */ -- GitLab