diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 93b18e7e553b5e1d80fdd70dc9c6df02e04d0adb..71e16fc1651593aed54515c8aaa490f652006c34 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -29,6 +29,12 @@ if(NOT WIN32) endif() cc_library(data_loader SRCS data_loader.cc DEPS enforce) endif(NOT WIN32) +if(WITH_GLOO) + cc_library(imperative_gloo_context SRCS gloo_context.cc DEPS collective_helper device_context tensor var_type_traits) + if ( WIN32 OR (NOT (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL) )) + cc_library(reducer SRCS reducer.cc DEPS layer) + endif() +endif() cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function) diff --git a/paddle/fluid/imperative/gloo_context.cc b/paddle/fluid/imperative/gloo_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..d7df6ec3c116416e81975ec556e5d73b11313e09 --- /dev/null +++ b/paddle/fluid/imperative/gloo_context.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/gloo_context.h" +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { + +void GLOOParallelContext::Init() { + // PADDLE_THROW(platform::errors::OutOfRange( + // "Still not implement Init")); + VLOG(4) << "Start GLOOParallelContext initialization"; + auto gloo_wrapper = framework::GlooWrapper::GetInstance(); + gloo_wrapper->SetSize(strategy_.nranks_); + gloo_wrapper->SetRank(strategy_.local_rank_); + gloo_wrapper->SetPrefix(""); + gloo_wrapper->SetIface("lo"); + auto addr = paddle::string::Split(strategy_.trainer_endpoints_[0], ':'); + VLOG(4) << "Server is" << strategy_.trainer_endpoints_[0]; + std::string host = addr[0]; + int port = std::stoi(addr[1]); + gloo_wrapper->SetHttpStore(host, port, "worker"); + gloo_wrapper->Init(); + device_ = std::unique_ptr( + new platform::CPUDeviceContext(platform::CPUPlace())); +} + +void GLOOParallelContext::InitWithRingID(int ring_id) { + PADDLE_THROW( + platform::errors::OutOfRange("Still not implement InitWithRingID")); +} + +#define GLOO_CASE(type, T, gw) \ + case type: { \ + VLOG(4) << "Use the gloo all reduce to sync. SRC:" << src_tensor; \ + std::vector send_vector##T; \ + framework::TensorToVector(src_tensor, &send_vector##T); \ + auto recv_vector##T = gw->AllReduce(send_vector##T); \ + framework::TensorFromVector(recv_vector##T, dst_tensor); \ + VLOG(4) << "DST:" << *dst_tensor; \ + break; \ + } + +void GLOOParallelContext::AllReduceByStream(const framework::Variable &src, + framework::Variable *dst, + int ring_id, bool use_calc_stream) { + // AllReduce(src, dst, strategy_, ring_id, use_calc_stream); + auto src_tensor = src.Get(); + auto *dst_tensor = dst->GetMutable(); + auto gloo_wrapper = framework::GlooWrapper::GetInstance(); + dst_tensor->Resize(src_tensor.dims()); + switch (src_tensor.type()) { + GLOO_CASE(framework::proto::VarType::FP32, float, gloo_wrapper); + GLOO_CASE(framework::proto::VarType::FP64, double, gloo_wrapper); + GLOO_CASE(framework::proto::VarType::INT32, int, gloo_wrapper); + GLOO_CASE(framework::proto::VarType::INT64, int64_t, gloo_wrapper); + default: { + PADDLE_THROW( + platform::errors::InvalidArgument("Invalid datatype for allreduce")); + } + } + gloo_wrapper->Barrier(); +} + +paddle::platform::DeviceContext *GLOOParallelContext::GetDeviceContext( + int ring_id) { + // return the CPUDeviceContext + return device_.get(); +} + +void GLOOParallelContext::WaitCompute(int ring_id) { + // do nothing because cpu don't need sync + return; +} + +void GLOOParallelContext::WaitComm(int ring_id) { + // do nothing because cpu don't need sync + return; +} + +void GLOOParallelContext::SynchronizeCompute() { + // do nothing because cpu don't need sync + return; +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/gloo_context.h b/paddle/fluid/imperative/gloo_context.h new file mode 100644 index 0000000000000000000000000000000000000000..f54dc1a406a92f643e02699f545ecd20b27ebc20 --- /dev/null +++ b/paddle/fluid/imperative/gloo_context.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include "paddle/fluid/imperative/parallel_context.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { + +class GLOOParallelContext : public ParallelContext { + public: + explicit GLOOParallelContext(const ParallelStrategy& strategy, + const platform::Place& place) + : ParallelContext(strategy, place) {} + + ~GLOOParallelContext() override = default; + + void Init() override; + + void InitWithRingID(int ring_id) override; + + void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, int ring_id, + bool use_calc_stream) override; + + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; + + void WaitCompute(int ring_id) override; + + void WaitComm(int ring_id) override; + + void SynchronizeCompute() override; + + private: + std::unique_ptr device_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 4166edf32573ca0394680f408a0f11e17d2ca372..5c426bc677e1fdba3b34c04cf6b4e390f66c688a 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -28,7 +28,7 @@ namespace paddle { namespace imperative { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) // div the nranks void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { framework::Tensor *tensor = @@ -42,9 +42,12 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { DivNRanks(tensor, nranks, context); #endif } else if (platform::is_cpu_place(tensor->place())) { + VLOG(4) << "before div 2" << *tensor; + VLOG(4) << "NDiv for cpu devices : rank = " << nranks; framework::VisitDataTypeSmall( dtype_, DivNRanksForAllReduce( tensor, nranks, context)); + VLOG(4) << "after div 2" << *tensor; } else if (platform::is_xpu_place(tensor->place())) { #ifdef PADDLE_WITH_XPU_BKCL // TODO(liuyuhui) support xpu about div nranks in the future @@ -764,8 +767,8 @@ void Reducer::MarkGroupReady(size_t group_index) { for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; ++next_group_) { - auto &group = groups_[next_group_]; - const int run_order = next_group_ % nrings_; + UNUSED auto &group = groups_[next_group_]; + UNUSED const int run_order = next_group_ % nrings_; // For CUDA or XPU, compute_stream --> comm_stream. // For CPU, do nothing. @@ -792,11 +795,12 @@ void Reducer::MarkGroupReady(size_t group_index) { cv_.notify_all(); } }); -#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) +#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \ + defined(PADDLE_WITH_GLOO) FusedAllReduceSchedule(run_order, group, next_group_); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "Not compiled with BKCL or NCCL.")); + "Not compiled with BKCL or NCCL or GLOO.")); #endif } } @@ -974,7 +978,8 @@ void Reducer::FinalizeBackward() { if (find_unused_vars_each_step_) { // TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_GLOO) ProcessUnusedDenseVars(); #endif // Initialize local used vars diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 3cc40f7b1306a13c641c8a71b0e284ac203a5fc8..b5a7dd149f09fefeac7c8a80e6d541534573a3bf 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -49,7 +49,7 @@ namespace paddle { namespace imperative { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) template struct DivNRanksFunctor { diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index ca0ed68a13f2fa1b9f0233ccad00f51c9e9d4592..a4ad9333163378c3575d5b73d9e4a615bed0051c 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -69,6 +69,8 @@ endif() if(WITH_GLOO) set(PYBIND_DEPS ${PYBIND_DEPS} gloo_context) set(PYBIND_SRCS ${PYBIND_SRCS} gloo_context_py.cc) + set(PYBIND_DEPS ${PYBIND_DEPS} imperative_gloo_context) + set(PYBIND_DEPS ${PYBIND_DEPS} reducer) endif(WITH_GLOO) if (WITH_CRYPTO) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 67f9f8b203defd7396caa9e348cb4e63f9482128..422d6de6f33b85dbcbe99b648b728749e6184a39 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -35,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/bkcl_context.h" #include "paddle/fluid/imperative/data_loader.h" +#include "paddle/fluid/imperative/gloo_context.h" #include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/nccl_context.h" @@ -2017,7 +2018,7 @@ void BindImperative(py::module *m_ptr) { py::call_guard()); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) py::class_>(m, "ParallelContext"); @@ -2062,6 +2063,20 @@ void BindImperative(py::module *m_ptr) { &imperative::BKCLParallelContext::InitWithRingID, py::arg("ring_id")); #endif + +#if defined(PADDLE_WITH_GLOO) + // xiongkun + py::class_>( + m, "GLOOParallelContext") + .def(py::init()) + .def("init", [](imperative::GLOOParallelContext &self) { self.Init(); }) + .def("init_with_ring_id", + &imperative::GLOOParallelContext::InitWithRingID, + py::arg("ring_id")); +#endif + m.def("pylayer_apply", [](const platform::CPUPlace &place, const py::object &cls, const py::args args, const py::kwargs kwargs) { diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 1a4fe21afa577fa2e6a1f82528f45acdd977d6b7..7789b17429c4eb23a10ab947d47b9d49966272ad 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -55,19 +55,51 @@ def _start_kv_server(port, http_server_d, size): http_server.stop() -def init_parallel_env(): +def _check_backend(backend): + if backend not in ['nccl', 'gloo', 'bkcl', 'auto']: + raise ValueError( + "paddle.distributed initialize error, " + "backend argument can only be one of 'nccl', 'gloo', 'bkcl', 'auto', but got %s" + % backend) + + if backend == 'nccl' and not core.is_compiled_with_cuda(): + raise ValueError( + "paddle.distributed initialize error, " + "your paddle is not compiled with cuda but you assign 'nccl' as backend." + ) + + if backend == 'bkcl' and not core.is_compiled_with_xpu(): + raise ValueError( + "paddle.distributed initialize error, " + "your paddle is not compiled with xpu but you assign 'bkcl' as backend." + ) + + if backend in ['auto', 'nccl', 'bkcl'] and (core.is_compiled_with_cuda() or + core.is_compiled_with_xpu()): + # passes 'auto' and can use cuda or xpu, use the default logics. so return False + return False + else: + return True + + +def init_parallel_env(backend='auto'): """ Initialize parallel training environment in dynamic graph mode. .. note:: Now initialize both `NCCL` and `GLOO` contexts for communication. + Args: + backend (string): A string represents the backend used by DataParallel, + should be one of 'gloo'(for cpu), 'nccl'(for cuda), 'bkcl'(for xpu), 'auto'(auto detect). + The auto detection prefer 'nccl', 'bkcl' than 'gloo'. + Returns: None Examples: .. code-block:: python - + # required: gpu import paddle import paddle.nn as nn import paddle.optimizer as opt @@ -120,13 +152,14 @@ def init_parallel_env(): "Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything." ) return - - # 1. gpu xpu check, must be gpu or xpu - if not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(): + # NOTE(xiongkun): support cpu gloo only, add this environment variable to + # enable cpu only gloo prarllel training) + is_cpu_only = _check_backend(backend) + # 1. gpu xpu check, must be gpu or xpu, + if not (is_cpu_only or core.is_compiled_with_cuda() or + core.is_compiled_with_xpu()): raise NotImplementedError( - "Cannot initialize parallel environment in CPU-only version, now only " - "supports initializing the GPU and XPU parallel environment. Please recompile " - "or reinstall paddle with GPU or XPU support.") + "If you want to use CPU-only version, please use 'gloo' as backend") # 2. check env def _check_var_exists(var_name): @@ -136,9 +169,9 @@ def init_parallel_env(): "environment variable %s is needed, but not set." % var_name) - if core.is_compiled_with_cuda(): + if not is_cpu_only and core.is_compiled_with_cuda(): _check_var_exists("FLAGS_selected_gpus") - elif core.is_compiled_with_xpu(): + elif not is_cpu_only and core.is_compiled_with_xpu(): _check_var_exists('FLAGS_selected_xpus') _check_var_exists("PADDLE_TRAINER_ID") @@ -148,7 +181,7 @@ def init_parallel_env(): # 3: init gloo context (step 1: httpsever start) init_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0")) - if init_gloo: + if is_cpu_only or init_gloo: ep_rank_0 = parallel_env.trainer_endpoints[0].split(":") manager = Manager() # glboal dict to store status @@ -180,14 +213,19 @@ def init_parallel_env(): # directly, if they want to switch default place, # they need to call a function to change default place, # here just set correctly place to users - if core.is_compiled_with_cuda(): + if is_cpu_only: + place = core.CPUPlace() + elif core.is_compiled_with_cuda(): place = core.CUDAPlace(parallel_env.device_id) elif core.is_compiled_with_xpu(): place = core.XPUPlace(parallel_env.device_id) - _set_expected_place(place) + _set_expected_place(place) # init nccl or bkcl context - if core.is_compiled_with_cuda(): + if is_cpu_only: + parallel_helper._set_parallel_ctx( + core.GLOOParallelContext(strategy, place)) + elif core.is_compiled_with_cuda(): parallel_helper._set_parallel_ctx( core.NCCLParallelContext(strategy, place)) elif core.is_compiled_with_xpu(): @@ -196,18 +234,22 @@ def init_parallel_env(): other_endpoints = strategy.trainer_endpoints[:] other_endpoints.remove(strategy.current_endpoint) - if strategy.local_rank == 0: + if not is_cpu_only and strategy.local_rank == 0: wait_server_ready(other_endpoints) parallel_helper._init_parallel_ctx() - # 5: init gloo context (step 2: gloo init) # dividing init_gloo into two part beacause nccl and gloo # are separately looking for free ports which sometimes # leads to port-conflict. - if init_gloo: - wait_server_ready([parallel_env.trainer_endpoints[0]]) + if is_cpu_only and parallel_env.rank == 0: + # compare to init_gloo, we don't need to + # init gloo, because we do this in _init_parallel_ctx; + http_server_d["running"] = False + http_server.join() + elif init_gloo: + wait_server_ready([parallel_env.trainer_endpoints[0]]) gloo_strategy = core.GlooParallelStrategy() gloo_strategy.rank = parallel_env.rank gloo_strategy.rank_num = parallel_env.world_size diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 44fd4c04e20955fcbd6ad583aafb07610f8232c4..bb250e32c0ee676229f30bd9a7e0f92940f0395f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -184,6 +184,10 @@ endif() list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_hybrid_parallel) +if (NOT WITH_GLOO) + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel_cpuonly) +endif() + if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py index 5c518976d1f36ce6f64e2228675131b62e6f2f5a..048c9b399d8040771ab6aa646205cb94f816342f 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py @@ -16,6 +16,7 @@ from __future__ import division from __future__ import print_function import unittest +import os import paddle import numpy as np @@ -65,7 +66,8 @@ class SimpleNet(fluid.Layer): class TestDistTraning(unittest.TestCase): def test_multiple_gpus(self): - dist.init_parallel_env() + backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto') + dist.init_parallel_env(backend) self.trainer_id = dist.get_rank() model_a = SimpleNet(self.trainer_id) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py new file mode 100644 index 0000000000000000000000000000000000000000..6caf0c54ae6ca4b28536f4943b9fca861f06ab1e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import time +import paddle.fluid as fluid +import copy +import os +import subprocess + +from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc + + +def get_cluster_from_args(selected_gpus): + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) + + +def get_gpus(selected_gpus): + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + +def start_local_trainers(cluster, + pod, + training_script, + training_script_args, + log_dir=None): + current_env = copy.copy(os.environ.copy()) + #paddle broadcast ncclUniqueId use socket, and + #proxy maybe make trainers unreachable, so delete them. + #if we set them to "", grpc will log error message "bad uri" + #so just delete them. + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + for t in pod.trainers: + proc_env = { + "PADDLE_TRAINER_ID": "%d" % t.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + "PADDLE_DISTRI_BACKEND": + "gloo", # make init_parallel_env get 'gloo' argument. + } + + current_env.update(proc_env) + + print("trainer proc env:{}".format(current_env)) + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + cmd = "python -m coverage run --branch -p " + training_script + else: + cmd = "python -u " + training_script + + print("start trainer proc:{} env:{}".format(cmd, proc_env)) + + fn = None + + proc = subprocess.Popen(cmd.split(" "), env=current_env) + + tp = TrainerProc() + tp.proc = proc + tp.rank = t.rank + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + +class TestMultipleGpus(unittest.TestCase): + def run_mnist_2gpu(self, target_file_name): + #if not fluid.core.is_compiled_with_cuda( + #) or fluid.core.get_cuda_device_count() == 0: + # return + + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + cluster, pod = get_cluster_from_args(selected_gpus) + procs = start_local_trainers( + cluster, + pod, + training_script=target_file_name, + training_script_args=[]) + + while True: + alive = watch_local_trainers(procs, cluster.trainers_nranks()) + + if not alive: + print("Local procs complete, POD info:{}".format(pod)) + break + time.sleep(3) + + +class TestDataParallelGradientCheck(TestMultipleGpus): + def test_multiple_gpus_dynamic(self): + self.run_mnist_2gpu('parallel_dygraph_gradient_check.py') + + +if __name__ == "__main__": + unittest.main()