diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cf5968993fd4f31f8b6ff6ae482b1fe50310a00f..99dcaf27134e879fa85f57e5a675382442e9edf2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -200,6 +200,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( BuildStrategy::GradientScaleStrategy::kCustomized) { CreateScaleLossGradOp(&result); } + // This assumes the backward generating code will ensure IsScaleLossOp + // is true only for the op that scale the final scalar loss. + // It also assumes backward op will always follow the forward op in + // the block. is_forwarding = false; } else { int op_dev_id = GetOpDeviceID(*op); @@ -244,6 +248,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( InsertAllReduceOp(&result, g_name); } break; + default: + LOG(FATAL) << "Unknown reduce strategy "; + break; } } } catch (boost::bad_get e) { @@ -262,7 +269,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } /* Dependency graph has been constructed. However, there are still data - harzaeds need to be handled. + hazards need to be handled. */ PolishGraphToSupportDataHazards(&result); @@ -447,6 +454,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } +// Find the first occurence of `prev_op_name` and make current `op` depend +// on it. void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->ops_) { @@ -490,6 +499,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, } } +// Create RPC related op handles that connects its in ops and out ops. void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { int op_dev_id = -1; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 6c5098ce85b784a3edcf8f48d2cc828aabd8e161..b1706eb12d080364d04108c7ef4da31e1e7c1deb 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { + std::lock_guard l(exception_mu_); if (exception_) { auto exp = *exception_; exception_.reset(); @@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp( ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; } catch (platform::EnforceNotMet ex) { + std::lock_guard l(exception_mu_); exception_.reset(new platform::EnforceNotMet(ex)); } catch (...) { LOG(FATAL) << "Unknown exception catched"; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 4a2075f1cccb3211316567197da56c01d26f35ce..90430be996758364387b552019762d9c2e9dfe45 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::vector places_; platform::DeviceContextPool fetch_ctxs_; + std::mutex exception_mu_; std::unique_ptr exception_; std::atomic running_ops_;