From 0c851cab2294356dd292b9b4458379d1bde4eadd Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 8 Jun 2018 15:02:33 +0800 Subject: [PATCH] add SSA graph checker --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../framework/details/ssa_graph_builder.cc | 70 --------------- .../framework/details/ssa_graph_builder.h | 3 - .../details/ssa_graph_builder_factory.cc | 3 + .../framework/details/ssa_graph_checker.cc | 87 +++++++++++++++++++ .../framework/details/ssa_graph_checker.h | 44 ++++++++++ paddle/fluid/framework/parallel_executor.cc | 2 +- 7 files changed, 137 insertions(+), 75 deletions(-) create mode 100644 paddle/fluid/framework/details/ssa_graph_checker.cc create mode 100644 paddle/fluid/framework/details/ssa_graph_checker.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index ced063a0977..dbd118d3382 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -8,6 +8,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) +cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) @@ -30,7 +31,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) -cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) +cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index c4ee088507b..88a21f48879 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -83,76 +83,6 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { op->AddOutput(dummy_leaf); } } - -std::unique_ptr SSAGraphBuilder::BuildAndCheck( - const ProgramDesc &program) { - std::unique_ptr graph = Build(program); - PADDLE_ENFORCE(IsValidGraph(graph.get())); - return std::move(graph); -} - -bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { - std::unordered_map pending_ops; - std::unordered_set pending_vars; - std::unordered_set ready_vars; - std::unordered_set ready_ops; - - auto insert_pending_var = [&](VarHandleBase *var) { - pending_vars.insert(var); - if (var->generated_op_ == nullptr) { - ready_vars.emplace(var); - } - }; - - for (auto &var_map : graph->vars_) { - for (auto &name_pair : var_map) { - for (auto &version_pair : name_pair.second) { - insert_pending_var(version_pair.get()); - } - } - } - - for (auto &var : graph->dep_vars_) { - insert_pending_var(var.get()); - } - - for (auto &op : graph->ops_) { - if (op->Inputs().empty()) { - ready_ops.insert(op.get()); - } else { - pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); - } - } - - auto run_all_ops = [&](std::unordered_set &set) { - for (auto *op : set) { - for (auto out : op->Outputs()) { - ready_vars.emplace(out); - } - } - set.clear(); - }; - - while (!pending_vars.empty()) { - run_all_ops(ready_ops); - - if (ready_vars.empty()) { - return false; - } - - for (auto ready_var : ready_vars) { - pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { - auto &deps = --pending_ops[op]; - if (deps == 0) { - ready_ops.insert(op); - } - } - } - ready_vars.clear(); - } - return true; -} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index e99a9884075..5fc12a44b51 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -31,8 +31,6 @@ class SSAGraphBuilder { virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; - std::unique_ptr BuildAndCheck(const ProgramDesc &program); - DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: @@ -50,7 +48,6 @@ class SSAGraphBuilder { const platform::Place &place, size_t place_offset); - bool IsValidGraph(const SSAGraph *graph) const; // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc index b5e90d6b056..b4b49d3de6d 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/ssa_graph_checker.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h" namespace paddle { @@ -40,6 +41,8 @@ std::unique_ptr SSAGraphBuilderFactory::Create() { res.reset(new SSAGraghBuilderWithPrinter( std::move(fout), std::move(graphviz_printer), std::move(res))); } + res.reset(new SSAGraghBuilderWithChecker(std::move(res))); + return res; } } // namespace details diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc new file mode 100644 index 00000000000..da5428946ee --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph.h" +#include +#include "paddle/fluid/framework/details/ssa_graph_checker.h" + +namespace paddle { +namespace framework { +namespace details { + +bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { + std::unordered_map pending_ops; + std::unordered_set pending_vars; + std::unordered_set ready_vars; + std::unordered_set ready_ops; + + auto insert_pending_var = [&](VarHandleBase *var) { + pending_vars.insert(var); + if (var->generated_op_ == nullptr) { + ready_vars.emplace(var); + } + }; + + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + insert_pending_var(version_pair.get()); + } + } + } + + for (auto &var : graph->dep_vars_) { + insert_pending_var(var.get()); + } + + for (auto &op : graph->ops_) { + if (op->Inputs().empty()) { + ready_ops.insert(op.get()); + } else { + pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); + } + } + + auto run_all_ops = [&](std::unordered_set &set) { + for (auto *op : set) { + for (auto out : op->Outputs()) { + ready_vars.emplace(out); + } + } + set.clear(); + }; + + while (!pending_vars.empty()) { + run_all_ops(ready_ops); + + if (ready_vars.empty()) { + return false; + } + + for (auto ready_var : ready_vars) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = --pending_ops[op]; + if (deps == 0) { + ready_ops.insert(op); + } + } + } + ready_vars.clear(); + } + return true; +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h new file mode 100644 index 00000000000..542c4a17289 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace framework { +namespace details { +class SSAGraph; + +class SSAGraghBuilderWithChecker : public SSAGraphBuilder { + public: + explicit SSAGraghBuilderWithChecker( + std::unique_ptr&& builder) + : builder_(std::move(builder)) {} + + std::unique_ptr Build(const ProgramDesc& program) const override { + auto graph = builder_->Build(program); + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return graph; + } + + bool IsValidGraph(const SSAGraph* graph) const; + + private: + std::unique_ptr builder_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f1ab3370709..5d95dc214ac 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -114,7 +114,7 @@ ParallelExecutor::ParallelExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, - builder_factory.Create()->BuildAndCheck(main_program))); + builder_factory.Create()->Build(main_program))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), -- GitLab