未验证 提交 51cc73f0 编写于 作者: X xiongkun 提交者: GitHub

Intergrate GLOOParallelContext to support Multi-CPU Core for Dygraph DataParallel (#35154)

* can pass the fake test

* add files

* modify cmake to pass windows-ci

* for ci pass

* WITH_GLOO=ON

* for pass coverage test

* add cpuonly testcase

* add

* disable nccl when compile with cuda

* change python version in cpuonly

* add backend argument

* add required gpu

* add required:gpu
上级 692ac3e5
......@@ -29,6 +29,12 @@ if(NOT WIN32)
endif()
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32)
if(WITH_GLOO)
cc_library(imperative_gloo_context SRCS gloo_context.cc DEPS collective_helper device_context tensor var_type_traits)
if ( WIN32 OR (NOT (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL) ))
cc_library(reducer SRCS reducer.cc DEPS layer)
endif()
endif()
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/gloo_context.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace imperative {
void GLOOParallelContext::Init() {
// PADDLE_THROW(platform::errors::OutOfRange(
// "Still not implement Init"));
VLOG(4) << "Start GLOOParallelContext initialization";
auto gloo_wrapper = framework::GlooWrapper::GetInstance();
gloo_wrapper->SetSize(strategy_.nranks_);
gloo_wrapper->SetRank(strategy_.local_rank_);
gloo_wrapper->SetPrefix("");
gloo_wrapper->SetIface("lo");
auto addr = paddle::string::Split(strategy_.trainer_endpoints_[0], ':');
VLOG(4) << "Server is" << strategy_.trainer_endpoints_[0];
std::string host = addr[0];
int port = std::stoi(addr[1]);
gloo_wrapper->SetHttpStore(host, port, "worker");
gloo_wrapper->Init();
device_ = std::unique_ptr<platform::CPUDeviceContext>(
new platform::CPUDeviceContext(platform::CPUPlace()));
}
void GLOOParallelContext::InitWithRingID(int ring_id) {
PADDLE_THROW(
platform::errors::OutOfRange("Still not implement InitWithRingID"));
}
#define GLOO_CASE(type, T, gw) \
case type: { \
VLOG(4) << "Use the gloo all reduce to sync. SRC:" << src_tensor; \
std::vector<T> send_vector##T; \
framework::TensorToVector<T>(src_tensor, &send_vector##T); \
auto recv_vector##T = gw->AllReduce<T>(send_vector##T); \
framework::TensorFromVector<T>(recv_vector##T, dst_tensor); \
VLOG(4) << "DST:" << *dst_tensor; \
break; \
}
void GLOOParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id, bool use_calc_stream) {
// AllReduce(src, dst, strategy_, ring_id, use_calc_stream);
auto src_tensor = src.Get<framework::LoDTensor>();
auto *dst_tensor = dst->GetMutable<framework::LoDTensor>();
auto gloo_wrapper = framework::GlooWrapper::GetInstance();
dst_tensor->Resize(src_tensor.dims());
switch (src_tensor.type()) {
GLOO_CASE(framework::proto::VarType::FP32, float, gloo_wrapper);
GLOO_CASE(framework::proto::VarType::FP64, double, gloo_wrapper);
GLOO_CASE(framework::proto::VarType::INT32, int, gloo_wrapper);
GLOO_CASE(framework::proto::VarType::INT64, int64_t, gloo_wrapper);
default: {
PADDLE_THROW(
platform::errors::InvalidArgument("Invalid datatype for allreduce"));
}
}
gloo_wrapper->Barrier();
}
paddle::platform::DeviceContext *GLOOParallelContext::GetDeviceContext(
int ring_id) {
// return the CPUDeviceContext
return device_.get();
}
void GLOOParallelContext::WaitCompute(int ring_id) {
// do nothing because cpu don't need sync
return;
}
void GLOOParallelContext::WaitComm(int ring_id) {
// do nothing because cpu don't need sync
return;
}
void GLOOParallelContext::SynchronizeCompute() {
// do nothing because cpu don't need sync
return;
}
} // namespace imperative
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace imperative {
class GLOOParallelContext : public ParallelContext {
public:
explicit GLOOParallelContext(const ParallelStrategy& strategy,
const platform::Place& place)
: ParallelContext(strategy, place) {}
~GLOOParallelContext() override = default;
void Init() override;
void InitWithRingID(int ring_id) override;
void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;
paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override;
void WaitCompute(int ring_id) override;
void WaitComm(int ring_id) override;
void SynchronizeCompute() override;
private:
std::unique_ptr<platform::CPUDeviceContext> device_;
};
} // namespace imperative
} // namespace paddle
......@@ -28,7 +28,7 @@ namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO)
// div the nranks
void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
framework::Tensor *tensor =
......@@ -42,9 +42,12 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
DivNRanks(tensor, nranks, context);
#endif
} else if (platform::is_cpu_place(tensor->place())) {
VLOG(4) << "before div 2" << *tensor;
VLOG(4) << "NDiv for cpu devices : rank = " << nranks;
framework::VisitDataTypeSmall(
dtype_, DivNRanksForAllReduce<platform::CPUDeviceContext>(
tensor, nranks, context));
VLOG(4) << "after div 2" << *tensor;
} else if (platform::is_xpu_place(tensor->place())) {
#ifdef PADDLE_WITH_XPU_BKCL
// TODO(liuyuhui) support xpu about div nranks in the future
......@@ -764,8 +767,8 @@ void Reducer::MarkGroupReady(size_t group_index) {
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) {
auto &group = groups_[next_group_];
const int run_order = next_group_ % nrings_;
UNUSED auto &group = groups_[next_group_];
UNUSED const int run_order = next_group_ % nrings_;
// For CUDA or XPU, compute_stream --> comm_stream.
// For CPU, do nothing.
......@@ -792,11 +795,12 @@ void Reducer::MarkGroupReady(size_t group_index) {
cv_.notify_all();
}
});
#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_GLOO)
FusedAllReduceSchedule(run_order, group, next_group_);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not compiled with BKCL or NCCL."));
"Not compiled with BKCL or NCCL or GLOO."));
#endif
}
}
......@@ -974,7 +978,8 @@ void Reducer::FinalizeBackward() {
if (find_unused_vars_each_step_) {
// TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_GLOO)
ProcessUnusedDenseVars();
#endif
// Initialize local used vars
......
......@@ -49,7 +49,7 @@ namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO)
template <typename T>
struct DivNRanksFunctor {
......
......@@ -69,6 +69,8 @@ endif()
if(WITH_GLOO)
set(PYBIND_DEPS ${PYBIND_DEPS} gloo_context)
set(PYBIND_SRCS ${PYBIND_SRCS} gloo_context_py.cc)
set(PYBIND_DEPS ${PYBIND_DEPS} imperative_gloo_context)
set(PYBIND_DEPS ${PYBIND_DEPS} reducer)
endif(WITH_GLOO)
if (WITH_CRYPTO)
......
......@@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/bkcl_context.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/gloo_context.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/nccl_context.h"
......@@ -2017,7 +2018,7 @@ void BindImperative(py::module *m_ptr) {
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO)
py::class_<imperative::ParallelContext,
std::shared_ptr<imperative::ParallelContext>>(m,
"ParallelContext");
......@@ -2062,6 +2063,20 @@ void BindImperative(py::module *m_ptr) {
&imperative::BKCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
#if defined(PADDLE_WITH_GLOO)
// xiongkun
py::class_<imperative::GLOOParallelContext, imperative::ParallelContext,
std::shared_ptr<imperative::GLOOParallelContext>>(
m, "GLOOParallelContext")
.def(py::init<const imperative::ParallelStrategy &,
const platform::CPUPlace &>())
.def("init", [](imperative::GLOOParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::GLOOParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
m.def("pylayer_apply",
[](const platform::CPUPlace &place, const py::object &cls,
const py::args args, const py::kwargs kwargs) {
......
......@@ -55,19 +55,51 @@ def _start_kv_server(port, http_server_d, size):
http_server.stop()
def init_parallel_env():
def _check_backend(backend):
if backend not in ['nccl', 'gloo', 'bkcl', 'auto']:
raise ValueError(
"paddle.distributed initialize error, "
"backend argument can only be one of 'nccl', 'gloo', 'bkcl', 'auto', but got %s"
% backend)
if backend == 'nccl' and not core.is_compiled_with_cuda():
raise ValueError(
"paddle.distributed initialize error, "
"your paddle is not compiled with cuda but you assign 'nccl' as backend."
)
if backend == 'bkcl' and not core.is_compiled_with_xpu():
raise ValueError(
"paddle.distributed initialize error, "
"your paddle is not compiled with xpu but you assign 'bkcl' as backend."
)
if backend in ['auto', 'nccl', 'bkcl'] and (core.is_compiled_with_cuda() or
core.is_compiled_with_xpu()):
# passes 'auto' and can use cuda or xpu, use the default logics. so return False
return False
else:
return True
def init_parallel_env(backend='auto'):
"""
Initialize parallel training environment in dynamic graph mode.
.. note::
Now initialize both `NCCL` and `GLOO` contexts for communication.
Args:
backend (string): A string represents the backend used by DataParallel,
should be one of 'gloo'(for cpu), 'nccl'(for cuda), 'bkcl'(for xpu), 'auto'(auto detect).
The auto detection prefer 'nccl', 'bkcl' than 'gloo'.
Returns:
None
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
......@@ -120,13 +152,14 @@ def init_parallel_env():
"Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything."
)
return
# 1. gpu xpu check, must be gpu or xpu
if not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu():
# NOTE(xiongkun): support cpu gloo only, add this environment variable to
# enable cpu only gloo prarllel training)
is_cpu_only = _check_backend(backend)
# 1. gpu xpu check, must be gpu or xpu,
if not (is_cpu_only or core.is_compiled_with_cuda() or
core.is_compiled_with_xpu()):
raise NotImplementedError(
"Cannot initialize parallel environment in CPU-only version, now only "
"supports initializing the GPU and XPU parallel environment. Please recompile "
"or reinstall paddle with GPU or XPU support.")
"If you want to use CPU-only version, please use 'gloo' as backend")
# 2. check env
def _check_var_exists(var_name):
......@@ -136,9 +169,9 @@ def init_parallel_env():
"environment variable %s is needed, but not set." %
var_name)
if core.is_compiled_with_cuda():
if not is_cpu_only and core.is_compiled_with_cuda():
_check_var_exists("FLAGS_selected_gpus")
elif core.is_compiled_with_xpu():
elif not is_cpu_only and core.is_compiled_with_xpu():
_check_var_exists('FLAGS_selected_xpus')
_check_var_exists("PADDLE_TRAINER_ID")
......@@ -148,7 +181,7 @@ def init_parallel_env():
# 3: init gloo context (step 1: httpsever start)
init_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
if init_gloo:
if is_cpu_only or init_gloo:
ep_rank_0 = parallel_env.trainer_endpoints[0].split(":")
manager = Manager()
# glboal dict to store status
......@@ -180,14 +213,19 @@ def init_parallel_env():
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
if core.is_compiled_with_cuda():
if is_cpu_only:
place = core.CPUPlace()
elif core.is_compiled_with_cuda():
place = core.CUDAPlace(parallel_env.device_id)
elif core.is_compiled_with_xpu():
place = core.XPUPlace(parallel_env.device_id)
_set_expected_place(place)
_set_expected_place(place)
# init nccl or bkcl context
if core.is_compiled_with_cuda():
if is_cpu_only:
parallel_helper._set_parallel_ctx(
core.GLOOParallelContext(strategy, place))
elif core.is_compiled_with_cuda():
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
elif core.is_compiled_with_xpu():
......@@ -196,18 +234,22 @@ def init_parallel_env():
other_endpoints = strategy.trainer_endpoints[:]
other_endpoints.remove(strategy.current_endpoint)
if strategy.local_rank == 0:
if not is_cpu_only and strategy.local_rank == 0:
wait_server_ready(other_endpoints)
parallel_helper._init_parallel_ctx()
# 5: init gloo context (step 2: gloo init)
# dividing init_gloo into two part beacause nccl and gloo
# are separately looking for free ports which sometimes
# leads to port-conflict.
if init_gloo:
wait_server_ready([parallel_env.trainer_endpoints[0]])
if is_cpu_only and parallel_env.rank == 0:
# compare to init_gloo, we don't need to
# init gloo, because we do this in _init_parallel_ctx;
http_server_d["running"] = False
http_server.join()
elif init_gloo:
wait_server_ready([parallel_env.trainer_endpoints[0]])
gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = parallel_env.rank
gloo_strategy.rank_num = parallel_env.world_size
......
......@@ -184,6 +184,10 @@ endif()
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_hybrid_parallel)
if (NOT WITH_GLOO)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel_cpuonly)
endif()
if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future
......
......@@ -16,6 +16,7 @@ from __future__ import division
from __future__ import print_function
import unittest
import os
import paddle
import numpy as np
......@@ -65,7 +66,8 @@ class SimpleNet(fluid.Layer):
class TestDistTraning(unittest.TestCase):
def test_multiple_gpus(self):
dist.init_parallel_env()
backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto')
dist.init_parallel_env(backend)
self.trainer_id = dist.get_rank()
model_a = SimpleNet(self.trainer_id)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import time
import paddle.fluid as fluid
import copy
import os
import subprocess
from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc
def get_cluster_from_args(selected_gpus):
cluster_node_ips = '127.0.0.1'
node_ip = '127.0.0.1'
node_ips = [x.strip() for x in cluster_node_ips.split(',')]
node_ips.index(node_ip)
free_ports = None
free_ports = find_free_ports(len(selected_gpus))
if free_ports is not None:
free_ports = list(free_ports)
trainer_endpoints = []
for ip in node_ips:
trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports])
return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus)
def get_gpus(selected_gpus):
selected_gpus = [x.strip() for x in selected_gpus.split(',')]
return selected_gpus
def start_local_trainers(cluster,
pod,
training_script,
training_script_args,
log_dir=None):
current_env = copy.copy(os.environ.copy())
#paddle broadcast ncclUniqueId use socket, and
#proxy maybe make trainers unreachable, so delete them.
#if we set them to "", grpc will log error message "bad uri"
#so just delete them.
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
for t in pod.trainers:
proc_env = {
"PADDLE_TRAINER_ID": "%d" % t.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()),
"PADDLE_DISTRI_BACKEND":
"gloo", # make init_parallel_env get 'gloo' argument.
}
current_env.update(proc_env)
print("trainer proc env:{}".format(current_env))
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
cmd = "python -m coverage run --branch -p " + training_script
else:
cmd = "python -u " + training_script
print("start trainer proc:{} env:{}".format(cmd, proc_env))
fn = None
proc = subprocess.Popen(cmd.split(" "), env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = t.rank
tp.log_fn = fn
tp.cmd = cmd
procs.append(tp)
return procs
class TestMultipleGpus(unittest.TestCase):
def run_mnist_2gpu(self, target_file_name):
#if not fluid.core.is_compiled_with_cuda(
#) or fluid.core.get_cuda_device_count() == 0:
# return
selected_gpus = get_gpus('0,1')
cluster = None
pod = None
cluster, pod = get_cluster_from_args(selected_gpus)
procs = start_local_trainers(
cluster,
pod,
training_script=target_file_name,
training_script_args=[])
while True:
alive = watch_local_trainers(procs, cluster.trainers_nranks())
if not alive:
print("Local procs complete, POD info:{}".format(pod))
break
time.sleep(3)
class TestDataParallelGradientCheck(TestMultipleGpus):
def test_multiple_gpus_dynamic(self):
self.run_mnist_2gpu('parallel_dygraph_gradient_check.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册