From 2378aa8af4adefb573ddb031946023ed51fe5f8b Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 21 Oct 2019 19:30:45 +0800 Subject: [PATCH] [Cherry-pick 1.6] Fix DGC test and DGC nan bug (#20708) --- cmake/external/dgc.cmake | 4 +- .../framework/details/all_reduce_op_handle.cc | 4 ++ .../framework/details/all_reduce_op_handle.h | 2 + .../details/sparse_all_reduce_op_handle.cc | 44 ++++++++++-- .../details/sparse_all_reduce_op_handle.h | 3 + .../multi_devices_graph_pass.cc | 3 +- paddle/fluid/platform/init.cc | 17 ----- paddle/fluid/platform/init.h | 2 - paddle/fluid/pybind/pybind.cc | 1 - .../incubate/fleet/collective/__init__.py | 1 - python/paddle/fluid/optimizer.py | 2 - .../fluid/tests/unittests/test_dist_base.py | 72 ++++++++++++++++--- .../unittests/test_dist_mnist_dgc_nccl.py | 18 +++++ 13 files changed, 129 insertions(+), 44 deletions(-) diff --git a/cmake/external/dgc.cmake b/cmake/external/dgc.cmake index 5d5fcc3d429..59b463c4c0f 100644 --- a/cmake/external/dgc.cmake +++ b/cmake/external/dgc.cmake @@ -23,8 +23,8 @@ INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR}) ExternalProject_Add( extern_dgc ${EXTERNAL_PROJECT_LOG_ARGS} - URL "http://fleet.bj.bcebos.com/collective.tgz" - URL_MD5 "015d565156c3de4e30fe25473f47e7a9" + URL "http://fleet.bj.bcebos.com/collective_ef2216a.tgz" + URL_MD5 "2f67549fd5f1262383d83289abc4f88f" SOURCE_DIR "${DGC_SOURCES_DIR}" CONFIGURE_COMMAND "" BUILD_COMMAND make -j diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index a367772aef8..8deacf4e7dd 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -171,6 +171,10 @@ void AllReduceOpHandle::NCCLAllReduceFunc( } }); + SyncNCCLAllReduce(); +} + +void AllReduceOpHandle::SyncNCCLAllReduce() { if (FLAGS_sync_nccl_allreduce) { for (auto &p : places_) { int dev_id = boost::get(p).device; diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index c18b0ed9290..c8ff151a882 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -63,6 +63,8 @@ class AllReduceOpHandle : public OpHandleBase { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) void NCCLAllReduceFunc( const std::vector> &all_reduce_calls); + + void SyncNCCLAllReduce(); #endif void AllReduceImpl(const std::vector &in_var_handles, diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc index e69bda6fcf8..282f0fc053e 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/profiler.h" @@ -105,7 +106,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() { size_t in_numel = 0; size_t out_numel = 0; PADDLE_ENFORCE(nranks_ > 1); - std::vector> all_reduce_calls; + std::vector> all_gather_calls; + std::vector> sparse_reduce_calls; std::vector allocations; @@ -141,15 +143,45 @@ void SparseAllReduceOpHandle::RunImplEncoded() { << ", nranks:" << nranks_ << ", gather_buf size:" << buf_size << ", k:" << k << ", place:" << place << ", dtype:" << dtype; - all_reduce_calls.emplace_back([=] { - PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce( - in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm, - stream)); + all_gather_calls.emplace_back([=] { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + in_tensor_buf, gather_buff, 2 * k, static_cast(dtype), + comm, stream)); + }); + + sparse_reduce_calls.emplace_back([=] { + platform::CUDADeviceGuard guard(dev_id); + PADDLE_ENFORCE_EQ(paddle::communication::dgc::sparseReduce( + gather_buff, k, out_tensor_buf, + static_cast(out_numel), nranks_, stream), + true); }); } WaitInputVarGenerated(); - NCCLAllReduceFunc(all_reduce_calls); + SparseAllReduceFunc(all_gather_calls, sparse_reduce_calls); +} + +void SparseAllReduceOpHandle::SparseAllReduceFunc( + const std::vector> &all_gather_calls, + const std::vector> &sparse_reduce_calls) { + this->RunAndRecordEvent([&] { + if (all_gather_calls.size() == 1UL) { + // Do not use NCCLGroup when manage NCCL by per thread per device + all_gather_calls[0](); + } else { + platform::NCCLGroupGuard guard; + for (auto &call : all_gather_calls) { + call(); + } + } + + for (auto &call : sparse_reduce_calls) { + call(); + } + }); + + SyncNCCLAllReduce(); } int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) { diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h index d15814a2197..87136ea0d5a 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h @@ -43,6 +43,9 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { int GetKValue(const std::string &grad_name); bool IsEncoded(); void RunImplEncoded(); + void SparseAllReduceFunc( + const std::vector> &all_gather_calls, + const std::vector> &sparse_reduce_calls); private: bool is_encoded_{false}; diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc index 224ab21b478..23c463eca41 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc @@ -465,8 +465,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, new details::SparseAllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), scopes, places, multi_nccl_ctxs_, is_encoded, - static_cast(strategy_.trainers_endpoints_.size()) * - places_.size())); + strategy_.num_trainers_ * places_.size())); } else { result->Get(kGraphOps).emplace_back( new details::AllReduceOpHandle( diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index be6519b1890..04687b0b389 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -32,9 +32,6 @@ limitations under the License. */ #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/piece.h" -#if defined(PADDLE_WITH_DGC) -#include "dgc/dgc.h" -#endif DECLARE_int32(paddle_num_threads); DEFINE_int32(multiple_of_cupti_buffer_size, 1, @@ -51,10 +48,6 @@ namespace framework { std::once_flag gflags_init_flag; std::once_flag p2p_init_flag; -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) -std::once_flag dgc_init_flag; -#endif - void InitGflags(std::vector argv) { std::call_once(gflags_init_flag, [&]() { FLAGS_logtostderr = true; @@ -229,15 +222,5 @@ void InitGLOG(const std::string &prog_name) { #endif } -#if defined(PADDLE_WITH_DGC) -void InitDGC() { - std::call_once(dgc_init_flag, []() { - PADDLE_ENFORCE(paddle::communication::dgc::dynloadNcclLib()); - }); -} -#else -void InitDGC() {} -#endif - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/platform/init.h b/paddle/fluid/platform/init.h index d25e79e78fa..d189f0022bf 100644 --- a/paddle/fluid/platform/init.h +++ b/paddle/fluid/platform/init.h @@ -30,8 +30,6 @@ void InitDevices(bool init_p2p); void InitDevices(bool init_p2p, const std::vector devices); -void InitDGC(); - #ifndef _WIN32 void SignalHandle(const char *data, int size); #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index cb6a77f29f5..7161efeb539 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1352,7 +1352,6 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); - m.def("init_dgc", framework::InitDGC); m.def("load_op_library", framework::LoadOpLib); m.def("init_devices", [](bool init_p2p) { framework::InitDevices(init_p2p); }); diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index fa5dd3673dc..26b8e2c3b12 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -271,7 +271,6 @@ class CollectiveOptimizer(DistributedOptimizer): node_num = self._node_num() assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num - self._strategy.fuse_all_reduce_ops = True exec_strategy = self._strategy.exec_strategy if node_num <= 1: diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index c85433576c0..4405200e7ac 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -959,8 +959,6 @@ class DGCMomentumOptimizer(MomentumOptimizer): super(DGCMomentumOptimizer, self).__init__( learning_rate, momentum, use_nesterov, regularization, name) - core.init_dgc() - def _add_auto_increment_var(self, counter_name, begin, step=1): helper = LayerHelper('global_step_counter') counter, is_new_var = helper.create_or_get_global_variable( diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 49cf07d67b2..0c18fea1bc9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -291,6 +291,10 @@ class TestDistRunnerBase(object): build_stra.num_trainers = 1 build_stra.trainer_id = 0 + if args.use_dgc: + # fuse_all_reduce_ops require that gradients should not be sparse types + build_stra.fuse_all_reduce_ops = False + print_to_err(type(self).__name__, "begin to compile with data parallel") binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( loss_name=avg_cost.name, @@ -572,7 +576,8 @@ class TestDistBase(unittest.TestCase): check_error_log=False, batch_size=DEFAULT_BATCH_SIZE, batch_merge_repeat=1, - log_name=""): + log_name="", + gpus="0"): cmd = self._python_interp @@ -592,13 +597,17 @@ class TestDistBase(unittest.TestCase): if self.__use_cuda: cmd += " --use_cuda" env_local = { - "CUDA_VISIBLE_DEVICES": "0", + "CUDA_VISIBLE_DEVICES": gpus, "PADDLE_TRAINERS_NUM": "1", "PADDLE_TRAINER_ID": "0" } else: env_local = {'CPU_NUM': '1'} + # not use dgc in single card + if len(gpus) > 1 and self._use_dgc: + cmd += " --use_dgc" + env_local.update(envs) print("local_cmd: {}, env: {}".format(cmd, env_local)) @@ -825,12 +834,7 @@ class TestDistBase(unittest.TestCase): print("outs[1]:", outs[1]) return pickle.loads(outs[0]), pickle.loads(outs[1]) - def check_with_place(self, - model_file, - delta=1e-3, - check_error_log=False, - need_envs={}, - log_name=""): + def _get_required_envs(self, check_error_log=False, need_envs={}): # TODO(typhoonzero): should auto adapt GPU count on the machine. required_envs = { "PATH": os.getenv("PATH", ""), @@ -844,13 +848,24 @@ class TestDistBase(unittest.TestCase): "NCCL_SHM_DISABLE": "1" } - required_envs.update(need_envs) - if check_error_log: required_envs["GLOG_vmodule"] = \ - "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10" + "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \ + "alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \ + "sparse_all_reduce_op_handle=10" required_envs["GLOG_logtostderr"] = "1" + required_envs.update(need_envs) + return required_envs + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}, + log_name=""): + required_envs = self._get_required_envs(check_error_log, need_envs) + local_losses \ = self._run_local(model_file, required_envs, check_error_log, log_name=log_name) @@ -881,3 +896,38 @@ class TestDistBase(unittest.TestCase): dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2 print("=======", local_loss, ":", dist_loss[0], "=======") self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta) + + def check_with_place_multi_cards(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}, + log_name=""): + # need open p2p or shm otherwise multi cards mode will hang + need_envs.update({"NCCL_P2P_DISABLE": "0", "NCCL_SHM_DISABLE": "0"}) + + required_envs = self._get_required_envs(check_error_log, need_envs) + + if self._use_dgc: + multi_cards_losses = self._run_local( + model_file, + required_envs, + check_error_log, + log_name=log_name + "_dgc_2cards", + gpus="0,1") + + self._use_dgc = False + base_losses = self._run_local( + model_file, + required_envs, + check_error_log, + log_name=log_name + "_base_2cards", + gpus="0,1") + + self._use_dgc = True + + for step_id in range(RUN_STEP): + base_loss = base_losses[step_id] + multi_cards_loss = multi_cards_losses[step_id] + print("=======", base_loss, ":", multi_cards_loss, "=======") + self.assertAlmostEqual(base_loss, multi_cards_loss, delta=delta) diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py index aaa43ec10bd..43e60a9eba6 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py @@ -38,5 +38,23 @@ class TestDistMnistNCCL2DGC(TestDistBase): log_name=flag_name) +class TestDistMnistNCCL2DGCMultiCards(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + self._use_dgc = True + + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place_multi_cards( + "dist_mnist.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() -- GitLab