diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 627370cd2df7317b4d32aa967565aaf9cf0c7a08..4271e4c1bb6bc7b83f2633191ea2d464f4f56c4c 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) -cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index c43826b64cc5140c539df17fdd13d9bee7fefdcd..3ba9c1bba1812f2363aaca3d1b2f9eb1fa411c7a 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -8,6 +8,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) +cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) @@ -31,7 +32,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) -cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) +cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 211113c7979ee95d896c0a57879f7b3ad13b36ef..88a21f48879a15450051ad94ed76e1c48bf23014 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -11,8 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc similarity index 90% rename from paddle/fluid/framework/details/graph_builder_factory.cc rename to paddle/fluid/framework/details/ssa_graph_builder_factory.cc index a04b9bb63c06b40ff5c30c9792cdfad5d64d404c..b4b49d3de6da2e5fd7836668619e42d10bb6b35a 100644 --- a/paddle/fluid/framework/details/graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/graph_builder_factory.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/ssa_graph_checker.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h" namespace paddle { @@ -40,6 +41,8 @@ std::unique_ptr SSAGraphBuilderFactory::Create() { res.reset(new SSAGraghBuilderWithPrinter( std::move(fout), std::move(graphviz_printer), std::move(res))); } + res.reset(new SSAGraghBuilderWithChecker(std::move(res))); + return res; } } // namespace details diff --git a/paddle/fluid/framework/details/graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h similarity index 100% rename from paddle/fluid/framework/details/graph_builder_factory.h rename to paddle/fluid/framework/details/ssa_graph_builder_factory.h diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc new file mode 100644 index 0000000000000000000000000000000000000000..da5428946ee588e8eac1f78929dc0432df532975 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph.h" +#include +#include "paddle/fluid/framework/details/ssa_graph_checker.h" + +namespace paddle { +namespace framework { +namespace details { + +bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { + std::unordered_map pending_ops; + std::unordered_set pending_vars; + std::unordered_set ready_vars; + std::unordered_set ready_ops; + + auto insert_pending_var = [&](VarHandleBase *var) { + pending_vars.insert(var); + if (var->generated_op_ == nullptr) { + ready_vars.emplace(var); + } + }; + + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + insert_pending_var(version_pair.get()); + } + } + } + + for (auto &var : graph->dep_vars_) { + insert_pending_var(var.get()); + } + + for (auto &op : graph->ops_) { + if (op->Inputs().empty()) { + ready_ops.insert(op.get()); + } else { + pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); + } + } + + auto run_all_ops = [&](std::unordered_set &set) { + for (auto *op : set) { + for (auto out : op->Outputs()) { + ready_vars.emplace(out); + } + } + set.clear(); + }; + + while (!pending_vars.empty()) { + run_all_ops(ready_ops); + + if (ready_vars.empty()) { + return false; + } + + for (auto ready_var : ready_vars) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = --pending_ops[op]; + if (deps == 0) { + ready_ops.insert(op); + } + } + } + ready_vars.clear(); + } + return true; +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..542c4a172891ba9d3621918986089b2e400b6ae8 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace framework { +namespace details { +class SSAGraph; + +class SSAGraghBuilderWithChecker : public SSAGraphBuilder { + public: + explicit SSAGraghBuilderWithChecker( + std::unique_ptr&& builder) + : builder_(std::move(builder)) {} + + std::unique_ptr Build(const ProgramDesc& program) const override { + auto graph = builder_->Build(program); + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return graph; + } + + bool IsValidGraph(const SSAGraph* graph) const; + + private: + std::unique_ptr builder_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 496fadd04dac982b87b9d9e14f599ed37d9709d0..bcbf5736267f0d760d14c96784f0994c6bd013ac 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ready_vars->Push(var); } } + void ThreadedSSAGraphExecutor::RunOp( BlockingQueue *ready_var_q, details::OpHandleBase *op) { auto op_run = [ready_var_q, op, this] { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ce56f55e4195a0625cd0754152285b80e4282183..5d95dc214ac39117a2ec4674da7f3bd50fa6d3d0 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,8 +22,8 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif -#include "paddle/fluid/framework/details/graph_builder_factory.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h"