From 9ebf05b003ab910bac2636496ef89d43927b7e60 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Fri, 5 Mar 2021 11:18:27 +0800 Subject: [PATCH] [Kunlun]Multi xpu dygraph performance optimization , add distributed.spawn support for multi xpu and some bug-fixes (#31130) --- paddle/fluid/imperative/reducer.cc | 96 ++++++++---- paddle/fluid/imperative/reducer.h | 11 +- .../operators/collective/c_comm_init_op.cc | 8 +- .../operators/collective/gen_bkcl_id_op.cc | 2 +- .../pybind/global_value_getter_setter.cc | 7 + paddle/fluid/pybind/pybind.cc | 6 + python/paddle/distributed/cloud_utils.py | 20 +-- .../paddle/distributed/fleet/launch_utils.py | 2 +- python/paddle/distributed/spawn.py | 141 ++++++++++++------ python/paddle/distributed/utils.py | 26 +++- python/paddle/fluid/compiler.py | 4 - .../test_spawn_and_init_parallel_env.py | 10 +- 12 files changed, 224 insertions(+), 109 deletions(-) diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index f8740940d0..5dd7e2d821 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -301,6 +301,10 @@ Reducer::Reducer(const std::vector> &vars, VLOG(3) << "Start construct the Reducer ..."; nrings_ = parallel_ctx->GetNRings(); nranks_ = parallel_ctx->GetNRanks(); +#ifdef PADDLE_WITH_XPU_BKCL + comm_pool_.reset(new ::ThreadPool(1)); + comm_op_count_ = 0; +#endif // initialize groups InitializeGroups(group_indices); for (size_t global_var_index = 0; global_var_index < vars_.size(); @@ -634,6 +638,8 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { } } +// TODO(liuyuhui): If BKCL support non-blocking communication, it should be +// fixed as same as multi gpus card trainging. void Reducer::MarkGroupReady(size_t group_index) { if (group_index > next_group_) { VLOG(3) << "It will adjust the order of group in next batch automatically"; @@ -651,45 +657,71 @@ void Reducer::MarkGroupReady(size_t group_index) { // so we expose WaitCompute() interface and call // it here. parallel_ctx_->WaitCompute(run_order); - - if (group.is_sparse_) { - if (group.sparse_contents_ != nullptr) { - VLOG(3) << "sparse group [" << next_group_ - << "] start allreduce in ring[" << run_order << "]"; - group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_); - parallel_ctx_->AllReduceByStream( - *group.sparse_contents_, group.sparse_contents_, run_order, false); - } else { - VLOG(3) << "The sparse group[" << next_group_ - << "] has no var to allreduce"; +#ifdef PADDLE_WITH_XPU_BKCL + { + std::lock_guard lock(mutex_); + comm_op_count_ += 1; // lock + } + // TODO(liuyuhui): Add try catch to deal with exception later, + // otherwise the main thread will continue to run when an exception is + // thrown in comm_pool_. + comm_pool_->enqueue([&] { + auto dev_id = BOOST_GET_CONST(platform::XPUPlace, place_).device; + platform::SetXPUDeviceId(dev_id); + FusedAllReduceSchedule(run_order, group); + { + std::lock_guard lock(mutex_); + comm_op_count_ -= 1; // lock + cv_.notify_all(); } - } else { - VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring[" + }); +#elif defined(PADDLE_WITH_NCCL) + FusedAllReduceSchedule(run_order, group); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Not compiled with BKCL or NCCL.")); +#endif + } +} + +void Reducer::FusedAllReduceSchedule(int run_order, Group &group) { + if (group.is_sparse_) { + if (group.sparse_contents_ != nullptr) { + VLOG(3) << "sparse group [" << next_group_ << "] start allreduce in ring[" << run_order << "]"; - // Select common commstream to concat tensors - // group.dense_tensors ---> group.dense_contents_ - group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); + group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_); + parallel_ctx_->AllReduceByStream( + *group.sparse_contents_, group.sparse_contents_, run_order, false); + } else { + VLOG(3) << "The sparse group[" << next_group_ + << "] has no var to allreduce"; + } + } else { + VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring[" + << run_order << "]"; + // Select common commstream to concat tensors + // group.dense_tensors ---> group.dense_contents_ + group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); // NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support // default stream for communicating, so there exist some problems in // synchronization. And need to add a WaitComm there. -// TODO(liuyuhui): If BKCL support events, it should be fixed as non-blocking -// communication. +// TODO(liuyuhui): If BKCL support non-blocking communication, it should be +// fixed as multi gpus card trainging. #ifdef PADDLE_WITH_XPU_BKCL - if (platform::is_xpu_place(group.dense_tensors_[0].place())) { - parallel_ctx_->WaitComm(run_order); - } + if (platform::is_xpu_place(group.dense_tensors_[0].place())) { + parallel_ctx_->WaitComm(run_order); + } #endif - group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_); + group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_); - // Start allreduce - parallel_ctx_->AllReduceByStream( - group.dense_contents_, &(group.dense_contents_), run_order, false); + // Start allreduce + parallel_ctx_->AllReduceByStream( + group.dense_contents_, &(group.dense_contents_), run_order, false); - // Select common commstream to split tensors - // group.dense_contents_ ---> group.dense_tensors - group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order)); - } + // Select common commstream to split tensors + // group.dense_contents_ ---> group.dense_tensors + group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order)); } } @@ -717,6 +749,12 @@ std::vector> Reducer::RebuildGruops() { void Reducer::FinalizeBackward() { all_group_ready_ = false; +#ifdef PADDLE_WITH_XPU_BKCL + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return comm_op_count_ == 0; }); + } +#endif // Must prevent compute_stream_ starting until all comm streams have finished for (int i = 0; i < nrings_; ++i) { parallel_ctx_->WaitComm(i); diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index f352ad17fd..b2680d0dea 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once - +#include #include #include #include @@ -153,6 +153,8 @@ class Reducer { void MarkGroupReady(size_t group_index); + void FusedAllReduceSchedule(int run_order, Group& group); // NOLINT + void FinalizeBackward(); std::vector> RebuildGruops(); @@ -187,6 +189,13 @@ class Reducer { bool has_marked_unused_vars_{false}; bool find_unused_vars_{false}; bool all_group_ready_{false}; +#ifdef PADDLE_WITH_XPU_BKCL + // comm_pool_ is used for scheduling allreduce in multi Kunlun cards training. + std::unique_ptr<::ThreadPool> comm_pool_{nullptr}; + uint32_t comm_op_count_; + std::mutex mutex_; + std::condition_variable cv_; +#endif }; std::vector> AssignGroupBySize( diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index 3464bff486..f451086167 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -68,10 +68,10 @@ class CCommInitOp : public framework::OperatorBase { nccl_id, nranks, rank_id, device_id, rid); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with GPU.")); + "PaddlePaddle should be compiled with GPU.")); #endif } else if (is_xpu_place(place)) { -#if defined(PADDLE_WITH_BKCL) +#if defined(PADDLE_WITH_XPU_BKCL) BKCLUniqueId* bkcl_id = var->GetMutable(); int nranks = Attr("nranks"); @@ -81,7 +81,7 @@ class CCommInitOp : public framework::OperatorBase { rid, 0, platform::errors::OutOfRange( "Ring id must equal 0 in multi Kunlun cards training, but got %d", - ring_id)); + rid)); int device_id = BOOST_GET_CONST(platform::XPUPlace, place).device; if (Attr("device_id") >= 0) { device_id = Attr("device_id"); @@ -90,7 +90,7 @@ class CCommInitOp : public framework::OperatorBase { bkcl_id, nranks, rank_id, device_id, rid); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with XPU.")); + "PaddlePaddle should be compiled with XPU.")); #endif } else { PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/collective/gen_bkcl_id_op.cc b/paddle/fluid/operators/collective/gen_bkcl_id_op.cc index f14271e367..7067bfb314 100644 --- a/paddle/fluid/operators/collective/gen_bkcl_id_op.cc +++ b/paddle/fluid/operators/collective/gen_bkcl_id_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index 1732cf5bfd..6074d191ad 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -87,6 +87,10 @@ DECLARE_uint64(reallocate_gpu_memory_in_mb); // others DECLARE_bool(sync_nccl_allreduce); #endif +#ifdef PADDLE_WITH_XPU +// device management +DECLARE_string(selected_xpus); +#endif #ifdef PADDLE_WITH_DISTRIBUTE DECLARE_int32(rpc_send_thread_num); DECLARE_int32(rpc_get_thread_num); @@ -365,6 +369,9 @@ static void RegisterGlobalVarGetterSetter() { FLAGS_reallocate_gpu_memory_in_mb, FLAGS_enable_cublas_tensor_op_math, FLAGS_selected_gpus, FLAGS_sync_nccl_allreduce); #endif +#ifdef PADDLE_WITH_XPU + REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_xpus); +#endif #ifdef PADDLE_WITH_DITRIBUTE REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_send_thread_num, FLAGS_rpc_get_thread_num, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2e5cd3473c..c8ca3bf2c8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2497,6 +2497,12 @@ All parameter, weight, gradient are variables in Paddle. [](BuildStrategy &self, int nccl_comm_num) { self.nccl_comm_num_ = nccl_comm_num; }) + .def_property( + "bkcl_comm_num", + [](const BuildStrategy &self) { return self.bkcl_comm_num_; }, + [](BuildStrategy &self, int bkcl_comm_num) { + self.bkcl_comm_num_ = bkcl_comm_num; + }) .def_property("use_hierarchical_allreduce", [](const BuildStrategy &self) { return self.use_hierarchical_allreduce_; diff --git a/python/paddle/distributed/cloud_utils.py b/python/paddle/distributed/cloud_utils.py index ae603a0e60..962ba62b15 100644 --- a/python/paddle/distributed/cloud_utils.py +++ b/python/paddle/distributed/cloud_utils.py @@ -17,9 +17,9 @@ import paddle from paddle.distributed.utils import get_cluster, logger, get_gpus, get_cluster_from_args -def get_cloud_cluster(args_node_ips, args_node_ip, args_port, selected_gpus): +def get_cloud_cluster(args_node_ips, args_node_ip, args_port, selected_devices): """ - args_node_ips:string, args_node_ip:string, args_port: int, selected_gpus:list + args_node_ips:string, args_node_ip:string, args_port: int, selected_devices:list """ #you can automatically get ip info while using paddlecloud multi nodes mode. node_ips = os.getenv("PADDLE_TRAINERS") @@ -60,7 +60,7 @@ paddlecloud environment.".format(args_node_ips, node_ips)) paddle_port = int(os.getenv("PADDLE_PORT", "")) if paddle_ports_num >= len( - selected_gpus) and paddle_port != args_port: + selected_devices) and paddle_port != args_port: logger.warning("Use Cloud specified port:{}.".format( paddle_port)) started_port = paddle_port @@ -72,7 +72,7 @@ paddlecloud environment.".format(args_node_ips, node_ips)) if started_port is None: started_port = 6170 ports = [ - x for x in range(started_port, started_port + len(selected_gpus)) + x for x in range(started_port, started_port + len(selected_devices)) ] trainer_endpoints = [] for ip in node_ips: @@ -90,7 +90,7 @@ paddlecloud environment.".format(args_node_ips, node_ips)) .format(node_ips, node_ip, node_rank, trainer_endpoints)) cluster, pod = get_cluster(node_ips, node_ip, trainer_endpoints, - selected_gpus) + selected_devices) return cluster, cluster.pods[node_rank] @@ -100,20 +100,20 @@ def _get_trainers_num(): def get_cluster_and_pod(args): # parse arguments, used for cloud-single-machine and local - selected_gpus = get_gpus(args.selected_gpus) + selected_devices = get_gpus(args.selected_devices) trainers_num = _get_trainers_num() - logger.debug("parsed from args trainerss_num:{} selected_gpus:{}".format( - trainers_num, selected_gpus)) + logger.debug("parsed from args trainerss_num:{} selected_devices:{}".format( + trainers_num, selected_devices)) cluster = None pod = None if args.use_paddlecloud and trainers_num != 1: cluster, pod = get_cloud_cluster(args.cluster_node_ips, args.node_ip, - args.started_port, selected_gpus) + args.started_port, selected_devices) logger.info("get cluster from cloud:{}".format(cluster)) else: - cluster, pod = get_cluster_from_args(args, selected_gpus) + cluster, pod = get_cluster_from_args(args, selected_devices) logger.info("get cluster from args:{}".format(cluster)) return cluster, pod diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index b4f1f93149..c5cb1ec94a 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -280,7 +280,7 @@ def get_cluster(node_ips, node_ip, trainer_endpoints, device_mode, if isinstance(devices_per_proc[i], (list, tuple)): trainer.gpus.extend(devices_per_proc[i]) else: - trainer.gpus.extend(devices_per_proc[i]) + trainer.gpus.append(devices_per_proc[i]) trainer.endpoint = "%s" % (cur_node_endpoints[i]) trainer.rank = trainer_rank trainer_rank += 1 diff --git a/python/paddle/distributed/spawn.py b/python/paddle/distributed/spawn.py index 911fed416c..56e59ac88e 100644 --- a/python/paddle/distributed/spawn.py +++ b/python/paddle/distributed/spawn.py @@ -50,10 +50,10 @@ class ParallelEnvArgs(object): self.print_config = True # It's for gpu training and the training process will run - # on the selected_gpus, each process is bound to a single GPU. + # on the selected_devices, each process is bound to a single GPU. # And if it's not set, this module will use all the gpu cards # for training. - self.selected_gpus = None + self.selected_devices = None def _py_supported_check(): @@ -67,9 +67,9 @@ def _py_supported_check(): def _options_valid_check(options): # `print_config` keeped as a debug options, not show to users - supported_options = ['start_method', 'ips', 'gpus', 'print_config'] + supported_options = ['start_method', 'ips', 'gpus', 'xpus', 'print_config'] deprecated_options = [ - 'selected_gpus', 'started_port', 'cluster_node_ips', 'node_ip', + 'selected_devices', 'started_port', 'cluster_node_ips', 'node_ip', 'use_paddlecloud' ] for key in options: @@ -109,47 +109,83 @@ def _get_subprocess_env_list(nprocs, options): if args.cluster_node_ips is None: args.cluster_node_ips = "127.0.0.1" - # deal with `gpus` - # set default selected gpus + # deal with `gpus` or `xpus` + # set default selected devices(gpus or xpus) # e.g. if the nprocs is 4, the selected gpus is "0,1,2,3" - # NOTE(chenweihang): [ why not use FLAGS_selected_gpus directly? ] - # because the FLAGS_selected_gpus may be used in other place, - # if we set FLAGS_selected_gpus to be `0,1,2,3`, it may cause error + # NOTE(chenweihang): [ why not use FLAGS_selected_gpus or FLAGS_selected_xpus directly? ] + # because the FLAGS_selected_gpus or FLAGS_selected_xpus may be used in other place, + # if we set FLAGS_selected_gpus or FLAGS_selected_xpus to be `0,1,2,3`, it may cause error # when using `ParallelEnv` - # NOTE(chenweihang): use absolute gpu card id - args.selected_gpus = options.get('gpus', None) - if args.selected_gpus is None: - args.selected_gpus = options.get('selected_gpus', None) - env_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) - if env_devices is None or env_devices == "": - env_devices_list = [ - str(x) for x in six.moves.range(core.get_cuda_device_count()) - ] - else: - env_devices_list = env_devices.split(',') - if args.selected_gpus is None: - if len(env_devices_list) < nprocs: - raise RuntimeError( - "the number of visible devices(%d) is less than the number " - "of spawn processes(%d), please ensure that the correct " - "`nprocs` argument is passed or the environment variable " - "`CUDA_VISIBLE_DEVICES` is correctly configured." % - (len(env_devices_list), nprocs)) - args.selected_gpus = ",".join( - [str(env_devices_list[x]) for x in range(0, nprocs)]) - else: - selected_gpu_list = args.selected_gpus.split(',') - if len(selected_gpu_list) != nprocs: - raise ValueError( - "The number of selected gpus(%s) is not equal to " - "the number of spawn processes(%d), please ensure that the " - "correct `nprocs` and `gpus` arguments are passed." % - (len(selected_gpu_list), nprocs)) - for card_id in selected_gpu_list: - if card_id not in env_devices_list: - raise ValueError("The selected gpu card %s cannot found in " - "CUDA_VISIBLE_DEVICES (%s)." % - (card_id, ",".join(env_devices_list))) + # NOTE(chenweihang): use absolute gpu or xpu card id + if core.is_compiled_with_cuda(): + args.selected_devices = options.get('gpus', None) + if args.selected_devices is None: + args.selected_devices = options.get('selected_devices', None) + env_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) + if env_devices is None or env_devices == "": + env_devices_list = [ + str(x) for x in six.moves.range(core.get_cuda_device_count()) + ] + else: + env_devices_list = env_devices.split(',') + if args.selected_devices is None: + if len(env_devices_list) < nprocs: + raise RuntimeError( + "the number of visible devices(%d) is less than the number " + "of spawn processes(%d), please ensure that the correct " + "`nprocs` argument is passed or the environment variable " + "`CUDA_VISIBLE_DEVICES` is correctly configured." % + (len(env_devices_list), nprocs)) + args.selected_devices = ",".join( + [str(env_devices_list[x]) for x in range(0, nprocs)]) + else: + selected_device_list = args.selected_devices.split(',') + if len(selected_device_list) != nprocs: + raise ValueError( + "The number of selected devices(%s) is not equal to " + "the number of spawn processes(%d), please ensure that the " + "correct `nprocs` and `gpus` arguments are passed." % + (len(selected_device_list), nprocs)) + for card_id in selected_device_list: + if card_id not in env_devices_list: + raise ValueError("The selected gpu card %s cannot found in " + "CUDA_VISIBLE_DEVICES (%s)." % + (card_id, ",".join(env_devices_list))) + + elif core.is_compiled_with_xpu(): + args.selected_devices = options.get('xpus', None) + if args.selected_devices is None: + args.selected_devices = options.get('selected_devices', None) + env_devices = os.getenv("XPU_VISIBLE_DEVICES", None) + if env_devices is None or env_devices == "": + env_devices_list = [ + str(x) for x in six.moves.range(core.get_xpu_device_count()) + ] + else: + env_devices_list = env_devices.split(',') + if args.selected_devices is None: + if len(env_devices_list) < nprocs: + raise RuntimeError( + "the number of visible devices(%d) is less than the number " + "of spawn processes(%d), please ensure that the correct " + "`nprocs` argument is passed or the environment variable " + "`XPU_VISIBLE_DEVICES` is correctly configured." % + (len(env_devices_list), nprocs)) + args.selected_devices = ",".join( + [str(env_devices_list[x]) for x in range(0, nprocs)]) + else: + selected_device_list = args.selected_devices.split(',') + if len(selected_device_list) != nprocs: + raise ValueError( + "The number of selected devices(%s) is not equal to " + "the number of spawn processes(%d), please ensure that the " + "correct `nprocs` and `xpus` arguments are passed." % + (len(selected_device_list), nprocs)) + for card_id in selected_device_list: + if card_id not in env_devices_list: + raise ValueError("The selected xpu card %s cannot found in " + "XPU_VISIBLE_DEVICES (%s)." % + (card_id, ",".join(env_devices_list))) # set other inner args args.node_ip = options.get('node_ip', None) @@ -185,12 +221,17 @@ def _remove_risky_env(): def _set_trainer_env(env_dict): - # NOTE(chenweihang): [ Why need set FLAGS_selected_gpus here? ] + # NOTE(chenweihang): [ Why need set FLAGS_selected_gpus or FLAGS_selected_xpus here? ] # When the child process starts, it will inherit the configuration of the # main process and set the FLAGS once, but the environment variable has - # not been set at this time, which leads to the FLAGS_selected_gpus + # not been set at this time, which leads to the FLAGS_selected_gpus or FLAGS_selected_xpus # is keep same with mainprocess(usually empty), so manually update the flags here - set_flags({'FLAGS_selected_gpus': env_dict['FLAGS_selected_gpus']}) + if core.is_compiled_with_cuda(): + set_flags({'FLAGS_selected_gpus': env_dict['FLAGS_selected_gpus']}) + elif core.is_compiled_with_xpu(): + set_flags({'FLAGS_selected_xpus': env_dict['FLAGS_selected_xpus']}) + else: + raise ValueError("PaddlePaddle should be compiled with XPU or CUDA.") for var_name in env_dict: os.environ[var_name] = env_dict[var_name] @@ -407,8 +448,14 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): if device == 'cpu': # TODO: not supports cpu parallel now nprocs = _cpu_num() - else: + elif device == 'gpu': nprocs = core.get_cuda_device_count() + elif device == 'xpu': + nprocs = core.get_xpu_device_count() + else: + raise ValueError( + "`device` should be a string of `cpu`, 'gpu' or 'xpu', but got {}". + format(device)) # NOTE(chenweihang): [ why need get cluster info before run? ] # when using `paddle.distributed.spawn` start parallel training, diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 54efce052e..f40a7b31b8 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -24,6 +24,7 @@ import six import subprocess from contextlib import closing import socket +from paddle.fluid import core logger = logging.getLogger("root") logger.propagate = False @@ -401,13 +402,24 @@ def find_free_ports(num): def _prepare_trainer_env(cluster, trainer): - proc_env = { - "FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in trainer.gpus]), - "PADDLE_TRAINER_ID": "%d" % trainer.rank, - "PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint, - "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), - "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) - } + if core.is_compiled_with_xpu(): + proc_env = { + "FLAGS_selected_xpus": + "%s" % ",".join([str(g) for g in trainer.gpus]), + "PADDLE_TRAINER_ID": "%d" % trainer.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) + } + elif core.is_compiled_with_cuda(): + proc_env = { + "FLAGS_selected_gpus": + "%s" % ",".join([str(g) for g in trainer.gpus]), + "PADDLE_TRAINER_ID": "%d" % trainer.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) + } return proc_env diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index a04d58ff25..2698f1a00d 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -360,10 +360,6 @@ class CompiledProgram(object): else: self._exec_strategy.num_threads = len(places) * 2 - if self._exec_strategy._use_device == DeviceType.XPU: - assert self._exec_strategy.num_threads == 1, \ - "Currently only single thread is supported in Kunlun XPU." - if self._build_strategy.num_trainers > 1: assert self._is_data_parallel, \ "If you use multi-trainer to train the model, you should use "\ diff --git a/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py index 53efa186d1..6efab81a26 100644 --- a/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py +++ b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py @@ -59,10 +59,10 @@ class TestSpawnAssistMethod(unittest.TestCase): with self.assertRaises(RuntimeError): _get_subprocess_env_list(nprocs=100, options=dict()) - def test_selected_gpus_error(self): + def test_selected_devices_error(self): with self.assertRaises(ValueError): options = dict() - options['selected_gpus'] = "100,101" + options['selected_devices'] = "100,101" _get_subprocess_env_list(nprocs=2, options=options) def test_get_correct_env(self): @@ -72,15 +72,15 @@ class TestSpawnAssistMethod(unittest.TestCase): self.assertEqual(env_dict['PADDLE_TRAINER_ID'], '0') self.assertEqual(env_dict['PADDLE_TRAINERS_NUM'], '1') - def test_nprocs_not_equal_to_selected_gpus(self): + def test_nprocs_not_equal_to_selected_devices(self): with self.assertRaises(ValueError): options = dict() - options['selected_gpus'] = "100,101,102" + options['selected_devices'] = "100,101,102" _get_subprocess_env_list(nprocs=2, options=options) def test_options_valid_check(self): options = dict() - options['selected_gpus'] = "100,101,102" + options['selected_devices'] = "100,101,102" _options_valid_check(options) with self.assertRaises(ValueError): -- GitLab