From 36c2a9af27da71524aae97899f82c9e5847320e4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 17 Sep 2018 09:55:04 +0800 Subject: [PATCH] pass builder allow cutomize pass in python. --- paddle/fluid/framework/CMakeLists.txt | 9 +- paddle/fluid/framework/details/CMakeLists.txt | 5 + .../fluid/framework/details/build_strategy.cc | 150 ++++++++++++++++++ .../fluid/framework/details/build_strategy.h | 32 ++++ paddle/fluid/framework/ir/CMakeLists.txt | 2 + paddle/fluid/framework/ir/pass.cc | 1 - paddle/fluid/framework/ir/pass.h | 14 +- paddle/fluid/framework/ir/pass_builder.cc | 43 +++++ paddle/fluid/framework/ir/pass_builder.h | 45 ++++++ paddle/fluid/framework/parallel_executor.cc | 95 +---------- paddle/fluid/framework/parallel_executor.h | 4 +- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 28 +++- .../tests/unittests/test_pass_builder.py | 110 +++++++++++++ 14 files changed, 437 insertions(+), 103 deletions(-) create mode 100644 paddle/fluid/framework/details/build_strategy.cc create mode 100644 paddle/fluid/framework/ir/pass_builder.cc create mode 100644 paddle/fluid/framework/ir/pass_builder.h create mode 100644 python/paddle/fluid/tests/unittests/test_pass_builder.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6d8cbe5d9e4..69c6dd02005 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -150,11 +150,10 @@ else() endif() if (NOT WIN32) - cc_library(parallel_executor SRCS parallel_executor.cc DEPS - threaded_ssa_graph_executor scope_buffered_ssa_graph_executor - graph graph_viz_pass multi_devices_graph_pass - multi_devices_graph_print_pass multi_devices_graph_check_pass - fast_threaded_ssa_graph_executor fuse_elewise_add_act_pass) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS + threaded_ssa_graph_executor scope_buffered_ssa_graph_executor + graph build_strategy + fast_threaded_ssa_graph_executor) endif() # NOT WIN32 cc_library(prune SRCS prune.cc DEPS framework_proto) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index a8e0c4a3fed..0cf11bc9abd 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -54,3 +54,8 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu # device_context reduce_op_handle ) cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) + +cc_library(build_strategy SRCS build_strategy.cc DEPS + graph_viz_pass multi_devices_graph_pass + multi_devices_graph_print_pass multi_devices_graph_check_pass, + fuse_elewise_add_act_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc new file mode 100644 index 00000000000..2a3bc85ff79 --- /dev/null +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -0,0 +1,150 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/details/build_strategy.h" + +#include +#include + +#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" +#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class ParallelExecutorPassBuilder : public ir::PassBuilder { + public: + explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) + : ir::PassBuilder(), strategy_(strategy) { + // Apply a graph viz pass to record a graph. + if (!strategy_.debug_graphviz_path_.empty()) { + auto viz_pass = AppendPass("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + } + + // Apply op fusion. + if (strategy.fuse_elewise_add_act_ops_) { + auto fuse_elewise_add_act_pass = + ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass"); + graph = fuse_elewise_add_act_pass->Apply(std::move(graph)); + // Apply a graph viz pass to record a graph. + if (!strategy.debug_graphviz_path_.empty()) { + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); + } + } + + // Convert graph to run on multi-devices. + auto multi_devices_pass = AppendPass("multi_devices_pass"); + multi_devices_pass->SetNotOwned("strategy", + &strategy_); + + // Apply a graph print pass to record a graph with device info. + if (!strategy_.debug_graphviz_path_.empty()) { + auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); + multi_devices_print_pass->SetNotOwned( + "debug_graphviz_path", &strategy_.debug_graphviz_path_); + multi_devices_print_pass->Set( + "graph_printer", new details::GraphvizSSAGraphPrinter); + } + + // Verify that the graph is correct for multi-device executor. + AppendPass("multi_devices_check_pass"); + } + + std::unique_ptr Build( + const ProgramDesc &main_program, + const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, +#ifdef PADDLE_WITH_CUDA + const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { +#else + const bool use_cuda) const { +#endif + // Convert the program to graph. + std::unique_ptr graph(new ir::Graph(main_program)); + + for (std::shared_ptr &pass : AllPasses()) { + if (pass->Type() == "multi_devices_pass") { + pass->SetNotOwned>("places", + &places); + pass->SetNotOwned("loss_var_name", &loss_var_name); + pass->SetNotOwned>("params", + ¶m_names); + pass->SetNotOwned>("local_scopes", + &local_scopes); +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + pass->SetNotOwned("nccl_ctxs", nctx); +#endif + } + graph = pass->Apply(std::move(graph)); + } + return graph; + } + + private: + BuildStrategy strategy_; +}; + +ir::PassBuilder *BuildStrategy::CreatePassBuilder() const { + pass_builder_.reset(new ParallelExecutorPassBuilder(*this)); + return pass_builder_.get(); +} + +std::unique_ptr BuildStrategy::Apply( + const ProgramDesc &main_program, const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, +#ifdef PADDLE_WITH_CUDA + const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { +#else + const bool use_cuda) const { +#endif + if (!pass_builder_) { + CreatePassBuilder(); + } + // std::unique_ptr graph; + ParallelExecutorPassBuilder *builder = + reinterpret_cast(pass_builder_.get()); +#ifdef PADDLE_WITH_CUDA + std::unique_ptr graph = + builder->Build(main_program, places, loss_var_name, param_names, + local_scopes, use_cuda, nccl_ctxs); +#else + std::unique_ptr graph = builder->Build( + main_program, places, loss_var_name, param_names, local_scopes, use_cuda); +#endif + return graph; +} +} // namespace details +} // namespace framework +} // namespace paddle + +USE_PASS(fuse_elewise_add_act_pass); +USE_PASS(graph_viz_pass); +USE_PASS(multi_devices_pass); +USE_PASS(multi_devices_check_pass); +USE_PASS(multi_devices_print_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 77cafa49f18..4468708d09f 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -15,11 +15,25 @@ #pragma once #include +#include + +#include "paddle/fluid/framework/ir/pass_builder.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif namespace paddle { namespace framework { namespace details { +class ParallelExecutorPassBuilder; +struct BuildStrategy; + struct BuildStrategy { // ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and // kReduce, for CPU and GPU. If you use kAllReduce, different threads @@ -57,6 +71,24 @@ struct BuildStrategy { bool fuse_elewise_add_act_ops_{false}; bool enable_data_balance_{false}; + + ir::PassBuilder *CreatePassBuilder() const; + + std::unique_ptr Apply( + const ProgramDesc &main_program, + const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, +#ifdef PADDLE_WITH_CUDA + const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const; +#else + const bool use_cuda) const; +#endif + + private: + // TODO(panyx0718): This should probably be unique_ptr. + mutable std::shared_ptr pass_builder_; }; } // namespace details diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 4dca3ceb456..9796f277895 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -41,6 +41,8 @@ cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") +cc_library(pass_builder SRCS pass_builder.cc DEPS pass) + cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index d7158eba626..6cf405efe63 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -19,7 +19,6 @@ namespace paddle { namespace framework { namespace ir { std::unique_ptr Pass::Apply(std::unique_ptr graph) const { - PADDLE_ENFORCE(!applied_, "Pass can only Apply() once."); PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 0f14083d259..042a7461b42 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -42,6 +42,8 @@ class Pass { attr_dels_.clear(); } + std::string Type() const { return type_; } + std::unique_ptr Apply(std::unique_ptr graph) const; // Get a reference to the attributed previously set. @@ -68,13 +70,13 @@ class Pass { // should delete the attribute. template void SetNotOwned(const std::string &attr_name, AttrType *attr) { - PADDLE_ENFORCE(attrs_.count(attr_name) == 0); attrs_[attr_name] = attr; } protected: - virtual std::unique_ptr ApplyImpl( - std::unique_ptr graph) const = 0; + virtual std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + LOG(FATAL) << "Calling virtual Pass not implemented."; + } private: template @@ -89,7 +91,10 @@ class Pass { required_graph_attrs_.insert(attrs.begin(), attrs.end()); } + void RegisterType(const std::string &type) { type_ = type; } + mutable bool applied_{false}; + std::string type_; std::unordered_set required_pass_attrs_; std::unordered_set required_graph_attrs_; std::map attrs_; @@ -143,10 +148,11 @@ struct PassRegistrar : public Registrar { PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), "'%s' is registered more than once.", pass_type); PassRegistry::Instance().Insert( - pass_type, [this]() -> std::unique_ptr { + pass_type, [this, pass_type]() -> std::unique_ptr { std::unique_ptr pass(new PassType()); pass->RegisterRequiredPassAttrs(this->required_pass_attrs_); pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_); + pass->RegisterType(pass_type); return pass; }); } diff --git a/paddle/fluid/framework/ir/pass_builder.cc b/paddle/fluid/framework/ir/pass_builder.cc new file mode 100644 index 00000000000..e0719867b34 --- /dev/null +++ b/paddle/fluid/framework/ir/pass_builder.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/pass_builder.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::shared_ptr PassBuilder::AppendPass(const std::string& pass_type) { + auto pass = ir::PassRegistry::Instance().Get(pass_type); + passes_.emplace_back(pass.release()); + return passes_.back(); +} + +void PassBuilder::RemovePass(size_t idx) { + PADDLE_ENFORCE(passes_.size() > idx); + passes_.erase(passes_.begin() + idx); +} + +std::shared_ptr PassBuilder::InsertPass(size_t idx, + const std::string& pass_type) { + PADDLE_ENFORCE(passes_.size() >= idx); + std::shared_ptr pass( + ir::PassRegistry::Instance().Get(pass_type).release()); + passes_.insert(passes_.begin() + idx, std::move(pass)); + return passes_[idx]; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_builder.h b/paddle/fluid/framework/ir/pass_builder.h new file mode 100644 index 00000000000..9969cc90f38 --- /dev/null +++ b/paddle/fluid/framework/ir/pass_builder.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class PassBuilder { + public: + PassBuilder() {} + + virtual ~PassBuilder() {} + + std::shared_ptr AppendPass(const std::string& pass_type); + + std::shared_ptr InsertPass(size_t idx, const std::string& pass_type); + + void RemovePass(size_t idx); + + std::vector> AllPasses() const { return passes_; } + + protected: + std::vector> passes_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f5a54c0f48c..855870b41c9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -19,15 +19,13 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_viz_pass.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" #endif #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" -#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -35,80 +33,6 @@ limitations under the License. */ namespace paddle { namespace framework { -std::unique_ptr ApplyParallelExecutorPass( - const ProgramDesc &main_program, const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶m_names, - const std::vector &local_scopes, const bool use_cuda, -#ifdef PADDLE_WITH_CUDA - const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) { -#else - const BuildStrategy &strategy) { -#endif - // Convert the program to graph. - std::unique_ptr graph(new ir::Graph(main_program)); - - // Apply a graph viz pass to record a graph. - if (!strategy.debug_graphviz_path_.empty()) { - auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - graph = viz_pass->Apply(std::move(graph)); - } - - // Apply op fusion. - if (strategy.fuse_elewise_add_act_ops_) { - auto fuse_elewise_add_act_pass = - ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass"); - graph = fuse_elewise_add_act_pass->Apply(std::move(graph)); - // Apply a graph viz pass to record a graph. - if (!strategy.debug_graphviz_path_.empty()) { - auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - graph = viz_pass->Apply(std::move(graph)); - } - } - - // Convert graph to run on multi-devices. - auto multi_devices_pass = - ir::PassRegistry::Instance().Get("multi_devices_pass"); - multi_devices_pass->SetNotOwned>("places", - &places); - multi_devices_pass->SetNotOwned("loss_var_name", - &loss_var_name); - multi_devices_pass->SetNotOwned>( - "params", ¶m_names); - multi_devices_pass->SetNotOwned>("local_scopes", - &local_scopes); - multi_devices_pass->SetNotOwned("strategy", &strategy); - -#ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; - multi_devices_pass->SetNotOwned("nccl_ctxs", nctx); -#endif - graph = multi_devices_pass->Apply(std::move(graph)); - - // Apply a graph print pass to record a graph with device info. - if (!strategy.debug_graphviz_path_.empty()) { - auto multi_devices_print_pass = - ir::PassRegistry::Instance().Get("multi_devices_print_pass"); - multi_devices_print_pass->SetNotOwned( - "debug_graphviz_path", &strategy.debug_graphviz_path_); - multi_devices_print_pass->Set( - "graph_printer", new details::GraphvizSSAGraphPrinter); - graph = multi_devices_print_pass->Apply(std::move(graph)); - } - - // Verify that the graph is correct for multi-device executor. - auto multi_devices_check_pass = - ir::PassRegistry::Instance().Get("multi_devices_check_pass"); - graph = multi_devices_check_pass->Apply(std::move(graph)); - return graph; -} - class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(const std::vector &places) @@ -199,10 +123,9 @@ ParallelExecutor::ParallelExecutor( // Step 3. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp #ifdef PADDLE_WITH_CUDA - std::unique_ptr graph = ApplyParallelExecutorPass( + std::unique_ptr graph = build_strategy.Apply( main_program, member_->places_, loss_var_name, params, - member_->local_scopes_, member_->use_cuda_, build_strategy, - member_->nccl_ctxs_.get()); + member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); auto max_memory_size = GetEagerDeletionThreshold(); if (max_memory_size >= 0) { @@ -228,9 +151,9 @@ ParallelExecutor::ParallelExecutor( } } #else - std::unique_ptr graph = ApplyParallelExecutorPass( - main_program, member_->places_, loss_var_name, params, - member_->local_scopes_, member_->use_cuda_, build_strategy); + std::unique_ptr graph = + build_strategy.Apply(main_program, member_->places_, loss_var_name, + params, member_->local_scopes_, member_->use_cuda_); #endif if (exec_strategy.type_ == ExecutionStrategy::kDefault) { @@ -373,12 +296,6 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle - -USE_PASS(fuse_elewise_add_act_pass); -USE_PASS(graph_viz_pass); -USE_PASS(multi_devices_pass); -USE_PASS(multi_devices_check_pass); -USE_PASS(multi_devices_print_pass); #ifdef PADDLE_WITH_CUDA USE_PASS(reference_count_pass); #endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index c64906ff230..fd386a5987f 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,14 +14,14 @@ limitations under the License. */ #pragma once -#include #include #include #include #include #include + +#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h" -#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index b5bd07d401f..e7f634c4a62 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,5 +1,5 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method) +set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc) if(NOT WIN32) list(APPEND PYBIND_DEPS parallel_executor profiler) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8b62502e3f9..c14b893fa40 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" @@ -595,6 +596,28 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); + py::class_> pass(m, "Pass"); + pass.def(py::init()) + .def("set_str", [](ir::Pass &self, const std::string &name, + const std::string &attr) { + self.Set(name, new std::string(attr)); + }); + + py::class_ pb(m, "PassBuilder"); + pb.def(py::init()) + .def("append_pass", + [](ir::PassBuilder &self, + const std::string &pass_type) -> std::shared_ptr { + return self.AppendPass(pass_type); + }) + .def("all_passes", [](ir::PassBuilder &self) { return self.AllPasses(); }) + .def("insert_pass", + [](ir::PassBuilder &self, size_t idx, const std::string &pass_type) { + return self.InsertPass(idx, pass_type); + }) + .def("remove_pass", + [](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); }); + // -- python binds for parallel executor. py::class_ pe(m, "ParallelExecutor"); py::class_ exec_strategy(pe, "ExecutionStrategy"); @@ -677,7 +700,10 @@ All parameter, weight, gradient are variables in Paddle. }, [](BuildStrategy &self, bool b) { self.fuse_elewise_add_act_ops_ = b; - }); + }) + .def("create_pass_builder", + [](BuildStrategy &self) { return *self.CreatePassBuilder(); }, + py::return_value_policy::reference); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py new file mode 100644 index 00000000000..2da4c097d92 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -0,0 +1,110 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np +import unittest +import os +import sys +import math + + +def simple_fc_net(): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = img + for _ in range(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestPassBuilder(unittest.TestCase): + def check_network_convergence(self, use_cuda, build_strategy=None): + os.environ['CPU_NUM'] = str(4) + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = simple_fc_net() + test_program = main.clone(for_test=True) + + opt = fluid.optimizer.SGD(learning_rate=0.001) + opt.minimize(loss) + + batch_size = 32 + image = np.random.normal(size=(batch_size, 784)).astype('float32') + label = np.random.randint(0, 10, (batch_size, 1), dtype="int64") + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + feed_dict = {'image': image, 'label': label} + + train_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, + loss_name=loss.name, + main_program=main, + build_strategy=build_strategy) + + test_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, + main_program=test_program, + share_vars_from=train_exe, + build_strategy=build_strategy) + + for i in range(5): + test_loss, = test_exe.run([loss.name], feed=feed_dict) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) + + avg_test_loss_val = np.array(test_loss).mean() + if math.isnan(float(avg_test_loss_val)): + sys.exit("got NaN loss, testing failed.") + + avg_train_loss_val = np.array(train_loss).mean() + if math.isnan(float(avg_train_loss_val)): + sys.exit("got NaN loss, training failed.") + + self.assertTrue( + np.allclose( + train_loss, test_loss, atol=1e-8), + "Train loss: " + str(train_loss) + "\n Test loss:" + + str(test_loss)) + + def test_parallel_testing_with_new_strategy(self): + build_strategy = fluid.BuildStrategy() + pass_builder = build_strategy.create_pass_builder() + viz_pass = pass_builder.append_pass("graph_viz_pass") + all_passes = pass_builder.all_passes() + pass_builder.insert_pass(len(all_passes), "graph_viz_pass") + pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) + viz_pass.set_str("graph_viz_path", "/tmp/viz_pass") + + self.check_network_convergence( + use_cuda=core.is_compiled_with_cuda(), + build_strategy=build_strategy) + + +if __name__ == '__main__': + unittest.main() -- GitLab