diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6d8cbe5d9e491555a94e9420995149041213ab79..69c6dd02005902d9ee486a0b7f8f0c9a343be255 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -150,11 +150,10 @@ else() endif() if (NOT WIN32) - cc_library(parallel_executor SRCS parallel_executor.cc DEPS - threaded_ssa_graph_executor scope_buffered_ssa_graph_executor - graph graph_viz_pass multi_devices_graph_pass - multi_devices_graph_print_pass multi_devices_graph_check_pass - fast_threaded_ssa_graph_executor fuse_elewise_add_act_pass) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS + threaded_ssa_graph_executor scope_buffered_ssa_graph_executor + graph build_strategy + fast_threaded_ssa_graph_executor) endif() # NOT WIN32 cc_library(prune SRCS prune.cc DEPS framework_proto) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index a8e0c4a3fedfd56e38de7568be6b3f2e76a4b25f..e0a3ef5a9c6c53c42ebea1a41cac0d18a77781b2 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -54,3 +54,8 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu # device_context reduce_op_handle ) cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) + +cc_library(build_strategy SRCS build_strategy.cc DEPS + graph_viz_pass multi_devices_graph_pass + multi_devices_graph_print_pass multi_devices_graph_check_pass + fuse_elewise_add_act_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a6b497fa897e3882995688bf36704b1d77ea962 --- /dev/null +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -0,0 +1,126 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/details/build_strategy.h" + +#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" +#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class ParallelExecutorPassBuilder : public ir::PassBuilder { + public: + explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) + : ir::PassBuilder(), strategy_(strategy) { + // Add a graph viz pass to record a graph. + if (!strategy_.debug_graphviz_path_.empty()) { + auto viz_pass = AppendPass("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + } + + // Add op fusion. + if (strategy.fuse_elewise_add_act_ops_) { + auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass"); + // Add a graph viz pass to record a graph. + if (!strategy.debug_graphviz_path_.empty()) { + auto viz_pass = AppendPass("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph"); + viz_pass->Set("graph_viz_path", + new std::string(graph_path)); + } + } + + // Convert graph to run on multi-devices. + auto multi_devices_pass = AppendPass("multi_devices_pass"); + multi_devices_pass->SetNotOwned("strategy", + &strategy_); + + // Add a graph print pass to record a graph with device info. + if (!strategy_.debug_graphviz_path_.empty()) { + auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); + multi_devices_print_pass->SetNotOwned( + "debug_graphviz_path", &strategy_.debug_graphviz_path_); + multi_devices_print_pass->Set( + "graph_printer", new details::GraphvizSSAGraphPrinter); + } + + // Verify that the graph is correct for multi-device executor. + AppendPass("multi_devices_check_pass"); + } + + private: + BuildStrategy strategy_; +}; + +std::shared_ptr BuildStrategy::CreatePassesFromStrategy() + const { + pass_builder_.reset(new ParallelExecutorPassBuilder(*this)); + return pass_builder_; +} + +std::unique_ptr BuildStrategy::Apply( + const ProgramDesc &main_program, const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, +#ifdef PADDLE_WITH_CUDA + const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { +#else + const bool use_cuda) const { +#endif + // Create a default one if not initialized by user. + if (!pass_builder_) { + CreatePassesFromStrategy(); + } + + std::unique_ptr graph(new ir::Graph(main_program)); + + for (std::shared_ptr &pass : pass_builder_->AllPasses()) { + if (pass->Type() == "multi_devices_pass") { + pass->Erase("places"); + pass->SetNotOwned>("places", &places); + pass->Erase("loss_var_name"); + pass->SetNotOwned("loss_var_name", &loss_var_name); + pass->Erase("params"); + pass->SetNotOwned>("params", + ¶m_names); + pass->Erase("local_scopes"); + pass->SetNotOwned>("local_scopes", + &local_scopes); +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + pass->Erase("nccl_ctxs"); + pass->SetNotOwned("nccl_ctxs", nctx); +#endif + } + graph = pass->Apply(std::move(graph)); + } + return graph; +} +} // namespace details +} // namespace framework +} // namespace paddle + +USE_PASS(fuse_elewise_add_act_pass); +USE_PASS(graph_viz_pass); +USE_PASS(multi_devices_pass); +USE_PASS(multi_devices_check_pass); +USE_PASS(multi_devices_print_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 77cafa49f18158ea4187908856e10adaa0f916c9..02c4bea16916d58a6d0fce8918f8fceb9ff9356e 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -15,6 +15,17 @@ #pragma once #include +#include + +#include "paddle/fluid/framework/ir/pass_builder.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif namespace paddle { namespace framework { @@ -57,6 +68,30 @@ struct BuildStrategy { bool fuse_elewise_add_act_ops_{false}; bool enable_data_balance_{false}; + + // User normally doesn't need to call this API. + // The PassBuilder allows for more customized insert, remove of passes + // from python side. + // A new PassBuilder is created based on configs defined above and + // passes are owned by the PassBuilder. + std::shared_ptr CreatePassesFromStrategy() const; + + // Apply the passes built by the pass_builder_. The passes will be + // applied to the Program and output an ir::Graph. + std::unique_ptr Apply( + const ProgramDesc &main_program, + const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, +#ifdef PADDLE_WITH_CUDA + const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const; +#else + const bool use_cuda) const; +#endif + + private: + mutable std::shared_ptr pass_builder_; }; } // namespace details diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 4dca3ceb4569fb708c7a98621c5239acbe217586..9796f2778955ce47fe9fd1b13baa36404b1574c5 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -41,6 +41,8 @@ cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") +cc_library(pass_builder SRCS pass_builder.cc DEPS pass) + cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index d7158eba62686be57499df697466797e4034ea8f..6cf405efe63d2bc284c4650771a747b27bb4a9f6 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -19,7 +19,6 @@ namespace paddle { namespace framework { namespace ir { std::unique_ptr Pass::Apply(std::unique_ptr graph) const { - PADDLE_ENFORCE(!applied_, "Pass can only Apply() once."); PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 0f14083d259172f5b5f1ed80c7d38312d711beb5..9570c59cff2a6afeb1c607f7219b7b455974d6ce 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -42,6 +42,8 @@ class Pass { attr_dels_.clear(); } + std::string Type() const { return type_; } + std::unique_ptr Apply(std::unique_ptr graph) const; // Get a reference to the attributed previously set. @@ -52,6 +54,21 @@ class Pass { return *boost::any_cast(attrs_.at(attr_name)); } + bool Has(const std::string &attr_name) const { + return attrs_.find(attr_name) != attrs_.end(); + } + + void Erase(const std::string &attr_name) { + if (!Has(attr_name)) { + return; + } + if (attr_dels_.find(attr_name) != attr_dels_.end()) { + attr_dels_[attr_name](); + attr_dels_.erase(attr_name); + } + attrs_.erase(attr_name); + } + // Set a pointer to the attribute. Pass takes ownership of the attribute. template void Set(const std::string &attr_name, AttrType *attr) { @@ -68,13 +85,15 @@ class Pass { // should delete the attribute. template void SetNotOwned(const std::string &attr_name, AttrType *attr) { - PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", + attr_name); attrs_[attr_name] = attr; } protected: - virtual std::unique_ptr ApplyImpl( - std::unique_ptr graph) const = 0; + virtual std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + LOG(FATAL) << "Calling virtual Pass not implemented."; + } private: template @@ -89,7 +108,10 @@ class Pass { required_graph_attrs_.insert(attrs.begin(), attrs.end()); } + void RegisterType(const std::string &type) { type_ = type; } + mutable bool applied_{false}; + std::string type_; std::unordered_set required_pass_attrs_; std::unordered_set required_graph_attrs_; std::map attrs_; @@ -143,10 +165,11 @@ struct PassRegistrar : public Registrar { PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), "'%s' is registered more than once.", pass_type); PassRegistry::Instance().Insert( - pass_type, [this]() -> std::unique_ptr { + pass_type, [this, pass_type]() -> std::unique_ptr { std::unique_ptr pass(new PassType()); pass->RegisterRequiredPassAttrs(this->required_pass_attrs_); pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_); + pass->RegisterType(pass_type); return pass; }); } diff --git a/paddle/fluid/framework/ir/pass_builder.cc b/paddle/fluid/framework/ir/pass_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0719867b34d13666672b22070ce14dbaf80d85d --- /dev/null +++ b/paddle/fluid/framework/ir/pass_builder.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/pass_builder.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::shared_ptr PassBuilder::AppendPass(const std::string& pass_type) { + auto pass = ir::PassRegistry::Instance().Get(pass_type); + passes_.emplace_back(pass.release()); + return passes_.back(); +} + +void PassBuilder::RemovePass(size_t idx) { + PADDLE_ENFORCE(passes_.size() > idx); + passes_.erase(passes_.begin() + idx); +} + +std::shared_ptr PassBuilder::InsertPass(size_t idx, + const std::string& pass_type) { + PADDLE_ENFORCE(passes_.size() >= idx); + std::shared_ptr pass( + ir::PassRegistry::Instance().Get(pass_type).release()); + passes_.insert(passes_.begin() + idx, std::move(pass)); + return passes_[idx]; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_builder.h b/paddle/fluid/framework/ir/pass_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..733d3a3ad1ab8989ea30fe45cd7e1ffe9432de13 --- /dev/null +++ b/paddle/fluid/framework/ir/pass_builder.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class PassBuilder { + public: + PassBuilder() {} + + virtual ~PassBuilder() {} + + // Append a new pass to the end. + std::shared_ptr AppendPass(const std::string& pass_type); + + // Insert a new pass after `idx`. + std::shared_ptr InsertPass(size_t idx, const std::string& pass_type); + + // Remove a new pass at `idx`. + void RemovePass(size_t idx); + + // Returns a list of all passes. + std::vector> AllPasses() const { return passes_; } + + protected: + std::vector> passes_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 5b5011412ed39e033a7a65921e9c64ce2d54c638..6ad7d1df8bdd016b617c820c022ef55f23ba21cd 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -82,12 +82,10 @@ TEST(PassTest, TestPassAttrCheck) { ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); - try { - graph = pass->Apply(std::move(graph)); - } catch (paddle::platform::EnforceNotMet e) { - exception = std::string(e.what()); - } - ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos); + // Allow apply more than once. + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph = pass->Apply(std::move(graph)); pass = PassRegistry::Instance().Get("test_pass"); pass->SetNotOwned("test_pass_attr", &val); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f5a54c0f48c98512dcb393d63a61f3c927e7ac1f..855870b41c9629aaae88f009ffa908a91b25a931 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -19,15 +19,13 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_viz_pass.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" #endif #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" -#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -35,80 +33,6 @@ limitations under the License. */ namespace paddle { namespace framework { -std::unique_ptr ApplyParallelExecutorPass( - const ProgramDesc &main_program, const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶m_names, - const std::vector &local_scopes, const bool use_cuda, -#ifdef PADDLE_WITH_CUDA - const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) { -#else - const BuildStrategy &strategy) { -#endif - // Convert the program to graph. - std::unique_ptr graph(new ir::Graph(main_program)); - - // Apply a graph viz pass to record a graph. - if (!strategy.debug_graphviz_path_.empty()) { - auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - graph = viz_pass->Apply(std::move(graph)); - } - - // Apply op fusion. - if (strategy.fuse_elewise_add_act_ops_) { - auto fuse_elewise_add_act_pass = - ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass"); - graph = fuse_elewise_add_act_pass->Apply(std::move(graph)); - // Apply a graph viz pass to record a graph. - if (!strategy.debug_graphviz_path_.empty()) { - auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - graph = viz_pass->Apply(std::move(graph)); - } - } - - // Convert graph to run on multi-devices. - auto multi_devices_pass = - ir::PassRegistry::Instance().Get("multi_devices_pass"); - multi_devices_pass->SetNotOwned>("places", - &places); - multi_devices_pass->SetNotOwned("loss_var_name", - &loss_var_name); - multi_devices_pass->SetNotOwned>( - "params", ¶m_names); - multi_devices_pass->SetNotOwned>("local_scopes", - &local_scopes); - multi_devices_pass->SetNotOwned("strategy", &strategy); - -#ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; - multi_devices_pass->SetNotOwned("nccl_ctxs", nctx); -#endif - graph = multi_devices_pass->Apply(std::move(graph)); - - // Apply a graph print pass to record a graph with device info. - if (!strategy.debug_graphviz_path_.empty()) { - auto multi_devices_print_pass = - ir::PassRegistry::Instance().Get("multi_devices_print_pass"); - multi_devices_print_pass->SetNotOwned( - "debug_graphviz_path", &strategy.debug_graphviz_path_); - multi_devices_print_pass->Set( - "graph_printer", new details::GraphvizSSAGraphPrinter); - graph = multi_devices_print_pass->Apply(std::move(graph)); - } - - // Verify that the graph is correct for multi-device executor. - auto multi_devices_check_pass = - ir::PassRegistry::Instance().Get("multi_devices_check_pass"); - graph = multi_devices_check_pass->Apply(std::move(graph)); - return graph; -} - class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(const std::vector &places) @@ -199,10 +123,9 @@ ParallelExecutor::ParallelExecutor( // Step 3. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp #ifdef PADDLE_WITH_CUDA - std::unique_ptr graph = ApplyParallelExecutorPass( + std::unique_ptr graph = build_strategy.Apply( main_program, member_->places_, loss_var_name, params, - member_->local_scopes_, member_->use_cuda_, build_strategy, - member_->nccl_ctxs_.get()); + member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); auto max_memory_size = GetEagerDeletionThreshold(); if (max_memory_size >= 0) { @@ -228,9 +151,9 @@ ParallelExecutor::ParallelExecutor( } } #else - std::unique_ptr graph = ApplyParallelExecutorPass( - main_program, member_->places_, loss_var_name, params, - member_->local_scopes_, member_->use_cuda_, build_strategy); + std::unique_ptr graph = + build_strategy.Apply(main_program, member_->places_, loss_var_name, + params, member_->local_scopes_, member_->use_cuda_); #endif if (exec_strategy.type_ == ExecutionStrategy::kDefault) { @@ -373,12 +296,6 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle - -USE_PASS(fuse_elewise_add_act_pass); -USE_PASS(graph_viz_pass); -USE_PASS(multi_devices_pass); -USE_PASS(multi_devices_check_pass); -USE_PASS(multi_devices_print_pass); #ifdef PADDLE_WITH_CUDA USE_PASS(reference_count_pass); #endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index c64906ff230df5f2b7cc9f5c6b29d68956ab8f33..fd386a5987f11ff64964e95eb7e9b83572dc790c 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,14 +14,14 @@ limitations under the License. */ #pragma once -#include #include #include #include #include #include + +#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h" -#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index b5bd07d401f9ebfe441bc0f84f9bad317f0e8da9..e7f634c4a622b48e97040987836406cf73cb23b6 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,5 +1,5 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method) +set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc) if(NOT WIN32) list(APPEND PYBIND_DEPS parallel_executor profiler) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8b62502e3f920a3bf7d80f9e7edc2f3647a0e5b1..ef2f1f2a20a2eddee8ac077ee4bbf4dfd777448d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" @@ -595,6 +596,29 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); + py::class_> pass(m, "Pass"); + pass.def(py::init()) + .def("set_str", [](ir::Pass &self, const std::string &name, + const std::string &attr) { + self.Set(name, new std::string(attr)); + }); + + py::class_> pb( + m, "PassBuilder"); + pb.def(py::init()) + .def("append_pass", + [](ir::PassBuilder &self, + const std::string &pass_type) -> std::shared_ptr { + return self.AppendPass(pass_type); + }) + .def("all_passes", [](ir::PassBuilder &self) { return self.AllPasses(); }) + .def("insert_pass", + [](ir::PassBuilder &self, size_t idx, const std::string &pass_type) { + return self.InsertPass(idx, pass_type); + }) + .def("remove_pass", + [](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); }); + // -- python binds for parallel executor. py::class_ pe(m, "ParallelExecutor"); py::class_ exec_strategy(pe, "ExecutionStrategy"); @@ -677,7 +701,11 @@ All parameter, weight, gradient are variables in Paddle. }, [](BuildStrategy &self, bool b) { self.fuse_elewise_add_act_ops_ = b; - }); + }) + .def("_create_passes_from_strategy", + [](BuildStrategy &self) -> std::shared_ptr { + return self.CreatePassesFromStrategy(); + }); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..288c5f6a1f6b1760ca40c0c653e4c0726b799519 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -0,0 +1,121 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np +import unittest +import os +import sys +import math + + +def simple_fc_net(): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = img + for _ in range(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestPassBuilder(unittest.TestCase): + def check_network_convergence(self, use_cuda, build_strategy=None): + os.environ['CPU_NUM'] = str(4) + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = simple_fc_net() + test_program = main.clone(for_test=True) + + opt = fluid.optimizer.SGD(learning_rate=0.001) + opt.minimize(loss) + + batch_size = 32 + image = np.random.normal(size=(batch_size, 784)).astype('float32') + label = np.random.randint(0, 10, (batch_size, 1), dtype="int64") + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + feed_dict = {'image': image, 'label': label} + + train_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, + loss_name=loss.name, + main_program=main, + build_strategy=build_strategy) + + test_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, + main_program=test_program, + share_vars_from=train_exe, + build_strategy=build_strategy) + + for i in range(5): + test_loss, = test_exe.run([loss.name], feed=feed_dict) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) + + avg_test_loss_val = np.array(test_loss).mean() + if math.isnan(float(avg_test_loss_val)): + sys.exit("got NaN loss, testing failed.") + + avg_train_loss_val = np.array(train_loss).mean() + if math.isnan(float(avg_train_loss_val)): + sys.exit("got NaN loss, training failed.") + + self.assertTrue( + np.allclose( + train_loss, test_loss, atol=1e-8), + "Train loss: " + str(train_loss) + "\n Test loss:" + + str(test_loss)) + + def test_parallel_testing_with_new_strategy(self): + build_strategy = fluid.BuildStrategy() + pass_builder = build_strategy._create_passes_from_strategy() + origin_len = len(pass_builder.all_passes()) + + viz_pass = pass_builder.append_pass("graph_viz_pass") + self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) + + pass_builder.insert_pass( + len(pass_builder.all_passes()), "graph_viz_pass") + self.assertEqual(origin_len + 2, len(pass_builder.all_passes())) + + pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) + self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) + viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass") + + self.check_network_convergence( + use_cuda=core.is_compiled_with_cuda(), + build_strategy=build_strategy) + try: + os.stat("/tmp/test_viz_pass") + except os.error: + self.assertFalse(True) + + +if __name__ == '__main__': + unittest.main()