From df31926fcfa799a88666dd8d6e5648c1d32332e4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 20 Jun 2018 21:46:52 +0800 Subject: [PATCH] small thread-safety fix and doc improvements. --- .../framework/details/multi_devices_graph_builder.cc | 12 +++++++++++- .../framework/details/threaded_ssa_graph_executor.cc | 2 ++ .../framework/details/threaded_ssa_graph_executor.h | 1 + 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 78356cb1be3..57e2f4265a8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -199,6 +199,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( BuildStrategy::GradientScaleStrategy::kCustomized) { CreateScaleLossGradOp(&result); } + // This assumes the backward generating code will ensure IsScaleLossOp + // is true only for the op that scale the final scalar loss. + // It also assumes backward op will always follow the forward op in + // the block. is_forwarding = false; } else { int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); @@ -243,6 +247,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( InsertAllReduceOp(&result, g_name); } break; + default: + LOG(FATAL) << "Unknown reduce strategy "; + break; } } } catch (boost::bad_get e) { @@ -261,7 +268,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } /* Dependency graph has been constructed. However, there are still data - harzaeds need to be handled. + hazards need to be handled. */ PolishGraphToSupportDataHazards(&result); @@ -449,6 +456,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } +// Find the first occurence of `prev_op_name` and make current `op` depend +// on it. void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->ops_) { @@ -469,6 +478,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, } } +// Create RPC related op handles that connects its in ops and out ops. void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { result->ops_.emplace_back( diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 6c5098ce85b..b1706eb12d0 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { + std::lock_guard l(exception_mu_); if (exception_) { auto exp = *exception_; exception_.reset(); @@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp( ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; } catch (platform::EnforceNotMet ex) { + std::lock_guard l(exception_mu_); exception_.reset(new platform::EnforceNotMet(ex)); } catch (...) { LOG(FATAL) << "Unknown exception catched"; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 4a2075f1ccc..90430be9967 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::vector places_; platform::DeviceContextPool fetch_ctxs_; + std::mutex exception_mu_; std::unique_ptr exception_; std::atomic running_ops_; -- GitLab