未验证 提交 b6e7f8e9 编写于 作者: X xiongkun 提交者: GitHub

User specified backend (#35745)

上级 921c0917
...@@ -238,6 +238,24 @@ class GlooWrapper { ...@@ -238,6 +238,24 @@ class GlooWrapper {
return ret; return ret;
} }
// TODO(xiongkun03): support all gather array of
// numbers with different length
// can use AllgathervOptions, may be work in different
// occasion. Need some survey.
template <typename T>
void AllGatherVector(T* input_ptr, T* output_ptr,
size_t element_num) { // NOLINT
CHECK_EQ(is_initialized_, true);
#ifdef PADDLE_WITH_GLOO
gloo::AllgatherOptions opts(context_);
opts.setInput(input_ptr, element_num);
opts.setOutput(output_ptr, element_num * size_);
gloo::allgather(opts);
#else
LOG(WARNING) << "AllGather does nothing when WITH_GLOO=OFF";
#endif
}
protected: protected:
bool is_initialized_ = false; bool is_initialized_ = false;
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -67,8 +68,36 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src, ...@@ -67,8 +68,36 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst, framework::Variable *dst,
int ring_id, bool use_calc_stream) { int ring_id, bool use_calc_stream) {
// AllReduce(src, dst, strategy_, ring_id, use_calc_stream); // AllReduce(src, dst, strategy_, ring_id, use_calc_stream);
auto src_tensor = src.Get<framework::LoDTensor>(); if (src.IsType<framework::LoDTensor>()) {
auto *dst_tensor = dst->GetMutable<framework::LoDTensor>(); if (!dst->IsType<framework::LoDTensor>()) {
dst->Clear();
}
AllReduce(src.Get<framework::LoDTensor>(),
dst->GetMutable<framework::LoDTensor>());
} else if (src.IsType<framework::SelectedRows>()) {
if (&src != dst) {
if (!dst->IsType<framework::SelectedRows>()) {
dst->Clear();
}
AllReduce(src.Get<framework::SelectedRows>(),
dst->GetMutable<framework::SelectedRows>());
} else {
// SelectedRows cannot be allreduce in-place
framework::Variable tmp_dst;
AllReduce(src.Get<framework::SelectedRows>(),
tmp_dst.GetMutable<framework::SelectedRows>());
*dst = std::move(tmp_dst);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported variable type %s for imperative allreduce, only "
"LoDTensor and SelectedRows are supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}
void GLOOParallelContext::AllReduce(const framework::Tensor &src_tensor,
framework::Tensor *dst_tensor) {
auto gloo_wrapper = framework::GlooWrapper::GetInstance(); auto gloo_wrapper = framework::GlooWrapper::GetInstance();
dst_tensor->Resize(src_tensor.dims()); dst_tensor->Resize(src_tensor.dims());
switch (src_tensor.type()) { switch (src_tensor.type()) {
...@@ -84,6 +113,88 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src, ...@@ -84,6 +113,88 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src,
gloo_wrapper->Barrier(); gloo_wrapper->Barrier();
} }
#define GLOO_ALL_GATHER_CASE(type, T, gw) \
case type: { \
const auto *src_tensor_ptr = src_tensor.data<T>(); \
gw->AllGatherVector<T>(const_cast<T *>(src_tensor_ptr), \
reinterpret_cast<T *>(dst_tensor_ptr), \
value_sendcount); \
break; \
}
void GLOOParallelContext::AllReduce(const framework::SelectedRows &src,
framework::SelectedRows *dst) {
// auto ;
// int local_rank = strategy_.local_rank_;
int nranks = strategy_.nranks_;
VLOG(3) << "SelectedRows AllReduce start";
const auto &src_tensor = src.value();
const auto &place = src_tensor.place();
auto dtype = src_tensor.type();
// 1. Gather rows number from all workers. Here use ncclAllGather to do this,
// but we can use other ways to implement is in the future
const auto &src_rows = src.rows();
auto gloo_wrapper = framework::GlooWrapper::GetInstance();
size_t local_row_num = src_rows.size();
std::vector<size_t> rows_num_vector =
gloo_wrapper->AllGather<size_t>(local_row_num);
const auto *cpu_rows_num_ptr = rows_num_vector.data();
auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + nranks,
static_cast<int64_t>(0));
dst->set_height(src.height());
VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',')
<< ", total rows number: " << rows_num
<< ", height: " << src.height();
auto *dst_rows = dst->mutable_rows();
dst_rows->resize(rows_num);
auto *dst_rows_ptr = dst_rows->MutableData(place);
const int64_t *src_rows_ptr = src_rows.Data(place);
// VLOG(3) << "Selected Rows of src:" << string::join_strings(dst_rows, ',')
auto *dst_tensor = dst->mutable_value();
auto dims = src_tensor.dims();
dims[0] = rows_num;
auto feature_size = framework::product(dims) / dims[0];
dst_tensor->Resize(dims);
if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + nranks,
[&](size_t row) { return row == cpu_rows_num_ptr[0]; })) {
// During sparse communication, the number of each card is same.
// Because gloo wrapper utility class currently don't support
// broadcast, so we only deal the-same case.
VLOG(3) << "Use the gloo all reduce to sync. SRC:" << src_tensor;
// framework::SerializeToStream(VLOG(4), src);
VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";
auto value_sendcount = cpu_rows_num_ptr[0] * feature_size;
auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype);
gloo_wrapper->AllGatherVector<int64_t>(const_cast<int64_t *>(src_rows_ptr),
static_cast<int64_t *>(dst_rows_ptr),
rows_num_vector[0]);
switch (dtype) {
GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP32, float,
gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP64, double,
gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT32, int, gloo_wrapper);
GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT64, int64_t,
gloo_wrapper);
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid datatype for allreduce"));
}
}
VLOG(3) << "Selected Row DST:" << *dst_tensor;
VLOG(3) << "Selected Rows of DST:"
<< string::join_strings(std::vector<int64_t>(*dst_rows), ',');
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The number of each card is not the same, gloo only support the-same"
"batch division"));
}
}
paddle::platform::DeviceContext *GLOOParallelContext::GetDeviceContext( paddle::platform::DeviceContext *GLOOParallelContext::GetDeviceContext(
int ring_id) { int ring_id) {
// return the CPUDeviceContext // return the CPUDeviceContext
......
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/parallel_context.h" #include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -52,6 +55,11 @@ class GLOOParallelContext : public ParallelContext { ...@@ -52,6 +55,11 @@ class GLOOParallelContext : public ParallelContext {
void SynchronizeCompute() override; void SynchronizeCompute() override;
private:
void AllReduce(const framework::Tensor& src, framework::Tensor* dst);
void AllReduce(const framework::SelectedRows& src,
framework::SelectedRows* dst);
private: private:
std::unique_ptr<platform::CPUDeviceContext> device_; std::unique_ptr<platform::CPUDeviceContext> device_;
}; };
......
...@@ -103,7 +103,12 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -103,7 +103,12 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
type=str, type=str,
default="log", default="log",
help="The path for each process's log. Default --log_dir=log/") help="The path for each process's log. Default --log_dir=log/")
base_group.add_argument(
"--backend",
type=str,
default="auto",
help="Specifize the backend, can be gloo|nccl|bkcl|auto. Default value is auto which perfers nccl or bkcl."
)
base_group.add_argument( base_group.add_argument(
"--nproc_per_node", "--nproc_per_node",
type=int, type=int,
...@@ -230,8 +235,21 @@ def get_cluster_from_args(args, device_mode, devices_per_proc): ...@@ -230,8 +235,21 @@ def get_cluster_from_args(args, device_mode, devices_per_proc):
devices_per_proc) devices_per_proc)
def cpuonly_check(args):
if args.ips and len(args.ips.split(',')) > 1:
raise RuntimeError(
"CPUONLY launch only support single trainer, that is len(ips)=1, but got %s."
% args.ips)
if args.run_mode:
assert args.run_mode == 'cpuonly', "CPUONLY launch only support run mode is CPUONLY"
if args.servers:
raise RuntimeError("CPUONLY launch can't have --servers as arguments.")
return True
def launch_collective(args): def launch_collective(args):
# parse arguments, used for cloud-single-machine and local # parse arguments, used for cloud-single-machine and local
if args.backend == 'gloo': cpuonly_check(args)
(device_mode, devices_per_proc) = launch_utils.get_device_proc_info(args) (device_mode, devices_per_proc) = launch_utils.get_device_proc_info(args)
trainers_num = cloud_utils.get_trainers_num() trainers_num = cloud_utils.get_trainers_num()
logger.debug("parsed from args trainerss_num:{} mode:{} devices:{}".format( logger.debug("parsed from args trainerss_num:{} mode:{} devices:{}".format(
...@@ -265,6 +283,7 @@ def launch_collective(args): ...@@ -265,6 +283,7 @@ def launch_collective(args):
global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0")) global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0"))
global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3" global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3"
global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir
global_envs["PADDLE_DISTRI_BACKEND"] = args.backend
procs = start_local_trainers( procs = start_local_trainers(
cluster, cluster,
...@@ -349,9 +368,12 @@ def which_distributed_mode(args): ...@@ -349,9 +368,12 @@ def which_distributed_mode(args):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
accelerators = fluid.core.get_cuda_device_count() accelerators = fluid.core.get_cuda_device_count()
args.backend = 'nccl'
elif fluid.core.is_compiled_with_npu(): elif fluid.core.is_compiled_with_npu():
args.backend = 'unknown'
accelerators = fluid.core.get_npu_device_count() accelerators = fluid.core.get_npu_device_count()
elif fluid.core.is_compiled_with_xpu(): elif fluid.core.is_compiled_with_xpu():
args.backend = 'bkcl'
accelerators = fluid.core.get_xpu_device_count() accelerators = fluid.core.get_xpu_device_count()
else: else:
accelerators = 0 accelerators = 0
...@@ -372,10 +394,14 @@ def which_distributed_mode(args): ...@@ -372,10 +394,14 @@ def which_distributed_mode(args):
else: else:
if not fluid.core.is_compiled_with_cuda( if not fluid.core.is_compiled_with_cuda(
) and not fluid.core.is_compiled_with_xpu(): ) and not fluid.core.is_compiled_with_xpu():
logger.warning( if args.servers:
"Not found distinct arguments and not compiled with cuda or xpu. Default use ps mode" logger.warning(
) "Not found distinct arguments and not compiled with cuda or xpu. \
return DistributeMode.PS But found args.servers not empty, default use ps mode")
return DistributeMode.PS
else:
args.backend = "gloo"
return DistributeMode.COLLECTIVE
else: else:
logger.warning( logger.warning(
"Not found distinct arguments and compiled with cuda or xpu. Default use collective mode" "Not found distinct arguments and compiled with cuda or xpu. Default use collective mode"
...@@ -556,7 +582,20 @@ def launch(): ...@@ -556,7 +582,20 @@ def launch():
logger = get_logger() logger = get_logger()
_print_arguments(args) _print_arguments(args)
distribute_mode = which_distributed_mode(args) if args.backend == 'auto':
distribute_mode = which_distributed_mode(args)
assert args.backend in [
'gloo', 'nccl', 'bkcl', 'unknown'
] # which_distributed_mode must modify args.backend
else:
assert args.run_mode == 'collective' or args.run_mode == None, "When backend is not 'auto', run mode must be collective"
check_backend(args.backend)
distribute_mode = DistributeMode.COLLECTIVE
block_windows_and_macos(
args.backend) # raise error when using gloo on windows or macos
if args.backend == 'gloo':
logger.warning("launch start with CPUONLY mode")
if enable_elastic(args, distribute_mode): if enable_elastic(args, distribute_mode):
launch_elastic(args, distribute_mode) launch_elastic(args, distribute_mode)
......
...@@ -22,6 +22,7 @@ import subprocess ...@@ -22,6 +22,7 @@ import subprocess
import tempfile import tempfile
import shutil import shutil
from contextlib import closing from contextlib import closing
import multiprocessing
import socket import socket
import warnings import warnings
import six import six
...@@ -30,6 +31,7 @@ import struct ...@@ -30,6 +31,7 @@ import struct
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from distutils.util import strtobool from distutils.util import strtobool
import paddle.utils.cpp_extension.extension_utils as utils
logger = logging.getLogger("root") logger = logging.getLogger("root")
logger.propagate = False logger.propagate = False
...@@ -669,29 +671,31 @@ def get_xpus(xpus): ...@@ -669,29 +671,31 @@ def get_xpus(xpus):
return res_xpus return res_xpus
def get_device_mode(): def get_device_mode(backend):
if fluid.core.is_compiled_with_npu() and \ if fluid.core.is_compiled_with_npu() and \
fluid.core.get_npu_device_count() > 0: fluid.core.get_npu_device_count() > 0:
print("launch train in ascend npu mode!") print("launch train in ascend npu mode!")
return DeviceMode.ASCEND_NPU return DeviceMode.ASCEND_NPU
if fluid.core.is_compiled_with_cuda() and \ if backend == 'nccl' and \
fluid.core.get_cuda_device_count() > 0: fluid.core.get_cuda_device_count() > 0:
print("launch train in GPU mode!") print("launch train in GPU mode!")
return DeviceMode.GPU return DeviceMode.GPU
if fluid.core.is_compiled_with_xpu() and fluid.core.get_xpu_device_count( if backend == 'bkcl' and fluid.core.get_xpu_device_count() > 0:
) > 0:
print("launch train in XPU mode") print("launch train in XPU mode")
return DeviceMode.XPU return DeviceMode.XPU
print("launch train in CPU mode") if backend == 'gloo':
return DeviceMode.CPU print("launch train in CPU mode")
return DeviceMode.CPU
raise RuntimeError("Don't supported devices")
def get_device_proc_info(args): def get_device_proc_info(args):
# device_mode # device_mode
device_mode = get_device_mode() device_mode = get_device_mode(args.backend)
# devices # devices
devices_per_proc = [] devices_per_proc = []
...@@ -722,6 +726,9 @@ def get_device_proc_info(args): ...@@ -722,6 +726,9 @@ def get_device_proc_info(args):
else: else:
devices_per_proc = xpus devices_per_proc = xpus
elif device_mode == DeviceMode.CPU: elif device_mode == DeviceMode.CPU:
if hasattr(args, "paddle_cpuonly") and args.nproc_per_node is None:
#NOTE (xiongkun03) set it to cpu core number
args.nproc_per_node = multiprocessing.cpu_count()
if args.nproc_per_node is None: if args.nproc_per_node is None:
devices_per_proc = [0] devices_per_proc = [0]
else: else:
...@@ -1237,3 +1244,45 @@ class ParameterServerLauncher(object): ...@@ -1237,3 +1244,45 @@ class ParameterServerLauncher(object):
tp.cmd = cmd tp.cmd = cmd
self.procs["heter_worker"].append(tp) self.procs["heter_worker"].append(tp)
def check_backend(backend):
if backend not in ['nccl', 'gloo', 'bkcl', 'auto']:
raise ValueError(
"paddle.distributed initialize error, "
"backend argument can only be one of 'nccl', 'gloo', 'bkcl', 'auto', but got %s"
% backend)
if backend == 'nccl' and not fluid.core.is_compiled_with_cuda():
raise ValueError(
"paddle.distributed initialize error, "
"your paddle is not compiled with cuda but you assign 'nccl' as backend."
)
if backend == 'bkcl' and not fluid.core.is_compiled_with_xpu():
raise ValueError(
"paddle.distributed initialize error, "
"your paddle is not compiled with xpu but you assign 'bkcl' as backend."
)
def block_windows_and_macos(backend):
if backend != 'gloo': return
if utils.OS_NAME.startswith('darwin'): # MACOS , block
raise ValueError(
"You are going to using gloo on macos, but currently is not supported"
)
if utils.IS_WINDOWS: # MACOS , block
raise ValueError(
"You are going to using gloo on windows, but currently is not supported"
)
def get_backend_by_compile_flag():
if fluid.core.is_compiled_with_cuda():
return 'nccl'
if fluid.core.is_compiled_with_xpu():
return 'bkcl'
return 'gloo'
...@@ -26,6 +26,7 @@ from paddle import compat as cpt ...@@ -26,6 +26,7 @@ from paddle import compat as cpt
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import _set_expected_place from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
from paddle.distributed.fleet.launch_utils import check_backend
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401 from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401
...@@ -55,25 +56,8 @@ def _start_kv_server(port, http_server_d, size): ...@@ -55,25 +56,8 @@ def _start_kv_server(port, http_server_d, size):
http_server.stop() http_server.stop()
def _check_backend(backend): def _is_cpuonly(backend):
if backend not in ['nccl', 'gloo', 'bkcl', 'auto']: check_backend(backend)
raise ValueError(
"paddle.distributed initialize error, "
"backend argument can only be one of 'nccl', 'gloo', 'bkcl', 'auto', but got %s"
% backend)
if backend == 'nccl' and not core.is_compiled_with_cuda():
raise ValueError(
"paddle.distributed initialize error, "
"your paddle is not compiled with cuda but you assign 'nccl' as backend."
)
if backend == 'bkcl' and not core.is_compiled_with_xpu():
raise ValueError(
"paddle.distributed initialize error, "
"your paddle is not compiled with xpu but you assign 'bkcl' as backend."
)
if backend in ['auto', 'nccl', 'bkcl'] and (core.is_compiled_with_cuda() or if backend in ['auto', 'nccl', 'bkcl'] and (core.is_compiled_with_cuda() or
core.is_compiled_with_xpu()): core.is_compiled_with_xpu()):
# passes 'auto' and can use cuda or xpu, use the default logics. so return False # passes 'auto' and can use cuda or xpu, use the default logics. so return False
...@@ -82,7 +66,7 @@ def _check_backend(backend): ...@@ -82,7 +66,7 @@ def _check_backend(backend):
return True return True
def init_parallel_env(backend='auto'): def init_parallel_env():
""" """
Initialize parallel training environment in dynamic graph mode. Initialize parallel training environment in dynamic graph mode.
...@@ -154,7 +138,8 @@ def init_parallel_env(backend='auto'): ...@@ -154,7 +138,8 @@ def init_parallel_env(backend='auto'):
return return
# NOTE(xiongkun): support cpu gloo only, add this environment variable to # NOTE(xiongkun): support cpu gloo only, add this environment variable to
# enable cpu only gloo prarllel training) # enable cpu only gloo prarllel training)
is_cpu_only = _check_backend(backend) backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto')
is_cpu_only = _is_cpuonly(backend)
# 1. gpu xpu check, must be gpu or xpu, # 1. gpu xpu check, must be gpu or xpu,
if not (is_cpu_only or core.is_compiled_with_cuda() or if not (is_cpu_only or core.is_compiled_with_cuda() or
core.is_compiled_with_xpu()): core.is_compiled_with_xpu()):
......
...@@ -24,8 +24,10 @@ import warnings ...@@ -24,8 +24,10 @@ import warnings
from paddle.distributed.utils import _print_arguments from paddle.distributed.utils import _print_arguments
from paddle.distributed.utils import _prepare_trainer_env from paddle.distributed.utils import _prepare_trainer_env
from paddle.distributed.utils import get_host_name_ip from paddle.distributed.utils import get_host_name_ip
from paddle.distributed.cloud_utils import get_cluster_and_pod from paddle.distributed.cloud_utils import get_cluster_and_pod, _get_trainers_num
from paddle.distributed.fleet.launch import get_cluster_from_args
from paddle.distributed.fleet.cloud_utils import use_paddlecloud from paddle.distributed.fleet.cloud_utils import use_paddlecloud
from paddle.distributed.fleet.launch_utils import DeviceMode, check_backend, block_windows_and_macos
from paddle.device import get_device from paddle.device import get_device
# deprecated module import # deprecated module import
...@@ -71,7 +73,9 @@ def _py_supported_check(): ...@@ -71,7 +73,9 @@ def _py_supported_check():
def _options_valid_check(options): def _options_valid_check(options):
# `print_config` keeped as a debug options, not show to users # `print_config` keeped as a debug options, not show to users
supported_options = ['start_method', 'ips', 'gpus', 'xpus', 'print_config'] supported_options = [
'start_method', 'ips', 'gpus', 'xpus', 'print_config', 'backend'
]
deprecated_options = [ deprecated_options = [
'selected_devices', 'started_port', 'cluster_node_ips', 'node_ip', 'selected_devices', 'started_port', 'cluster_node_ips', 'node_ip',
'use_paddlecloud' 'use_paddlecloud'
...@@ -95,6 +99,22 @@ def _get_default_nprocs(): ...@@ -95,6 +99,22 @@ def _get_default_nprocs():
return core.get_cuda_device_count() return core.get_cuda_device_count()
elif 'xpu' in device: elif 'xpu' in device:
return core.get_xpu_device_count() return core.get_xpu_device_count()
elif 'cpu' in device:
return multiprocessing.cpu_count()
else:
raise RuntimeError(
"`paddle.distributed.spawn` does not support parallel training on device `{}` now.".
format(device))
def _get_default_backend():
device = get_device()
if 'gpu' in device:
return 'nccl'
elif 'xpu' in device:
return 'bkcl'
elif 'cpu' in device:
return 'gloo'
else: else:
raise RuntimeError( raise RuntimeError(
"`paddle.distributed.spawn` does not support parallel training on device `{}` now.". "`paddle.distributed.spawn` does not support parallel training on device `{}` now.".
...@@ -112,6 +132,16 @@ def _get_node_ip(ips): ...@@ -112,6 +132,16 @@ def _get_node_ip(ips):
def _get_subprocess_env_list(nprocs, options): def _get_subprocess_env_list(nprocs, options):
# NOTE (xiongkun03) Why put backend deduction here ?
# Becase _get_subprocess_env_list is used by many testcases.
# So for campability, we put backend deduction here
# logic for handle backend option
if 'backend' not in options or options['backend'] == 'auto':
options['backend'] = _get_default_backend()
check_backend(options['backend'])
block_windows_and_macos(options['backend'])
# contruct processes env list # contruct processes env list
processes_env_list = [] processes_env_list = []
...@@ -133,7 +163,7 @@ def _get_subprocess_env_list(nprocs, options): ...@@ -133,7 +163,7 @@ def _get_subprocess_env_list(nprocs, options):
# if we set FLAGS_selected_gpus or FLAGS_selected_xpus to be `0,1,2,3`, it may cause error # if we set FLAGS_selected_gpus or FLAGS_selected_xpus to be `0,1,2,3`, it may cause error
# when using `ParallelEnv` # when using `ParallelEnv`
# NOTE(chenweihang): use absolute gpu or xpu card id # NOTE(chenweihang): use absolute gpu or xpu card id
if core.is_compiled_with_cuda(): if options['backend'] == 'nccl':
args.selected_devices = options.get('gpus', None) args.selected_devices = options.get('gpus', None)
if args.selected_devices is None: if args.selected_devices is None:
args.selected_devices = options.get('selected_devices', None) args.selected_devices = options.get('selected_devices', None)
...@@ -168,7 +198,7 @@ def _get_subprocess_env_list(nprocs, options): ...@@ -168,7 +198,7 @@ def _get_subprocess_env_list(nprocs, options):
"CUDA_VISIBLE_DEVICES (%s)." % "CUDA_VISIBLE_DEVICES (%s)." %
(card_id, ",".join(env_devices_list))) (card_id, ",".join(env_devices_list)))
elif core.is_compiled_with_xpu(): elif options['backend'] == 'bkcl':
args.selected_devices = options.get('xpus', None) args.selected_devices = options.get('xpus', None)
if args.selected_devices is None: if args.selected_devices is None:
args.selected_devices = options.get('selected_devices', None) args.selected_devices = options.get('selected_devices', None)
...@@ -202,6 +232,23 @@ def _get_subprocess_env_list(nprocs, options): ...@@ -202,6 +232,23 @@ def _get_subprocess_env_list(nprocs, options):
raise ValueError("The selected xpu card %s cannot found in " raise ValueError("The selected xpu card %s cannot found in "
"XPU_VISIBLE_DEVICES (%s)." % "XPU_VISIBLE_DEVICES (%s)." %
(card_id, ",".join(env_devices_list))) (card_id, ",".join(env_devices_list)))
elif options['backend'] == 'gloo':
# TODO check gpu / xpu flag must not exist
warnings.warn(
"Your model will be trained under CPUONLY mode by using GLOO,"
"because CPUPlace is specified manually or your installed PaddlePaddle only support CPU Device."
)
args.paddle_cpuonly = True
args.selected_devices = None
args.ips = args.cluster_node_ips
assert options.get(
'use_paddlecloud',
None) is None, "CPUONLY spawn doesn't support use paddle cloud"
assert len(
args.cluster_node_ips.split(',')
) <= 1, "CPUONLY spawn only support single trainer, that is len(ips)=1, but got %s."
assert _get_trainers_num(
) == 1, "CPUONLY spawn doesn't support multi-trainer"
# set other inner args # set other inner args
args.node_ip = options.get('node_ip', None) args.node_ip = options.get('node_ip', None)
...@@ -215,11 +262,17 @@ def _get_subprocess_env_list(nprocs, options): ...@@ -215,11 +262,17 @@ def _get_subprocess_env_list(nprocs, options):
args.use_paddlecloud = use_paddlecloud() args.use_paddlecloud = use_paddlecloud()
# get cluster and pod config # get cluster and pod config
cluster, pod = get_cluster_and_pod(args) if options['backend'] == 'gloo':
devices_per_proc = [x for x in range(0, nprocs)]
cluster, pod = get_cluster_from_args(args, DeviceMode.CPU,
devices_per_proc)
else:
cluster, pod = get_cluster_and_pod(args)
# prepare subprocess env list # prepare subprocess env list
for trainer in pod.trainers: for trainer in pod.trainers:
processes_env_list.append(_prepare_trainer_env(cluster, trainer)) processes_env_list.append(
_prepare_trainer_env(cluster, trainer, options['backend']))
# [Debug] print config # [Debug] print config
args.print_config = options.get('print_config', False) args.print_config = options.get('print_config', False)
...@@ -236,27 +289,35 @@ def _remove_risky_env(): ...@@ -236,27 +289,35 @@ def _remove_risky_env():
os.environ.pop("https_proxy", None) os.environ.pop("https_proxy", None)
def _set_trainer_env(env_dict): def _set_trainer_env(env_dict, backend):
# NOTE(chenweihang): [ Why need set FLAGS_selected_gpus or FLAGS_selected_xpus here? ] # NOTE(chenweihang): [ Why need set FLAGS_selected_gpus or FLAGS_selected_xpus here? ]
# When the child process starts, it will inherit the configuration of the # When the child process starts, it will inherit the configuration of the
# main process and set the FLAGS once, but the environment variable has # main process and set the FLAGS once, but the environment variable has
# not been set at this time, which leads to the FLAGS_selected_gpus or FLAGS_selected_xpus # not been set at this time, which leads to the FLAGS_selected_gpus or FLAGS_selected_xpus
# is keep same with mainprocess(usually empty), so manually update the flags here # is keep same with mainprocess(usually empty), so manually update the flags here
if core.is_compiled_with_cuda():
# NOTE(xiongkun): why put backend here? because if gloo, we shouldn't set FLAGS_selectedXXX
#
if backend == 'nccl':
set_flags({'FLAGS_selected_gpus': env_dict['FLAGS_selected_gpus']}) set_flags({'FLAGS_selected_gpus': env_dict['FLAGS_selected_gpus']})
elif core.is_compiled_with_xpu(): elif backend == 'bkcl':
set_flags({'FLAGS_selected_xpus': env_dict['FLAGS_selected_xpus']}) set_flags({'FLAGS_selected_xpus': env_dict['FLAGS_selected_xpus']})
else: else:
raise ValueError("PaddlePaddle should be compiled with XPU or CUDA.") #NOTE(xiongkun) why not raise Error ?
# So far, we added support for CPU parallel, and will be applied when paddle is not
# compiled with cuda or xp. just do nothing.
pass
for var_name in env_dict: for var_name in env_dict:
os.environ[var_name] = env_dict[var_name] os.environ[var_name] = env_dict[var_name]
def _func_wrapper(func, args, error_queue, return_queue, env_dict): def _func_wrapper(func, args, error_queue, return_queue, env_dict, backend):
try: try:
# config subprocess environment variables # config subprocess environment variables
_remove_risky_env() _remove_risky_env()
_set_trainer_env(env_dict) _set_trainer_env(env_dict, backend)
# execute function # execute function
result = func(*args) result = func(*args)
# record function return value # record function return value
...@@ -487,7 +548,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -487,7 +548,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
return_queue = mp.SimpleQueue() return_queue = mp.SimpleQueue()
process = mp.Process( process = mp.Process(
target=_func_wrapper, target=_func_wrapper,
args=(func, args, error_queue, return_queue, procs_env_list[i])) args=(func, args, error_queue, return_queue, procs_env_list[i],
options['backend']))
process.daemon = daemon process.daemon = daemon
process.start() process.start()
error_queues.append(error_queue) error_queues.append(error_queue)
......
...@@ -25,6 +25,7 @@ import subprocess ...@@ -25,6 +25,7 @@ import subprocess
from contextlib import closing from contextlib import closing
import socket import socket
from paddle.fluid import core from paddle.fluid import core
from paddle.distributed.fleet.launch_utils import get_backend_by_compile_flag
from distutils.util import strtobool from distutils.util import strtobool
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
...@@ -613,8 +614,10 @@ def find_free_ports(num): ...@@ -613,8 +614,10 @@ def find_free_ports(num):
return None return None
def _prepare_trainer_env(cluster, trainer): def _prepare_trainer_env(cluster, trainer, backend=None):
if core.is_compiled_with_xpu(): if backend is None:
backend = get_backend_by_compile_flag() # for compatibility
if backend == 'bkcl':
proc_env = { proc_env = {
"FLAGS_selected_xpus": "FLAGS_selected_xpus":
"%s" % ",".join([str(g) for g in trainer.gpus]), "%s" % ",".join([str(g) for g in trainer.gpus]),
...@@ -623,7 +626,7 @@ def _prepare_trainer_env(cluster, trainer): ...@@ -623,7 +626,7 @@ def _prepare_trainer_env(cluster, trainer):
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
} }
elif core.is_compiled_with_cuda(): elif backend == 'nccl':
proc_env = { proc_env = {
"FLAGS_selected_gpus": "FLAGS_selected_gpus":
"%s" % ",".join([str(g) for g in trainer.gpus]), "%s" % ",".join([str(g) for g in trainer.gpus]),
...@@ -632,6 +635,19 @@ def _prepare_trainer_env(cluster, trainer): ...@@ -632,6 +635,19 @@ def _prepare_trainer_env(cluster, trainer):
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
} }
elif backend == 'gloo':
# NOTE (xiongkun) default fall back into cpu only
proc_env = {
"PADDLE_TRAINER_ID": "%d" % trainer.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()),
"PADDLE_DISTRI_BACKEND":
backend, # only add here, other will be auto
}
else:
raise ValueError("backend must be one of 'gloo, nccl, bkcl'")
return proc_env return proc_env
......
...@@ -200,8 +200,14 @@ endif() ...@@ -200,8 +200,14 @@ endif()
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_hybrid_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_hybrid_parallel)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer_gloo) # NOTE: @xiongkun03, cpu is too slow, fix it in next PR
if (NOT WITH_GLOO) if (NOT WITH_GLOO)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel_cpuonly) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel_cpuonly)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_unused_variables_gloo)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_over_height_gloo)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_gloo)
endif() endif()
if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
...@@ -491,6 +497,10 @@ if (APPLE OR WIN32) ...@@ -491,6 +497,10 @@ if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_dataset) list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_dataset)
endif() endif()
if (NOT WITH_GLOO)
LIST(REMOVE_ITEM TEST_OPS test_cpuonly_spawn)
endif()
if(NOT WITH_GPU OR WIN32 OR APPLE) if(NOT WITH_GPU OR WIN32 OR APPLE)
list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass)
endif() endif()
...@@ -654,6 +664,9 @@ if(WITH_DISTRIBUTE) ...@@ -654,6 +664,9 @@ if(WITH_DISTRIBUTE)
endforeach(TEST_OP) endforeach(TEST_OP)
# solve it later. # solve it later.
bash_test_modules(test_fleet_launch_ps START_BASH test_fleet_launch_ps.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} ) bash_test_modules(test_fleet_launch_ps START_BASH test_fleet_launch_ps.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} )
if (WITH_GLOO)
bash_test_modules(test_cpuonly_launch START_BASH test_cpuonly_launch.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} )
endif()
bash_test_modules(test_new_group START_BASH test_new_group.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}+20" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} ) bash_test_modules(test_new_group START_BASH test_new_group.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}+20" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} )
endif(NOT APPLE) endif(NOT APPLE)
endif() endif()
...@@ -1070,3 +1083,8 @@ set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120) ...@@ -1070,3 +1083,8 @@ set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120)
set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400) set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400)
set_tests_properties(test_tensordot PROPERTIES TIMEOUT 1000) set_tests_properties(test_tensordot PROPERTIES TIMEOUT 1000)
set_tests_properties(test_tensordot PROPERTIES LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_tensordot PROPERTIES LABELS "RUN_TYPE=NIGHTLY")
if (WITH_GLOO)
set_tests_properties(test_parallel_dygraph_unused_variables_gloo PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_sparse_embedding_gloo PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height_gloo PROPERTIES TIMEOUT 120)
endif()
...@@ -66,8 +66,7 @@ class SimpleNet(fluid.Layer): ...@@ -66,8 +66,7 @@ class SimpleNet(fluid.Layer):
class TestDistTraning(unittest.TestCase): class TestDistTraning(unittest.TestCase):
def test_multiple_gpus(self): def test_multiple_gpus(self):
backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto') dist.init_parallel_env()
dist.init_parallel_env(backend)
self.trainer_id = dist.get_rank() self.trainer_id = dist.get_rank()
model_a = SimpleNet(self.trainer_id) model_a = SimpleNet(self.trainer_id)
......
...@@ -324,6 +324,7 @@ class TestSeResNeXt(TestParallelDyGraphRunnerBase): ...@@ -324,6 +324,7 @@ class TestSeResNeXt(TestParallelDyGraphRunnerBase):
bs = len(data) bs = len(data)
dy_x_data = np.array([x[0].reshape(3, 224, 224) dy_x_data = np.array([x[0].reshape(3, 224, 224)
for x in data]).astype('float32') for x in data]).astype('float32')
dy_x_data = dy_x_data / 255.0
y_data = np.array([x[1] for x in data]).astype('int64').reshape(bs, 1) y_data = np.array([x[1] for x in data]).astype('int64').reshape(bs, 1)
img = to_variable(dy_x_data) img = to_variable(dy_x_data)
label = to_variable(y_data) label = to_variable(y_data)
......
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function test_launch_cpuonly(){
python -m paddle.distributed.launch --nproc_per_node=4 --backend=gloo \
parallel_dygraph_gradient_check.py 2>ut.elog
if grep -q "ABORT" ut.elog; then
echo "test cpu only failed"
exit -1
else
if grep -q "CPUONLY" ut.elog; then
echo "test_launch_cpuonly successfully"
else
echo "test_launch_cpuonly failed"
exit -1
fi
fi
}
function test_launch_error_case1(){
python -m paddle.distributed.launch --nproc_per_node=4 --backend=random_str \
parallel_dygraph_gradient_check.py 2>ut.elog
if grep -q "ValueError" ut.elog; then
echo "test_launch_error_case1 successfully"
else
exit -1
fi
}
test_launch_cpuonly
test_launch_error_case1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train(print_result=False):
# 1. initialize parallel environment
dist.init_parallel_env()
# 2. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(learning_rate=0.001, parameters=dp_layer.parameters())
# 3. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
if print_result is True:
print("loss:", loss.numpy())
loss.backward()
print("Grad is", layer._linear1.weight.grad)
adam.step()
adam.clear_grad()
class TestSpawn(unittest.TestCase):
def test_spawn(self):
dist.spawn(train, backend='gloo', nprocs=4)
def test_wrong_backend(self):
try:
dist.spawn(train, backend='something', nprocs=4)
except ValueError as e:
self.assertEqual(type(e), ValueError)
if __name__ == '__main__':
unittest.main()
...@@ -209,7 +209,11 @@ class TestDistRunnerBase(object): ...@@ -209,7 +209,11 @@ class TestDistRunnerBase(object):
def get_data(): def get_data():
origin_batch = next(reader_generator) origin_batch = next(reader_generator)
if args.update_method != "local" and args.use_reader_alloc: if paddle.distributed.get_world_size(
) == 1 and args.update_method == 'gloo': # Gloo single mode
return origin_batch
elif args.update_method != "local" and args.use_reader_alloc:
new_batch = [] new_batch = []
for offset, item in enumerate(origin_batch): for offset, item in enumerate(origin_batch):
if offset % 2 == args.trainer_id: if offset % 2 == args.trainer_id:
...@@ -506,7 +510,10 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -506,7 +510,10 @@ class TestParallelDyGraphRunnerBase(object):
"train_one_loop should be implemented by the child classes.") "train_one_loop should be implemented by the child classes.")
def _get_data(self, batch, args): def _get_data(self, batch, args):
if args.update_method != "local": if paddle.distributed.get_world_size(
) == 1 and args.update_method == 'gloo': # Gloo single mode
return batch
elif args.update_method != "local":
new_batch = [] new_batch = []
for offset, item in enumerate(batch): for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id: if offset % 2 == args.trainer_id:
...@@ -518,14 +525,16 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -518,14 +525,16 @@ class TestParallelDyGraphRunnerBase(object):
def run_trainer(self, args): def run_trainer(self, args):
seed = 90 seed = 90
if fluid.core.is_compiled_with_cuda(): if args.update_method == 'gloo':
place = fluid.CPUPlace()
elif fluid.core.is_compiled_with_cuda():
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id) place = fluid.CUDAPlace(device_id)
elif fluid.core.is_compiled_with_xpu(): elif fluid.core.is_compiled_with_xpu():
device_id = int(os.getenv("FLAGS_selected_xpus", "0")) device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = fluid.XPUPlace(device_id) place = fluid.XPUPlace(device_id)
else: else:
assert ("Only support CUDAPlace or XPUPlace for now.") assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.")
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
...@@ -554,6 +563,16 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -554,6 +563,16 @@ class TestParallelDyGraphRunnerBase(object):
model = dygraph.parallel.DataParallel( model = dygraph.parallel.DataParallel(
model, strategy, find_unused_parameters=True) model, strategy, find_unused_parameters=True)
print_to_err(type(self).__name__, "model built in dygraph") print_to_err(type(self).__name__, "model built in dygraph")
elif args.update_method == "gloo":
paddle.distributed.init_parallel_env()
if not args.find_unused_parameters:
model = dygraph.parallel.DataParallel(
model, find_unused_parameters=False)
else:
model = dygraph.parallel.DataParallel(
model, find_unused_parameters=True)
out_losses = [] out_losses = []
print_to_err(type(self).__name__, "begin to run dygraph training") print_to_err(type(self).__name__, "begin to run dygraph training")
for step_id, data in enumerate(train_reader()): for step_id, data in enumerate(train_reader()):
...@@ -588,12 +607,12 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -588,12 +607,12 @@ class TestParallelDyGraphRunnerBase(object):
args.trainer_id = paddle.distributed.get_rank() args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env # 3. init parallel env
if args.update_method == "nccl2": if args.update_method in ["nccl2", "gloo"]:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
# 4. train model # 4. train model
model, train_reader, opt = self.get_model() model, train_reader, opt = self.get_model()
if args.update_method == "nccl2": if args.update_method in ["nccl2", "gloo"]:
if args.find_unused_parameters: if args.find_unused_parameters:
model = paddle.DataParallel(model, find_unused_parameters=True) model = paddle.DataParallel(model, find_unused_parameters=True)
else: else:
...@@ -668,7 +687,9 @@ def runtime_main(test_class): ...@@ -668,7 +687,9 @@ def runtime_main(test_class):
'--update_method', '--update_method',
type=str, type=str,
default="local", default="local",
choices=["pserver", "nccl2", "bkcl", "local", "nccl2_reduce_layer"]) choices=[
"pserver", "nccl2", "bkcl", "local", "nccl2_reduce_layer", "gloo"
])
parser.add_argument('--trainer_id', type=int, required=False, default=0) parser.add_argument('--trainer_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument('--nccl_comm_num', type=int, required=False, default=1) parser.add_argument('--nccl_comm_num', type=int, required=False, default=1)
...@@ -685,6 +706,7 @@ def runtime_main(test_class): ...@@ -685,6 +706,7 @@ def runtime_main(test_class):
'--current_endpoint', type=str, required=False, default="") '--current_endpoint', type=str, required=False, default="")
parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_cpu', action='store_true')
parser.add_argument('--use_xpu', action='store_true') parser.add_argument('--use_xpu', action='store_true')
parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--accumulate_gradient', action='store_true') parser.add_argument('--accumulate_gradient', action='store_true')
...@@ -713,6 +735,9 @@ def runtime_main(test_class): ...@@ -713,6 +735,9 @@ def runtime_main(test_class):
args = parser.parse_args() args = parser.parse_args()
if args.update_method == 'gloo':
paddle.set_device("cpu")
model = test_class() model = test_class()
if args.role == "pserver" and args.update_method == "pserver": if args.role == "pserver" and args.update_method == "pserver":
model.run_pserver(args) model.run_pserver(args)
...@@ -770,6 +795,7 @@ class TestDistBase(unittest.TestCase): ...@@ -770,6 +795,7 @@ class TestDistBase(unittest.TestCase):
self._use_reader_alloc = True self._use_reader_alloc = True
self._nccl2_mode = False self._nccl2_mode = False
self._bkcl_mode = False self._bkcl_mode = False
self._gloo_mode = False # now, support gloo backend
self._pipeline_mode = False self._pipeline_mode = False
self._mp_mode = False self._mp_mode = False
# FIXME(typhoonzero): I added this stupid argument to enable # FIXME(typhoonzero): I added this stupid argument to enable
...@@ -875,7 +901,7 @@ class TestDistBase(unittest.TestCase): ...@@ -875,7 +901,7 @@ class TestDistBase(unittest.TestCase):
batch_size=DEFAULT_BATCH_SIZE, batch_size=DEFAULT_BATCH_SIZE,
batch_merge_repeat=1, batch_merge_repeat=1,
log_name="", log_name="",
devices="0"): devices="1"):
cmd = self._python_interp cmd = self._python_interp
...@@ -947,6 +973,21 @@ class TestDistBase(unittest.TestCase): ...@@ -947,6 +973,21 @@ class TestDistBase(unittest.TestCase):
return pickle.loads(local_out) return pickle.loads(local_out)
def _run_local_gloo(self,
model,
envs,
check_error_log=False,
batch_size=DEFAULT_BATCH_SIZE,
batch_merge_repeat=1,
log_name="",
devices="0"):
saved_endpoints = self._ps_endpoints
self._ps_endpoints = self._ps_endpoints.split(',')[0]
result = self._run_cluster_gloo(model, envs, 'gloo', check_error_log,
log_name)
self._ps_endpoints = saved_endpoints
return result
def _run_cluster(self, model, envs, check_error_log, log_name): def _run_cluster(self, model, envs, check_error_log, log_name):
# Run dist train to compare with local results # Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver( ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(
...@@ -1037,6 +1078,62 @@ class TestDistBase(unittest.TestCase): ...@@ -1037,6 +1078,62 @@ class TestDistBase(unittest.TestCase):
return pickle.loads(tr0_out), pickle.loads(tr1_out) return pickle.loads(tr0_out), pickle.loads(tr1_out)
def _get_gloo_trainer_cmd(self, model, ep, update_method, trainer_id,
trainer_num):
env = {}
tr_cmd = "%s -u"
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
tr_cmd += " -m coverage run --branch -p"
tr_cmd += " %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f"
tr_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints,
trainer_id, ep, update_method, self._lr)
if self._use_reduce:
tr_cmd += " --use_reduce"
if self._use_reader_alloc:
tr_cmd += " --use_reader_alloc"
#assert self._use_reduce == False, "gloo not support _use_reduce"
#assert self._use_reader_alloc == False, "gloo not support _use_reduce"
if self._save_model:
tr_cmd += " --save_model"
self.__use_cuda = False
self.__use_xpu = False
assert self.__use_cuda == False, "gloo not support use cuda"
assert self.__use_xpu == False, "gloo not support use xpu"
tr_cmd += " --use_cpu"
env.update({
"PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
"PADDLE_TRAINER_ID": "{}".format(trainer_id),
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": ep,
"PADDLE_CURRENT_ENDPOINT": ep,
"PADDLE_DISTRI_BACKEND": "gloo",
"GLOG_v": "2",
})
assert self._use_dgc == False, "gloo not support use dgc"
if self._accumulate_gradient:
tr_cmd += " --accumulate_gradient"
if self._find_unused_parameters:
tr_cmd += " --find_unused_parameters"
assert self._pipeline_mode == False, "gloo not support use pipeline"
if self._enable_backward_deps: # build strategy, save it
tr_cmd += " --enable_backward_deps"
if self._fuse_all_reduce is not None:
tr_cmd += " --fuse_all_reduce {}".format(self._fuse_all_reduce)
assert self._use_fleet_api == False, "gloo not support use fleet api"
assert self._use_fleet_api_20 == False, "gloo not support use fleet api"
return tr_cmd, env
def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id, def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id,
trainer_num): trainer_num):
env = {} env = {}
...@@ -1123,6 +1220,57 @@ class TestDistBase(unittest.TestCase): ...@@ -1123,6 +1220,57 @@ class TestDistBase(unittest.TestCase):
return tr_cmd, env return tr_cmd, env
def _run_cluster_gloo(self, model, envs, update_method, check_error_log,
log_name):
assert update_method == "gloo", "_run_cluster_gloo must have update_method: gloo, but get %s" % update_method
assert not self._use_hallreduce, "_run_cluster_gloo must have _use_hallreduce = false"
worker_endpoints = self._ps_endpoints.split(",")
trainer_num = len(worker_endpoints)
procs = []
pipes = []
for i in range(0, trainer_num):
tr_cmd, tr_env = self._get_gloo_trainer_cmd(
model, worker_endpoints[i], update_method, i, trainer_num)
tr_env.update(envs)
tr_env["GLOG_vmodule"] = 'gloo_context=4'
tr_env["GLOG_v"] = '3'
print("use_hallreduce:{} tr_cmd:{}, env: {}".format(
self._use_hallreduce, tr_cmd, tr_env))
tr_pipe = open(log_name + "_tr{}_err.log".format(i), "wb")
print_to_err(
type(self).__name__,
"going to start process {} with nccl2".format(i))
tr_proc = subprocess.Popen(
tr_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr_pipe,
env=tr_env)
procs.append(tr_proc)
pipes.append(tr_pipe)
outs = []
for i in range(0, trainer_num):
tr_out, tr_err = procs[i].communicate()
outs.append(tr_out)
pipes[i].close()
sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err))
if trainer_num == 1:
if check_error_log: print("outs[0]:", outs[0])
return pickle.loads(outs[0])
else:
if check_error_log:
print("outs[0]:", outs[0])
print("outs[1]:", outs[1])
return pickle.loads(outs[0]), pickle.loads(outs[1])
def _run_cluster_nccl2(self, model, envs, update_method, check_error_log, def _run_cluster_nccl2(self, model, envs, update_method, check_error_log,
log_name): log_name):
if self._use_hallreduce: if self._use_hallreduce:
...@@ -1262,7 +1410,12 @@ class TestDistBase(unittest.TestCase): ...@@ -1262,7 +1410,12 @@ class TestDistBase(unittest.TestCase):
required_envs = self._get_required_envs(check_error_log, need_envs) required_envs = self._get_required_envs(check_error_log, need_envs)
local_losses \ if self._gloo_mode:
local_losses \
= self._run_local_gloo(model_file, required_envs,
check_error_log, log_name=log_name)
else:
local_losses \
= self._run_local(model_file, required_envs, = self._run_local(model_file, required_envs,
check_error_log, log_name=log_name) check_error_log, log_name=log_name)
...@@ -1288,6 +1441,14 @@ class TestDistBase(unittest.TestCase): ...@@ -1288,6 +1441,14 @@ class TestDistBase(unittest.TestCase):
update_method='bkcl', update_method='bkcl',
check_error_log=check_error_log, check_error_log=check_error_log,
log_name=log_name) log_name=log_name)
elif self._gloo_mode:
# gloo mode, cpu only parallel train @xiongkun03
tr0_losses, tr1_losses = self._run_cluster_gloo(
model_file,
required_envs,
update_method='gloo',
check_error_log=check_error_log,
log_name=log_name)
elif self._pipeline_mode: elif self._pipeline_mode:
tr0_losses, tr1_losses = self._run_pipeline( tr0_losses, tr1_losses = self._run_pipeline(
......
...@@ -49,6 +49,51 @@ def get_gpus(selected_gpus): ...@@ -49,6 +49,51 @@ def get_gpus(selected_gpus):
return selected_gpus return selected_gpus
def start_local_trainers_cpu(trainer_endpoints,
training_script,
training_script_args,
log_dir=None):
current_env = copy.copy(os.environ.copy())
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
n_rank = len(trainer_endpoints)
print(trainer_endpoints)
for rank_id, endpoint in enumerate(trainer_endpoints):
proc_env = {
"PADDLE_DISTRI_BACKEND": "gloo",
"PADDLE_TRAINER_ID": "%d" % rank_id,
"PADDLE_CURRENT_ENDPOINT": "%s" % endpoint,
"PADDLE_TRAINERS_NUM": "%d" % n_rank,
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints)
}
current_env.update(proc_env)
print("trainer proc env:{}".format(current_env))
assert os.getenv('WITH_COVERAGE',
'OFF') == 'OFF', "Gloo don't support WITH_COVERAGE."
cmd = "python -u " + training_script
print("start trainer proc:{} env:{}".format(cmd, proc_env))
fn = None
proc = subprocess.Popen(cmd.split(" "), env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = rank_id
tp.log_fn = fn
tp.cmd = cmd
procs.append(tp)
return procs
def start_local_trainers(cluster, def start_local_trainers(cluster,
pod, pod,
training_script, training_script,
...@@ -116,6 +161,26 @@ class TestMultipleGpus(unittest.TestCase): ...@@ -116,6 +161,26 @@ class TestMultipleGpus(unittest.TestCase):
training_script=target_file_name, training_script=target_file_name,
training_script_args=[]) training_script_args=[])
while True:
alive = watch_local_trainers(procs, cluster.trainers_endpoints())
if not alive:
print("Local procs complete, POD info:{}".format(pod))
break
time.sleep(3)
class TestMultipleWithGloo(unittest.TestCase):
def run_mnist_2cpu(self, target_file_name):
cluster, pod = get_cluster_from_args(
[0, 1]) #tmp use. for getting trainer_nranks()
procs = start_local_trainers_cpu(
cluster.trainers_endpoints(),
training_script=target_file_name,
training_script_args=[])
while True: while True:
alive = watch_local_trainers(procs, cluster.trainers_nranks()) alive = watch_local_trainers(procs, cluster.trainers_nranks())
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_sparse_embedding import TestSparseEmbedding
from parallel_dygraph_sparse_embedding_fp64 import TestSparseEmbeddingFP64
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphSparseEmdedding_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingFP64_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding_fp64(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_fp64.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_sparse_embedding_over_height import TestSparseEmbeddingOverHeight
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphSparseEmdeddingOverHeight_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_over_height.py",
delta=1e-7,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_transformer import TestTransformer
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphTransformer_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
def test_transformer(self):
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphTransformerAccGrad_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
self._accumulate_gradient = True
self._find_unused_parameters = False
def test_transformer(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_unused_variables import TestSparseEmbeddingUnusedVars
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphUnusedVar_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
def test_net(self):
self.check_with_place(
"parallel_dygraph_unused_variables.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphNoVar_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
def test_net(self):
self.check_with_place(
"parallel_dygraph_none_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSharedUnusedVariables_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._gloo_mode = True
self._dygraph = True
def test_mnist(self):
self.check_with_place(
"parallel_dygraph_shared_unused_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
...@@ -24,6 +24,7 @@ from paddle.distributed.spawn import _get_subprocess_env_list, _options_valid_ch ...@@ -24,6 +24,7 @@ from paddle.distributed.spawn import _get_subprocess_env_list, _options_valid_ch
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
import multiprocessing
# NOTE(chenweihang): Coverage CI is currently not able to count python3 # NOTE(chenweihang): Coverage CI is currently not able to count python3
# unittest, so the unittests here covers some cases that will only be # unittest, so the unittests here covers some cases that will only be
...@@ -89,8 +90,8 @@ class TestSpawnAssistMethod(unittest.TestCase): ...@@ -89,8 +90,8 @@ class TestSpawnAssistMethod(unittest.TestCase):
def test_get_default_nprocs(self): def test_get_default_nprocs(self):
paddle.set_device('cpu') paddle.set_device('cpu')
with self.assertRaises(RuntimeError): nprocs = _get_default_nprocs()
nprocs = _get_default_nprocs() self.assertEqual(nprocs, multiprocessing.cpu_count())
paddle.set_device('gpu') paddle.set_device('gpu')
nprocs = _get_default_nprocs() nprocs = _get_default_nprocs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册