From e5281b3c2d14fdd0cc515268307e29521eb40305 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Mon, 14 May 2018 13:23:07 +0800 Subject: [PATCH] Clean code & add execution strategy --- .../framework/details/execution_strategy.h | 29 ++++++++++ .../details/threaded_ssa_graph_executor.cc | 17 +++--- .../details/threaded_ssa_graph_executor.h | 11 ++-- paddle/fluid/framework/parallel_executor.cc | 9 ++-- paddle/fluid/framework/parallel_executor.h | 36 +++++++------ paddle/fluid/pybind/pybind.cc | 43 +++++++++------ python/paddle/fluid/__init__.py | 54 ++++++++++--------- python/paddle/fluid/parallel_executor.py | 51 ++++++++++-------- .../tests/unittests/test_parallel_executor.py | 8 +-- 9 files changed, 154 insertions(+), 104 deletions(-) create mode 100644 paddle/fluid/framework/details/execution_strategy.h diff --git a/paddle/fluid/framework/details/execution_strategy.h b/paddle/fluid/framework/details/execution_strategy.h new file mode 100644 index 0000000000..e8d510ec95 --- /dev/null +++ b/paddle/fluid/framework/details/execution_strategy.h @@ -0,0 +1,29 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace framework { +namespace details { + +struct ExecutionStrategy { + size_t num_threads_{0}; + bool use_event_{true}; + bool allow_op_delay_{false}; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index e90523ebe8..ef263d82c5 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -18,18 +18,17 @@ namespace paddle { namespace framework { namespace details { ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( - size_t num_threads, bool use_event, - const std::vector &local_scopes, + const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph, bool allow_op_delay) + std::unique_ptr &&graph) : SSAGraphExecutor(std::move(graph)), - pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), + pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) + : nullptr), local_scopes_(local_scopes), places_(places), fetch_ctxs_(places), - use_event_(use_event), running_ops_(0), - allow_op_delay_(allow_op_delay) {} + strategy_(strategy) {} FeedFetchList ThreadedSSAGraphExecutor::Run( const std::vector &fetch_tensors) { @@ -86,7 +85,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // // NOTE: DelayedOps have a lower priority. It will be scheduled after all // ready_ops have been performed. - if (ready_ops.empty() && allow_op_delay_ && running_ops_ == 0) { + if (ready_ops.empty() && strategy_.allow_op_delay_ && running_ops_ == 0) { run_all_ops(delayed_ops); } else { run_all_ops(ready_ops); @@ -113,7 +112,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto &deps = pending_ops[op]; --deps; if (deps == 0) { - if (op->IsMultiDeviceTransfer() && allow_op_delay_) { + if (op->IsMultiDeviceTransfer() && strategy_.allow_op_delay_) { delayed_ops.insert(op); } else { ready_ops.insert(op); @@ -191,7 +190,7 @@ void ThreadedSSAGraphExecutor::RunOp( auto op_run = [ready_var_q, op, this] { try { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); - op->Run(use_event_); + op->Run(strategy_.use_event_); VLOG(10) << op << " " << op->Name() << " Done "; running_ops_--; ready_var_q->Extend(op->Outputs()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index f18a88526b..1f7f88d752 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -23,6 +23,7 @@ #include #include "ThreadPool.h" // ThreadPool in thrird party #include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" @@ -34,11 +35,10 @@ namespace details { class ThreadedSSAGraphExecutor : public SSAGraphExecutor { public: - ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, + ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph, - bool allow_op_delay); + std::unique_ptr &&graph); // Run a SSAGraph by a thread pool // Use topological sort algorithm @@ -55,10 +55,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::vector places_; platform::DeviceContextPool fetch_ctxs_; - const bool use_event_; std::unique_ptr exception_; std::atomic running_ops_; - bool allow_op_delay_; void InsertPendingOp(std::unordered_map *pending_ops, OpHandleBase *op_instance) const; @@ -74,6 +72,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::unordered_map *pending_ops, std::unordered_set *pending_vars, BlockingQueue *ready_vars, FeedFetchList *fetch_data); + + private: + ExecutionStrategy strategy_; }; } // namespace details diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 20ef7e09f6..cdfd0a8c07 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -52,13 +52,13 @@ std::vector &ParallelExecutor::GetLocalScopes() { } ParallelExecutor::ParallelExecutor( - size_t num_threads, bool use_event, const std::vector &places, const std::unordered_set ¶ms, const std::unordered_set &bcast_vars, const ProgramDesc &main_program, const std::string &loss_var_name, - Scope *scope, const std::vector &local_scopes, bool allow_op_delay, - bool use_default_grad_scale, bool balance_parameter_opt_between_cards) + Scope *scope, const std::vector &local_scopes, + bool use_default_grad_scale, bool balance_parameter_opt_between_cards, + const ExecutionStrategy &exec_strategy) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; @@ -103,8 +103,7 @@ ParallelExecutor::ParallelExecutor( auto graph = builder.Build(main_program); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - num_threads, use_event, member_->local_scopes_, places, std::move(graph), - allow_op_delay)); + exec_strategy, member_->local_scopes_, places, std::move(graph))); // Step 3. Create vars in each scope; for (auto *var : main_program.Block(0).AllVars()) { diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index b251fc9141..ab50509124 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -17,53 +17,55 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" - namespace paddle { namespace framework { class ParallelExecutorPrivate; +using details::ExecutionStrategy; + class ParallelExecutor { DISABLE_COPY_AND_ASSIGN(ParallelExecutor); public: - explicit ParallelExecutor(size_t num_threads, bool use_event, - const std::vector& places, - const std::unordered_set& params, - const std::unordered_set& bcast_vars, - const ProgramDesc& main_program, - const std::string& loss_var_name, Scope* scope, - const std::vector& local_scopes, - bool allow_op_delay, bool use_default_grad_scale, - bool balance_parameter_opt_between_cards); + explicit ParallelExecutor(const std::vector &places, + const std::unordered_set ¶ms, + const std::unordered_set &bcast_vars, + const ProgramDesc &main_program, + const std::string &loss_var_name, Scope *scope, + const std::vector &local_scopes, + bool use_default_grad_scale, + bool balance_parameter_opt_between_cards, + const ExecutionStrategy &exec_strategy); ~ParallelExecutor(); - std::vector& GetLocalScopes(); + std::vector &GetLocalScopes(); /** * Feed tensors to local scopes. The size of tensors should be equal to the * size of local scopes. */ void FeedTensorsIntoLocalScopes( - const std::vector>& tensors); + const std::vector> &tensors); void FeedAndSplitTensorIntoLocalScopes( - const std::unordered_map& tensors); + const std::unordered_map &tensors); - void Run(const std::vector& fetch_tensors, - const std::string& fetched_var_name); + void Run(const std::vector &fetch_tensors, + const std::string &fetched_var_name); - void BCastParamsToGPUs(const std::unordered_set& vars) const; + void BCastParamsToGPUs(const std::unordered_set &vars) const; private: - ParallelExecutorPrivate* member_; + ParallelExecutorPrivate *member_; }; } // namespace framework diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3e2eed31b4..c456bc1a71 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -494,22 +494,33 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("reset_profiler", platform::ResetProfiler); - py::class_(m, "ParallelExecutor") - .def("__init__", - [](ParallelExecutor &self, size_t num_threads, bool use_event, - const std::vector &places, - const std::unordered_set ¶ms, - const std::unordered_set &bcast_vars, - const ProgramDesc &main_program, const std::string &loss_var_name, - Scope *scope, std::vector &local_scopes, - bool allow_op_delay, bool use_default_grad_scale, - bool balance_parameter_opt_between_cards) { - new (&self) ParallelExecutor( - num_threads, use_event, places, params, bcast_vars, - main_program, loss_var_name, scope, local_scopes, - allow_op_delay, use_default_grad_scale, - balance_parameter_opt_between_cards); - }) + py::class_ pe(m, "ParallelExecutor"); + py::class_(pe, "ExecutionStrategy") + .def(py::init()) + .def_property( + "num_threads", + [](const ExecutionStrategy &self) { return self.num_threads_; }, + [](ExecutionStrategy &self, size_t num_threads) { + self.num_threads_ = num_threads; + }) + .def_property( + "use_event", + [](const ExecutionStrategy &self) { return self.use_event_; }, + [](ExecutionStrategy &self, bool use_event) { + self.use_event_ = use_event; + }) + .def_property( + "allow_op_delay", + [](const ExecutionStrategy &self) { return self.allow_op_delay_; }, + [](ExecutionStrategy &self, bool allow_op_delay) { + self.allow_op_delay_ = allow_op_delay; + }); + + pe.def(py::init &, + const std::unordered_set &, + const std::unordered_set &, const ProgramDesc &, + const std::string &, Scope *, std::vector &, bool, + bool, const ExecutionStrategy &>()) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index c8a435748d..ef7a586475 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -44,42 +44,44 @@ import transpiler from param_attr import ParamAttr, WeightNormParamAttr from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace -from transpiler import DistributeTranspiler, SimpleDistributeTranspiler, InferenceTranspiler, memory_optimize, release_memory +from transpiler import DistributeTranspiler, SimpleDistributeTranspiler, \ + InferenceTranspiler, memory_optimize, release_memory from concurrency import (Go, make_channel, channel_send, channel_recv, channel_close, Select) import clip import profiler import unique_name import recordio_writer -from parallel_executor import ParallelExecutor +from parallel_executor import ParallelExecutor, ExecutionStrategy Tensor = LoDTensor -__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\ +__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + [ - 'io', - 'initializer', - 'layers', - 'transpiler' - 'nets', - 'optimizer', - 'learning_rate_decay', - 'backward', - 'regularizer', - 'LoDTensor', - 'CPUPlace', - 'CUDAPlace', - 'CUDAPinnedPlace', - 'Tensor', - 'ParamAttr', - 'WeightNormParamAttr', - 'DataFeeder', - 'clip', - 'profiler', - 'unique_name', - 'recordio_writer', - 'ParallelExecutor', -] + 'io', + 'initializer', + 'layers', + 'transpiler' + 'nets', + 'optimizer', + 'learning_rate_decay', + 'backward', + 'regularizer', + 'LoDTensor', + 'CPUPlace', + 'CUDAPlace', + 'CUDAPinnedPlace', + 'Tensor', + 'ParamAttr', + 'WeightNormParamAttr', + 'DataFeeder', + 'clip', + 'profiler', + 'unique_name', + 'recordio_writer', + 'ParallelExecutor', + 'ExecutionStrategy', + ] def __bootstrap__(): diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 5b43f860e7..69ea9ee335 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -19,7 +19,9 @@ import executor import warnings import sys -__all__ = ['ParallelExecutor'] +__all__ = ['ParallelExecutor', 'ExecutionStrategy'] + +ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy class ParallelExecutor(object): @@ -27,11 +29,11 @@ class ParallelExecutor(object): use_cuda, loss_name=None, main_program=None, - num_threads=None, - allow_op_delay=False, share_vars_from=None, use_default_grad_scale=True, - balance_parameter_opt_between_cards=False): + balance_parameter_opt_between_cards=False, + exec_strategy=None, + **kwargs): """ ParallelExecutor can run program in parallel. @@ -40,11 +42,6 @@ class ParallelExecutor(object): loss_name(str, default None): The loss name must set in training. main_program(Program, default None): The program that need to run, if not provided, then default_main_program will be used. - num_threads(int, default None): How many threads are used for - training. - allow_op_delay(bool, default False): Whether to delay and buffer - some operators together for scheduling or not, which may - improve performance in some cases, default False. share_vars_from(ParallelExecutor, default None): If provied, it will share variables from the specified ParallelExecutor. use_default_grad_scale(bool, default True): If set True, a default @@ -76,6 +73,16 @@ class ParallelExecutor(object): train_loss, = train_exe.run([loss.name], feed=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict) """ + if len(kwargs) != 0: + err_msg = "" + for key in kwargs: + if key in dir(ExecutionStrategy): + err_msg += \ + "Setting {0} by constructor is deprecated. Use " \ + "strategy=ExecutionStrategy(); strategy.{0}=xxx; " \ + "pe=ParallelExecutor(exec_strategy=strategy) " \ + "instead.\n " + raise ValueError(err_msg) self._places = [] self._act_places = [] @@ -93,13 +100,20 @@ class ParallelExecutor(object): self._places.append(p) assert self._places, "no place for execution" - if num_threads is None: + if exec_strategy is None: + exec_strategy = ExecutionStrategy() + if use_cuda: + exec_strategy.use_event = True + else: + exec_strategy.use_event = False + + if exec_strategy.num_threads == 0: if use_cuda: # Experiments on se-resnext shows that too many threads hurt # performance. Worth tunning for other models in the future. - num_threads = len(self._places) * 2 + exec_strategy.num_threads = len(self._places) * 2 else: - num_threads = min( + exec_strategy.num_threads = min( len(self._places) * 2, multiprocessing.cpu_count()) main = main_program @@ -120,21 +134,14 @@ class ParallelExecutor(object): ] self.executor = core.ParallelExecutor( - num_threads, - True if use_cuda else False, # use_event self._places, set([ p.name for p in main.global_block().iter_parameters() if not p.stop_gradient ]), - set(self.persistable_vars), - main.desc, - loss_name if loss_name else '', - scope, - local_scopes, - allow_op_delay, - use_default_grad_scale, - balance_parameter_opt_between_cards) + set(self.persistable_vars), main.desc, loss_name + if loss_name else '', scope, local_scopes, use_default_grad_scale, + balance_parameter_opt_between_cards, exec_strategy) self.scope = scope diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index a3be1a8db6..4173ad1925 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -232,14 +232,14 @@ class TestParallelExecutorBase(unittest.TestCase): place = fluid.CUDAPlace(0) startup_exe = fluid.Executor(place) startup_exe.run(startup) - + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.allow_op_delay = allow_op_delay if use_parallel_executor: exe = fluid.ParallelExecutor( True, loss_name=loss.name, - allow_op_delay=allow_op_delay, - balance_parameter_opt_between_cards=balance_parameter_opt_between_cards - ) + balance_parameter_opt_between_cards=balance_parameter_opt_between_cards, + exec_strategy=exec_strategy) else: exe = fluid.Executor(place=place) -- GitLab