diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h index 42d22bfe6d40f8db2d44009b2d06929e4fef364a..97d3a40874b31c7de80da7b5fefddd6542c96d3e 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h @@ -39,10 +39,13 @@ class Graph; namespace paddle { namespace platform { +#if defined(PADDLE_WITH_NCCL) class NCCLContextMap; class NCCLCommunicator; +#elif defined(PADDLE_WITH_XPU_BKCL) class BKCLContextMap; class BKCLCommunicator; +#endif } namespace framework { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8f38a56e98f491601bc39fc885be3677582eadcf..947a3c9455f1c71f59b8f129ea800d44282cbe61 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -968,9 +968,6 @@ void ParallelExecutor::BCastParamsToDevices( continue; } auto &dims = main_tensor.dims(); - - VLOG(1) << "bcast var=" << var; - if (paddle::platform::is_gpu_place(main_tensor.place())) { #if defined(PADDLE_WITH_NCCL) std::vector buffers; @@ -1013,6 +1010,11 @@ void ParallelExecutor::BCastParamsToDevices( std::vector buffers; buffers.reserve(member_->places_.size()); size_t numel = main_tensor.numel(); + // TODO(liuyuhui): BKCL only support parameters using float type, + // other parameters need to be strongly converted to float before + // broadcasting, + // but broadcast is equivalent to no type of operation, does not affect + // correctness. BKCLDataType data_type = BKCL_FLOAT; // BKCLDataType data_type = platform::ToBKCLDataType(main_tensor.type()); for (size_t i = 0; i < member_->places_.size(); ++i) { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 8e5363fafa3761566204560ecc410301545b15a1..4e79e645aaae12c563f5ceb82fdd85ec6416aac5 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -123,7 +123,7 @@ class XPUDeviceContext : public DeviceContext { void Wait() const override; #ifdef PADDLE_WITH_XPU_BKCL - /*! \brief Return nccl context. */ + /*! \brief Return bkcl context. */ BKCLContext_t bkcl_context() const { return bkcl_context_; } /*! \brief Set bkcl context. */