diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 414b0970c73d335f26995e291b26662d3a5f8362..47872a9f2af93e261b09fa069ddf712d1aa215ad 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -19,6 +19,13 @@ #include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/platform/profiler.h" +// async nccl allreduce or sync issue: +// https://github.com/PaddlePaddle/Paddle/issues/15049 +DEFINE_bool( + sync_nccl_allreduce, true, + "If set true, will call `cudaStreamSynchronize(nccl_stream)`" + "after allreduce, this mode can get better performance in some scenarios."); + namespace paddle { namespace framework { namespace details { @@ -48,111 +55,107 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, void AllReduceOpHandle::RunImpl() { platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second); -// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR, -// this is a distributed or inter-process call, find a better way. -#ifdef PADDLE_WITH_CUDA - // All-reduce op_handle can run on the sub-scope, find the nccl id from - // the global scope. - if (NoDummyInputSize() == 1 && - local_scopes_[0]->FindVar(NCCL_ID_VARNAME) == nullptr) { -#else - if (NoDummyInputSize() == 1) { -#endif - return; // No need to all reduce when GPU count = 1; - } else { - // Wait input done - WaitInputVarGenerated(); - auto in_var_handles = DynamicCast(this->Inputs()); - auto out_var_handles = DynamicCast(this->Outputs()); - PADDLE_ENFORCE_EQ( - in_var_handles.size(), places_.size(), - "The NoDummyInputSize should be equal to the number of places."); - PADDLE_ENFORCE_EQ( - in_var_handles.size(), out_var_handles.size(), - "The NoDummyInputSize and NoDummyOutputSize should be equal."); - - std::vector lod_tensors; - for (size_t i = 0; i < local_scopes_.size(); ++i) { - auto *s = local_scopes_[i]; - auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get(); - auto &lod_tensor = - local_scope.FindVar(in_var_handles[i]->name_)->Get(); - lod_tensors.emplace_back(&lod_tensor); - PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, - "The name of input and output should be equal."); - } + // FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR, + // this is a distributed or inter-process call, find a better way. + // Wait input done + WaitInputVarGenerated(); + auto in_var_handles = DynamicCast(this->Inputs()); + auto out_var_handles = DynamicCast(this->Outputs()); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), places_.size(), + "The NoDummyInputSize should be equal to the number of places."); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), out_var_handles.size(), + "The NoDummyInputSize and NoDummyOutputSize should be equal."); + + std::vector lod_tensors; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto *s = local_scopes_[i]; + auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get(); + auto &lod_tensor = + local_scope.FindVar(in_var_handles[i]->name_)->Get(); + lod_tensors.emplace_back(&lod_tensor); + PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, + "The name of input and output should be equal."); + } - if (platform::is_gpu_place(lod_tensors[0]->place())) { + if (platform::is_gpu_place(lod_tensors[0]->place())) { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr."); - int dtype = -1; - size_t numel = 0; - std::vector> all_reduce_calls; - for (size_t i = 0; i < local_scopes_.size(); ++i) { - auto &p = places_[i]; - auto &lod_tensor = *lod_tensors[i]; - void *buffer = const_cast(lod_tensor.data()); - - if (dtype == -1) { - dtype = platform::ToNCCLDataType(lod_tensor.type()); - } + PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr."); + int dtype = -1; + size_t numel = 0; + std::vector> all_reduce_calls; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + auto &lod_tensor = *lod_tensors[i]; + void *buffer = const_cast(lod_tensor.data()); + + if (dtype == -1) { + dtype = platform::ToNCCLDataType(lod_tensor.type()); + } + + if (numel == 0) { + numel = static_cast(lod_tensor.numel()); + } + + int dev_id = boost::get(p).device; + auto &nccl_ctx = nccl_ctxs_->at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(dtype), ncclSum, + comm, stream)); + }); + } - if (numel == 0) { - numel = static_cast(lod_tensor.numel()); + this->RunAndRecordEvent([&] { + if (all_reduce_calls.size() == 1UL) { + // Do not use NCCLGroup when manage NCCL by per thread per device + all_reduce_calls[0](); + } else { + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); } + } + }); + if (FLAGS_sync_nccl_allreduce) { + for (auto &p : places_) { int dev_id = boost::get(p).device; auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto stream = nccl_ctx.stream(); - auto comm = nccl_ctx.comm_; - all_reduce_calls.emplace_back([=] { - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - buffer, buffer, numel, static_cast(dtype), - ncclSum, comm, stream)); - // TODO(Yancey1989): synchronize here can get better performance - // if don't use NCCL group call, but need more profiling. - if (local_scopes_.size() == 1UL) cudaStreamSynchronize(stream); - }); + cudaStreamSynchronize(stream); } - - this->RunAndRecordEvent([&] { - if (all_reduce_calls.size() == 1UL) { - all_reduce_calls[0](); - } else { - platform::NCCLGroupGuard guard; - for (auto &call : all_reduce_calls) { - call(); - } - } - }); + } #else - PADDLE_THROW("Not compiled with CUDA"); + PADDLE_THROW("Not compiled with CUDA"); #endif - } else { // Special handle CPU only Operator's gradient. Like CRF - auto &trg = *this->local_scopes_[0] - ->FindVar(kLocalExecScopeName) - ->Get() - ->FindVar(out_var_handles[0]->name_) - ->GetMutable(); - - // Reduce All Tensor to trg in CPU - ReduceLoDTensor func(lod_tensors, &trg); - VisitDataType(lod_tensors[0]->type(), func); - - for (size_t i = 1; i < local_scopes_.size(); ++i) { - auto &scope = - *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); - auto &p = places_[i]; - auto *var = scope.FindVar(out_var_handles[i]->name_); - auto *dev_ctx = dev_ctxes_.at(p); - - RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { - auto &tensor_gpu = *var->GetMutable(); - auto &tensor_cpu = trg; - TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu); - }); - } + } else { // Special handle CPU only Operator's gradient. Like CRF + auto &trg = *this->local_scopes_[0] + ->FindVar(kLocalExecScopeName) + ->Get() + ->FindVar(out_var_handles[0]->name_) + ->GetMutable(); + + // Reduce All Tensor to trg in CPU + ReduceLoDTensor func(lod_tensors, &trg); + VisitDataType(lod_tensors[0]->type(), func); + + for (size_t i = 1; i < local_scopes_.size(); ++i) { + auto &scope = + *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); + auto &p = places_[i]; + auto *var = scope.FindVar(out_var_handles[i]->name_); + auto *dev_ctx = dev_ctxes_.at(p); + + RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { + auto &tensor_gpu = *var->GetMutable(); + auto &tensor_cpu = trg; + TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu); + }); } } } diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index b927b21b6f083262db0d10c6c911df3612a32e48..cb660cb8c2c3906c023f68e66ca5538aa6f6897e 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -31,6 +31,8 @@ namespace framework { namespace details { static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { + // Should fix the allreduce op order if scheduling + // them in multiple threads or processes to avoid hang. return (!strategy.enable_sequential_execution_ && strategy.num_trainers_ > 1) || strategy.enable_parallel_graph_; @@ -88,8 +90,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { auto multi_devices_pass = AppendPass("multi_devices_pass"); multi_devices_pass->SetNotOwned("strategy", &strategy_); - multi_devices_pass->Set("num_trainers", - new int(strategy_.num_trainers_)); // Add a graph print pass to record a graph with device info. if (!strategy_.debug_graphviz_path_.empty()) { @@ -134,6 +134,7 @@ std::shared_ptr BuildStrategy::CreatePassesFromStrategy( std::unique_ptr BuildStrategy::Apply( const ProgramDesc &main_program, const std::vector &places, const std::string &loss_var_name, const std::vector &local_scopes, + const size_t &num_parallel_devices, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { #else @@ -152,6 +153,9 @@ std::unique_ptr BuildStrategy::Apply( pass->Erase("local_scopes"); pass->SetNotOwned>("local_scopes", &local_scopes); + pass->Set("num_parallel_devices", + new size_t(num_parallel_devices)); + #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; pass->Erase("nccl_ctxs"); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index f9351fb8d2d40c5bec3a63ae693790ee10704a0b..b31e60ad8e5ec353b575981718d142b086aba4de 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -112,6 +112,7 @@ struct BuildStrategy { const std::vector &places, const std::string &loss_var_name, const std::vector &local_scopes, + const size_t &num_parallel_devices_, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const; diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 0be81a48ff7e74e373d8ec79034ea09ddaebe376..a6d583777ace4069747ba2e80fc4c73fa7d6a118 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -132,7 +132,7 @@ static const char kLossVarName[] = "loss_var_name"; static const char kPlaces[] = "places"; static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; -static const char kNumTrainers[] = "num_trainers"; +static const char kNumParallelDevices[] = "num_parallel_devices"; void MultiDevSSAGraphBuilder::Init() const { all_vars_.clear(); @@ -296,7 +296,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; - int num_trainers = Get(kNumTrainers); + size_t num_parallel_devices = Get(kNumParallelDevices); for (auto &node : nodes) { if (node->IsVar() && node->Var()) { @@ -382,16 +382,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateComputationalOps(&result, node, places_.size()); } -// insert collective ops at the backpropagation; and -// insert collective ops if the graph contains mutilple places. - -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - if (!is_forwarding && - (places_.size() > 1 || num_trainers > 1 || - (nccl_ctxs_ && nccl_ctxs_->contexts_.size() > 1))) { -#else - if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) { -#endif + if (!is_forwarding && num_parallel_devices > 1) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. if (static_cast(boost::get(node->Op()->GetAttr( @@ -668,12 +659,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID( void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name, ir::Node *out_var_node) const { + size_t num_parallel_devices = Get("num_parallel_devices"); for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), - local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx); + num_parallel_devices, local_scopes_[i], places_[i], dev_ctx); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -903,4 +895,4 @@ REGISTER_PASS(multi_devices_pass, .RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kLocalScopes) .RequirePassAttr(paddle::framework::details::kStrategy) - .RequirePassAttr(paddle::framework::details::kNumTrainers); + .RequirePassAttr(paddle::framework::details::kNumParallelDevices); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 1637ee3c7e77f767efad6cde20eec15119cf1a7a..ec44cae3b31cbd27690069b252edd4c069ca0e5a 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -107,6 +107,7 @@ class ParallelExecutorPrivate { bool own_local_scope_; bool use_cuda_; bool use_all_reduce_; + size_t num_parallel_devices_; // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and // then keeps unchanged @@ -202,6 +203,7 @@ ParallelExecutor::ParallelExecutor( member_->build_strategy_ = build_strategy; member_->use_all_reduce_ = build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; + member_->num_parallel_devices_ = num_trainers * places.size(); if (!member_->use_all_reduce_) { PADDLE_ENFORCE(places.size() > 1, @@ -212,12 +214,12 @@ ParallelExecutor::ParallelExecutor( if (build_strategy.enable_parallel_graph_) { PADDLE_ENFORCE( member_->use_all_reduce_, - "build_strategy.reduce should be `AllReduce` if you want to use" - "ParallelGraph executor."); + "build_strategy.reduce should be `AllReduce` if you want to enable" + "ParallelGraph."); PADDLE_ENFORCE( member_->use_cuda_, - "execution_strategy.use_cuda should be True if you want to use" - "ParallelGraph executor."); + "execution_strategy.use_cuda should be True if you want to enable " + "ParallelGraph."); } // Step 1. Bcast the bcast_vars to devs. @@ -241,27 +243,43 @@ ParallelExecutor::ParallelExecutor( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); ncclUniqueId *nccl_id = nullptr; + // nccl collective would broadcast nccl id by gen_nccl_id operator. + if (nccl_id_var != nullptr) { + nccl_id = nccl_id_var->GetMutable(); + } + if (build_strategy.enable_parallel_graph_ && places.size() > 1) { - // parallel graph mode should initialize nccl by ncclCommInitRank since - // it call nccl operator per device per thread. - if (nccl_id_var == nullptr) { + if (nccl_id == nullptr) { nccl_id = new ncclUniqueId(); PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id)); - *member_->global_scope_->Var(NCCL_ID_VARNAME) - ->GetMutable() = *nccl_id; - } else { - nccl_id = nccl_id_var->GetMutable(); } - } else if (nccl_id_var != nullptr) { // the other executor type. - // the distributed training with nccl mode would initialize the nccl id in - // startup_program. - nccl_id = nccl_id_var->GetMutable(); - } else { - // initlize NCCL by ncclCommInitAll, do not need to intialize the nccl_id. } - member_->nccl_ctxs_.reset(new platform::NCCLContextMap( member_->places_, nccl_id, num_trainers, trainer_id)); + +/** +if (build_strategy.enable_parallel_graph_ && places.size() > 1) { + // parallel graph mode should initialize nccl by ncclCommInitRank since + // it call nccl operator per device per thread. + if (nccl_id_var == nullptr) { + nccl_id = new ncclUniqueId(); + PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id)); + *member_->global_scope_->Var(NCCL_ID_VARNAME) + ->GetMutable() = *nccl_id; + } else { + nccl_id = nccl_id_var->GetMutable(); + } +} else if (nccl_id_var != nullptr) { // the other executor type. + // the distributed training with nccl mode would initialize the nccl id in + // startup_program. + nccl_id = nccl_id_var->GetMutable(); +} else { + // initlize NCCL by ncclCommInitAll, do not need to intialize the nccl_id. +} + +member_->nccl_ctxs_.reset(new platform::NCCLContextMap( + member_->places_, nccl_id, num_trainers, trainer_id)); +**/ #else PADDLE_THROW("Not compiled with CUDA"); #endif @@ -274,25 +292,27 @@ ParallelExecutor::ParallelExecutor( // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp std::vector> graphs; + member_->num_parallel_devices_ = member_->places_.size() * num_trainers; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) if (build_strategy.enable_parallel_graph_) { for (size_t i = 0; i < member_->places_.size(); ++i) { - std::unique_ptr graph = - build_strategy.Apply(main_program, {member_->places_[i]}, - loss_var_name, {member_->local_scopes_[i]}, - member_->use_cuda_, member_->nccl_ctxs_.get()); + std::unique_ptr graph = build_strategy.Apply( + main_program, {member_->places_[i]}, loss_var_name, + {member_->local_scopes_[i]}, member_->num_parallel_devices_, + member_->use_cuda_, member_->nccl_ctxs_.get()); graphs.push_back(std::move(graph)); } } else { std::unique_ptr graph = build_strategy.Apply( main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->use_cuda_, member_->nccl_ctxs_.get()); + member_->num_parallel_devices_, member_->use_cuda_, + member_->nccl_ctxs_.get()); graphs.push_back(std::move(graph)); } #else - std::unique_ptr graph = - build_strategy.Apply(main_program, member_->places_, loss_var_name, - member_->local_scopes_, member_->use_cuda_); + std::unique_ptr graph = build_strategy.Apply( + main_program, member_->places_, loss_var_name, member_->local_scopes_, + member_->num_parallel_devices_, member_->use_cuda_); graphs.push_back(std::move(graph)); #endif auto max_memory_size = GetEagerDeletionThreshold(); diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index a2c8ee120f9f1f21022c8191211f69ff91c4d1d4..36b13d45582e912fb23b85faa2afb7b7e4c7771d 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -60,71 +60,69 @@ class TestParallelExecutorBase(unittest.TestCase): startup = fluid.Program() startup.random_seed = 1 # Fix random seed main.random_seed = 1 - self.scope = fluid.Scope() - with fluid.scope_guard(self.scope): - with fluid.program_guard(main, startup): - if seed is not None: - startup.random_seed = seed - main.random_seed = seed - - loss = method(use_feed=feed_dict is not None) - - optimizer().minimize(loss) - - if memory_opt: - fluid.memory_optimize(main) - - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - startup_exe = fluid.Executor(place) - startup_exe.run(startup) - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.allow_op_delay = allow_op_delay - if use_fast_executor: - exec_strategy.use_experimental_executor = True - build_strategy.enable_parallel_graph = use_parallel_graph - build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ - if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce - build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops - build_strategy.memory_optimize = use_ir_memory_optimize - build_strategy.enable_sequential_execution = enable_sequential_execution - if use_cuda and core.is_compiled_with_cuda(): - build_strategy.remove_unnecessary_lock = True - - if use_parallel_executor: - exe = fluid.ParallelExecutor( - use_cuda, - loss_name=loss.name, - exec_strategy=exec_strategy, - build_strategy=build_strategy) - else: - exe = fluid.Executor(place=place) - - if batch_size is not None: - batch_size *= fluid.core.get_cuda_device_count( - ) if use_cuda else int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - begin = time.time() - first_loss, = run_executor( - exe=exe, feed=feed_dict, fetch_list=[loss.name]) - - for i in range(iter): - run_executor(exe=exe, feed=feed_dict, fetch_list=[]) - - last_loss, = run_executor( - exe=exe, feed=feed_dict, fetch_list=[loss.name]) - end = time.time() - - if batch_size is not None: - print("%.4f Instance per second" % ( - (batch_size * iter + 2) / (end - begin))) - - avg_last_loss_val = np.array(last_loss).mean() - avg_first_loss_val = np.array(first_loss).mean() - if math.isnan(float(avg_last_loss_val)) or math.isnan( - float(avg_first_loss_val)): - sys.exit("got NaN loss, training failed.") - - print(first_loss, last_loss) - # self.assertGreater(first_loss[0], last_loss[0]) - return first_loss, last_loss + with fluid.program_guard(main, startup): + if seed is not None: + startup.random_seed = seed + main.random_seed = seed + + loss = method(use_feed=feed_dict is not None) + + optimizer().minimize(loss) + + if memory_opt: + fluid.memory_optimize(main) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + startup_exe = fluid.Executor(place) + startup_exe.run(startup) + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.allow_op_delay = allow_op_delay + if use_fast_executor: + exec_strategy.use_experimental_executor = True + build_strategy = fluid.BuildStrategy() + build_strategy.enable_parallel_graph = use_parallel_graph + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ + if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce + build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops + build_strategy.memory_optimize = use_ir_memory_optimize + build_strategy.enable_sequential_execution = enable_sequential_execution + if use_cuda and core.is_compiled_with_cuda(): + build_strategy.remove_unnecessary_lock = True + + if use_parallel_executor: + exe = fluid.ParallelExecutor( + use_cuda, + loss_name=loss.name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + else: + exe = fluid.Executor(place=place) + + if batch_size is not None: + batch_size *= fluid.core.get_cuda_device_count( + ) if use_cuda else int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + begin = time.time() + first_loss, = run_executor( + exe=exe, feed=feed_dict, fetch_list=[loss.name]) + + for i in range(iter): + run_executor(exe=exe, feed=feed_dict, fetch_list=[]) + + last_loss, = run_executor( + exe=exe, feed=feed_dict, fetch_list=[loss.name]) + end = time.time() + + if batch_size is not None: + print("%.4f Instance per second" % ( + (batch_size * iter + 2) / (end - begin))) + + avg_last_loss_val = np.array(last_loss).mean() + avg_first_loss_val = np.array(first_loss).mean() + if math.isnan(float(avg_last_loss_val)) or math.isnan( + float(avg_first_loss_val)): + sys.exit("got NaN loss, training failed.") + + print(first_loss, last_loss) + # self.assertGreater(first_loss[0], last_loss[0]) + return first_loss, last_loss diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py index d75761153c0a119b85d918b188928dbe17f4fc51..3e4490aa58e819e49bcc3c452bc1acd69137c76c 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py @@ -175,44 +175,65 @@ class TestCRFModel(unittest.TestCase): print(pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name])[0]) - def test_update_sparse_parameter_all_reduce(self): + def _new_build_strategy(self, use_reduce=False, use_parallel_graph=False): build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce + + if use_reduce: + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce + else: + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce + build_strategy.enable_parallel_graph = use_parallel_graph + + return build_strategy + + def test_update_sparse_parameter_all_reduce(self): if core.is_compiled_with_cuda(): self.check_network_convergence( - is_sparse=True, build_strategy=build_strategy, use_cuda=True) - self.check_network_convergence( - is_sparse=True, build_strategy=build_strategy, use_cuda=True) + is_sparse=True, + build_strategy=self._new_build_strategy(), + use_cuda=True) self.check_network_convergence( - is_sparse=True, build_strategy=build_strategy, use_cuda=False) + is_sparse=True, + build_strategy=self._new_build_strategy(), + use_cuda=False) def test_update_dense_parameter_all_reduce(self): - build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce if core.is_compiled_with_cuda(): self.check_network_convergence( - is_sparse=False, build_strategy=build_strategy, use_cuda=True) + is_sparse=False, + build_strategy=self._new_build_strategy(), + use_cuda=True) + self.check_network_convergence( + is_sparse=False, + build_strategy=self._new_build_strategy( + use_parallel_graph=True), + use_cuda=True) + self.check_network_convergence( is_sparse=False, build_strategy=build_strategy, use_cuda=False) def test_update_sparse_parameter_reduce(self): - build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce if core.is_compiled_with_cuda(): self.check_network_convergence( - is_sparse=True, build_strategy=build_strategy, use_cuda=True) + is_sparse=True, + build_strategy=self._new_build_strategy(use_reduce=True), + use_cuda=True) self.check_network_convergence( - is_sparse=True, build_strategy=build_strategy, use_cuda=False) + is_sparse=True, + build_strategy=self._new_build_strategy(use_reduce=True), + use_cuda=False) def test_update_dense_parameter_reduce(self): - build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce if core.is_compiled_with_cuda(): self.check_network_convergence( - is_sparse=False, build_strategy=build_strategy, use_cuda=True) + is_sparse=False, + build_strategy=self._new_build_strategy(use_reduce=True), + use_cuda=True) self.check_network_convergence( - is_sparse=False, build_strategy=build_strategy, use_cuda=False) + is_sparse=False, + build_strategy=self._new_build_strategy(use_reduce=True), + use_cuda=False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index 531c99a8358a851a328d756609f61490a5cfe208..5515ff0bb20e9e6b93d892f598af558cc21b4d80 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -312,7 +312,7 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=use_reduce, - optimizer=optimizer(lr_scale=lr_scale), + optimizer=optimizer(), use_parallel_graph=use_parallel_graph) self.assertAlmostEquals(