From 8291b916d6cf053db779598b01dd59191fb5a1df Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 7 Jun 2018 16:24:23 +0800 Subject: [PATCH] replace graph_builder_factory with ssa_graph_builder_factory --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 2 +- paddle/fluid/framework/details/multi_devices_graph_builder.cc | 1 + paddle/fluid/framework/details/ssa_graph_builder.cc | 4 ++-- paddle/fluid/framework/details/ssa_graph_builder.h | 2 +- ...{graph_builder_factory.cc => ssa_graph_builder_factory.cc} | 2 +- .../{graph_builder_factory.h => ssa_graph_builder_factory.h} | 0 paddle/fluid/framework/parallel_executor.cc | 4 ++-- 8 files changed, 9 insertions(+), 8 deletions(-) rename paddle/fluid/framework/details/{graph_builder_factory.cc => ssa_graph_builder_factory.cc} (96%) rename paddle/fluid/framework/details/{graph_builder_factory.h => ssa_graph_builder_factory.h} (100%) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 627370cd2d..4271e4c1bb 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) -cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index c106761f72..ced063a097 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -30,7 +30,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) -cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) +cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 81d5b079b8..0c4d369e88 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -272,6 +272,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); + return std::unique_ptr(graph); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index d70f95a9f5..d24669a8f8 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -85,7 +85,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { } std::unique_ptr SSAGraphBuilder::BuildAndCheck( - const ProgramDesc &program) final { + const ProgramDesc &program) { std::unique_ptr graph = Build(program); PADDLE_ENFORCE(IsValidGraph(graph.get())); return std::move(graph); @@ -138,7 +138,7 @@ bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { if (ready_vars.empty()) { return false; } - for (auto ready_var : ready_vars.) { + for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); for (auto *op : ready_var->pending_ops_) { auto &deps = --pending_ops[op]; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index da9298ac8d..e99a988407 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -31,7 +31,7 @@ class SSAGraphBuilder { virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; - std::unique_ptr BuildAndCheck(const ProgramDesc &program) final; + std::unique_ptr BuildAndCheck(const ProgramDesc &program); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc similarity index 96% rename from paddle/fluid/framework/details/graph_builder_factory.cc rename to paddle/fluid/framework/details/ssa_graph_builder_factory.cc index a04b9bb63c..b5e90d6b05 100644 --- a/paddle/fluid/framework/details/graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/graph_builder_factory.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h" diff --git a/paddle/fluid/framework/details/graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h similarity index 100% rename from paddle/fluid/framework/details/graph_builder_factory.h rename to paddle/fluid/framework/details/ssa_graph_builder_factory.h diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ce56f55e41..f1ab337070 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,8 +22,8 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif -#include "paddle/fluid/framework/details/graph_builder_factory.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -114,7 +114,7 @@ ParallelExecutor::ParallelExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, - builder_factory.Create()->Build(main_program))); + builder_factory.Create()->BuildAndCheck(main_program))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), -- GitLab