From 5c7a523326b98b9c4fee1eca0c0c74e3112bc19a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 26 Mar 2018 11:50:52 +0800 Subject: [PATCH] Add Graphviz output --- .../details/computation_op_handle.cc | 2 + .../framework/details/computation_op_handle.h | 2 + .../framework/details/fetch_op_handle.cc | 2 + .../fluid/framework/details/fetch_op_handle.h | 2 + .../details/multi_devices_graph_builder.cc | 6 ++ .../details/nccl_all_reduce_op_handle.cc | 2 + .../details/nccl_all_reduce_op_handle.h | 2 + .../fluid/framework/details/op_handle_base.h | 2 + .../details/scale_loss_grad_op_handle.cc | 2 + .../details/scale_loss_grad_op_handle.h | 2 + .../framework/details/ssa_graph_builder.cc | 58 +++++++++++++++++++ .../framework/details/ssa_graph_builder.h | 2 + .../details/threaded_ssa_graph_executor.cc | 6 ++ .../tests/unittests/test_parallel_executor.py | 2 +- 14 files changed, 91 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 5867f8fc554..348b944cf92 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -35,6 +35,8 @@ void ComputationOpHandle::RunImpl() { op_->Run(*scope_, place_); } + +std::string ComputationOpHandle::Name() const { return op_->Type(); } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 1fbfd4eabe0..d6d2d731ca8 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -31,6 +31,8 @@ struct ComputationOpHandle : public OpHandleBase { ComputationOpHandle(const OpDesc &op_desc, Scope *scope, platform::Place place); + std::string Name() const override; + protected: void RunImpl() override; }; diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index ab552081a4a..c697a1c9378 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -72,6 +72,8 @@ void FetchOpHandle::RunImpl() { } } +std::string FetchOpHandle::Name() const { return "Fetch"; } + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index 3123f7ba232..904b2d669f8 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -38,6 +38,8 @@ struct FetchOpHandle : public OpHandleBase { void WaitAndMergeCPUTensors() const; + std::string Name() const override; + protected: void RunImpl() override; }; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index b27647a8eeb..cb02d36714d 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -136,6 +136,12 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( */ PolishGraphToSupportDataHazards(&result); + if (VLOG_IS_ON(10)) { + std::ostringstream sout; + PrintGraphviz(*graph, sout); + VLOG(10) << sout.str(); + } + return std::unique_ptr(graph); } } // namespace details diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index a79c61f3593..f2303ff4cab 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -69,6 +69,8 @@ void NCCLAllReduceOpHandle::RunImpl() { } } } + +std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index 7152d1a587e..045070bb6a9 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -32,6 +32,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { const std::vector &places, const platform::NCCLContextMap &ctxs); + std::string Name() const override; + protected: void RunImpl() override; }; diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 5178b51d8d7..99d89684867 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -33,6 +33,8 @@ struct OpHandleBase { std::string DebugString() const; + virtual std::string Name() const = 0; + virtual ~OpHandleBase(); void Run(bool use_event); diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index 2e69f1e5e84..a6a67c9b145 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -45,6 +45,8 @@ void ScaleLossGradOpHandle::RunImpl() { #endif } } + +std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 3a355749192..ab7353a4fc5 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -32,6 +32,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ~ScaleLossGradOpHandle() final; + std::string Name() const override; + protected: void RunImpl() override; }; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 7a80a4b1e73..e0209fce76b 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -83,6 +83,64 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, var.place_ = place; op_handle->AddOutput(&var); } + +template +void IterAllVar(const SSAGraph &graph, Callback callback) { + for (auto &each : graph.vars_) { + for (auto &pair1 : each) { + for (auto &pair2 : pair1.second) { + callback(pair2.second); + } + } + } + + for (auto &var : graph.dep_vars_) { + callback(*var); + } +} + +void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { + size_t var_id = 0; + std::unordered_map vars; + + sout << "digraph G {\n"; + + IterAllVar(graph, [&](const VarHandleBase &var) { + auto *var_ptr = &var; + auto *var_handle_ptr = dynamic_cast(var_ptr); + auto *dummy_ptr = dynamic_cast(var_ptr); + + size_t cur_var_id = var_id++; + vars[var_ptr] = cur_var_id; + + if (var_handle_ptr) { + sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_ + << "\\n" + << var_handle_ptr->place_ << "\\n" + << var_handle_ptr->version_ << "\"]" << std::endl; + } else if (dummy_ptr) { + sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl; + } + }); + + size_t op_id = 0; + for (auto &op : graph.ops_) { + std::string op_name = "op_" + std::to_string(op_id++); + sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" + << std::endl; + for (auto in : op->inputs_) { + std::string var_name = "var_" + std::to_string(vars[in]); + sout << var_name << " -> " << op_name << std::endl; + } + + for (auto out : op->outputs_) { + std::string var_name = "var_" + std::to_string(vars[out]); + sout << op_name << " -> " << var_name << std::endl; + } + } + + sout << "}\n"; +} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index df05bb73942..bf20e7164a1 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -51,6 +51,8 @@ class SSAGraphBuilder { static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset); + + static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout); }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 86e880ed72e..f609395d40f 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -133,6 +133,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( if (exception_) { throw * exception_; } + + VLOG(10) << "============================="; + for (auto &op : pending_ops) { + VLOG(10) << op.first->DebugString(); + } + // keep waiting the ready variables continue; } diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 2ebdbaaca65..dd6e70eadbd 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -48,7 +48,7 @@ def fc_with_batchnorm(): dtypes=['float32', 'int64']) img, label = fluid.layers.read_file(reader) hidden = img - for _ in xrange(4): + for _ in xrange(1): hidden = fluid.layers.fc( hidden, size=200, -- GitLab