diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 627370cd2df7317b4d32aa967565aaf9cf0c7a08..4271e4c1bb6bc7b83f2633191ea2d464f4f56c4c 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) -cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index c106761f72e689ff53867ecad8e36b6038173d0e..ced063a0977cb9d04a32ef3a97d62d7865d179bd 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -30,7 +30,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) -cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) +cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 81d5b079b86a5652496812caa3e9c4ab2a989f7f..0c4d369e889cf2cca7722dac14a5268fdacabeb4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -272,6 +272,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); + return std::unique_ptr(graph); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index d70f95a9f5da1a9a018f6fd89a45bb0bcdec9ad1..d24669a8f8dfd5801239c036619531b4c16433d8 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -85,7 +85,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { } std::unique_ptr SSAGraphBuilder::BuildAndCheck( - const ProgramDesc &program) final { + const ProgramDesc &program) { std::unique_ptr graph = Build(program); PADDLE_ENFORCE(IsValidGraph(graph.get())); return std::move(graph); @@ -138,7 +138,7 @@ bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { if (ready_vars.empty()) { return false; } - for (auto ready_var : ready_vars.) { + for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); for (auto *op : ready_var->pending_ops_) { auto &deps = --pending_ops[op]; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index da9298ac8d0fb27ffd0a0bbed3cec55d0bae3e87..e99a98840756efcaf74f0a6a741cc0cce9fae939 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -31,7 +31,7 @@ class SSAGraphBuilder { virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; - std::unique_ptr BuildAndCheck(const ProgramDesc &program) final; + std::unique_ptr BuildAndCheck(const ProgramDesc &program); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc similarity index 96% rename from paddle/fluid/framework/details/graph_builder_factory.cc rename to paddle/fluid/framework/details/ssa_graph_builder_factory.cc index a04b9bb63c06b40ff5c30c9792cdfad5d64d404c..b5e90d6b056c5cc0fccecc7f8d6dd4849f71a075 100644 --- a/paddle/fluid/framework/details/graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/graph_builder_factory.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h" diff --git a/paddle/fluid/framework/details/graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h similarity index 100% rename from paddle/fluid/framework/details/graph_builder_factory.h rename to paddle/fluid/framework/details/ssa_graph_builder_factory.h diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ce56f55e4195a0625cd0754152285b80e4282183..f1ab337070976fbdb8dcac1a61e168e17036d990 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,8 +22,8 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif -#include "paddle/fluid/framework/details/graph_builder_factory.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -114,7 +114,7 @@ ParallelExecutor::ParallelExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, - builder_factory.Create()->Build(main_program))); + builder_factory.Create()->BuildAndCheck(main_program))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos),