未验证 提交 3d1741b7 编写于 作者: L liuyuhui 提交者: GitHub

[Kunlun] bug fix of PR2: Support MultiDevicePass and BKCL in parallel executor (#29926)

上级 332da133
...@@ -39,10 +39,13 @@ class Graph; ...@@ -39,10 +39,13 @@ class Graph;
namespace paddle { namespace paddle {
namespace platform { namespace platform {
#if defined(PADDLE_WITH_NCCL)
class NCCLContextMap; class NCCLContextMap;
class NCCLCommunicator; class NCCLCommunicator;
#elif defined(PADDLE_WITH_XPU_BKCL)
class BKCLContextMap; class BKCLContextMap;
class BKCLCommunicator; class BKCLCommunicator;
#endif
} }
namespace framework { namespace framework {
......
...@@ -968,9 +968,6 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -968,9 +968,6 @@ void ParallelExecutor::BCastParamsToDevices(
continue; continue;
} }
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
VLOG(1) << "bcast var=" << var;
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
std::vector<void *> buffers; std::vector<void *> buffers;
...@@ -1013,6 +1010,11 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -1013,6 +1010,11 @@ void ParallelExecutor::BCastParamsToDevices(
std::vector<void *> buffers; std::vector<void *> buffers;
buffers.reserve(member_->places_.size()); buffers.reserve(member_->places_.size());
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
// TODO(liuyuhui): BKCL only support parameters using float type,
// other parameters need to be strongly converted to float before
// broadcasting,
// but broadcast is equivalent to no type of operation, does not affect
// correctness.
BKCLDataType data_type = BKCL_FLOAT; BKCLDataType data_type = BKCL_FLOAT;
// BKCLDataType data_type = platform::ToBKCLDataType(main_tensor.type()); // BKCLDataType data_type = platform::ToBKCLDataType(main_tensor.type());
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
......
...@@ -123,7 +123,7 @@ class XPUDeviceContext : public DeviceContext { ...@@ -123,7 +123,7 @@ class XPUDeviceContext : public DeviceContext {
void Wait() const override; void Wait() const override;
#ifdef PADDLE_WITH_XPU_BKCL #ifdef PADDLE_WITH_XPU_BKCL
/*! \brief Return nccl context. */ /*! \brief Return bkcl context. */
BKCLContext_t bkcl_context() const { return bkcl_context_; } BKCLContext_t bkcl_context() const { return bkcl_context_; }
/*! \brief Set bkcl context. */ /*! \brief Set bkcl context. */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册