From 9ac785be396bd21d3f152a299f5fa7cb5e268e08 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 7 Jun 2018 15:40:58 +0800 Subject: [PATCH] check graph's validation --- .../details/multi_devices_graph_builder.cc | 1 - .../framework/details/ssa_graph_builder.cc | 70 ++++++++++++++++++- .../framework/details/ssa_graph_builder.h | 3 + .../details/threaded_ssa_graph_executor.cc | 1 + 4 files changed, 73 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0c4d369e8..81d5b079b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -272,7 +272,6 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - return std::unique_ptr(graph); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 211113c79..d70f95a9f 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -11,8 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include namespace paddle { namespace framework { @@ -83,6 +83,74 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { op->AddOutput(dummy_leaf); } } + +std::unique_ptr SSAGraphBuilder::BuildAndCheck( + const ProgramDesc &program) final { + std::unique_ptr graph = Build(program); + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return std::move(graph); +} + +bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { + std::unordered_map pending_ops; + std::unordered_set pending_vars; + std::unordered_set ready_vars; + std::unordered_set ready_ops; + + auto insert_pending_var = [&](VarHandleBase *var) { + pending_vars.insert(var); + if (var->generated_op_ == nullptr) { + ready_vars.emplace(var); + } + }; + + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + insert_pending_var(version_pair.get()); + } + } + } + + for (auto &var : graph->dep_vars_) { + insert_pending_var(var.get()); + } + + for (auto &op : graph->ops_) { + if (op->Inputs().empty()) { + ready_ops.insert(op.get()); + } else { + pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); + } + } + + auto run_all_ops = [&](std::unordered_set &set) { + for (auto *op : set) { + for (auto out : op->Outputs()) { + ready_vars.emplace(out); + } + } + set.clear(); + }; + + while (!pending_vars.empty()) { + run_all_ops(ready_ops); + if (ready_vars.empty()) { + return false; + } + for (auto ready_var : ready_vars.) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = --pending_ops[op]; + if (deps == 0) { + ready_ops.insert(op); + } + } + } + ready_vars.clear(); + } + return true; +} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 5fc12a44b..da9298ac8 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -31,6 +31,8 @@ class SSAGraphBuilder { virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + std::unique_ptr BuildAndCheck(const ProgramDesc &program) final; + DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: @@ -48,6 +50,7 @@ class SSAGraphBuilder { const platform::Place &place, size_t place_offset); + bool IsValidGraph(const SSAGraph *graph) const; // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 496fadd04..bcbf57362 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ready_vars->Push(var); } } + void ThreadedSSAGraphExecutor::RunOp( BlockingQueue *ready_var_q, details::OpHandleBase *op) { auto op_run = [ready_var_q, op, this] { -- GitLab