diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index dc17f6a21fab23e77de30f473afd128b8748c828..d1e1f0ed23d995a51a6be3eb46093eafbf89efb7 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -40,7 +40,8 @@ class ParallelExecutorPrivate { }; ParallelExecutor::ParallelExecutor( - size_t num_threads, const std::vector &places, + size_t num_threads, bool use_event, + const std::vector &places, const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope) @@ -73,7 +74,8 @@ ParallelExecutor::ParallelExecutor( auto graph = builder.Build(main_program); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - num_threads, true, member_->local_scopes_, places, std::move(graph))); + num_threads, use_event, member_->local_scopes_, places, + std::move(graph))); // Step 3. Create vars in each scope; for (auto *scope : member_->local_scopes_) { diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 14489a18c3afb67e663ffe568df54375bbfa0843..8bc09c5798854feabb43fa64d160b341bfe0e7b5 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -34,7 +34,7 @@ class ParallelExecutor { DISABLE_COPY_AND_ASSIGN(ParallelExecutor); public: - explicit ParallelExecutor(size_t num_threads, + explicit ParallelExecutor(size_t num_threads, bool use_event, const std::vector& places, const std::unordered_set& params, const ProgramDesc& startup_program, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 60662244ccb9bfe48775f9cc5a23a0ff529c035b..e1b1bbec97985aa839c62a0a82b81b020faf0008 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -499,15 +499,15 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "ParallelExecutor") .def("__init__", - [](ParallelExecutor &self, size_t num_threads, + [](ParallelExecutor &self, size_t num_threads, bool use_event, const std::vector &places, const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope) { - new (&self) - ParallelExecutor(num_threads, places, params, startup_program, - main_program, loss_var_name, scope); + new (&self) ParallelExecutor(num_threads, use_event, places, + params, startup_program, main_program, + loss_var_name, scope); }) .def("run", &ParallelExecutor::Run); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index fcea28220485039c9daf3c5fa2688c31f9f34c42..5ea4d977f4d8d9eb56b1fefa16f429df6e2a15bb 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -41,6 +41,7 @@ from memory_optimization_transpiler import memory_optimize, release_memory import profiler import unique_name import recordio_writer +from parallel_executor import ParallelExecutor Tensor = LoDTensor @@ -68,6 +69,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ 'profiler', 'unique_name', 'recordio_writer', + 'ParallelExecutor', ] diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..5e0588fa73241a8752e1b3195a123820165f070d --- /dev/null +++ b/python/paddle/fluid/parallel_executor.py @@ -0,0 +1,62 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import core +import multiprocessing +import framework +import executor + +__all__ = ['ParallelExecutor'] + + +class ParallelExecutor(object): + def __init__(self, loss_name, use_cuda, num_threads=None): + places = [] + if use_cuda: + for i in xrange(core.get_cuda_device_count()): + p = core.Place() + p.set_place(core.CUDAPlace(i)) + places.append(p) + else: + for i in xrange(multiprocessing.cpu_count()): + p = core.Place() + p.set_place(core.CPUPlace()) + places.append(p) + + if num_threads is None: + num_threads = min(len(places) * 2, multiprocessing.cpu_count()) + + startup = framework.default_startup_program() + main = framework.default_main_program() + scope = executor.global_scope() + + self.executor = core.ParallelExecutor( + num_threads, + True if use_cuda else False, # use_event + places, + set([ + p.name for p in main.global_block().iter_parameters() + if not p.stop_gradient + ]), + startup.desc, + main.desc, + loss_name, + scope) + self.scope = scope + + def run(self, fetch_list): + fetch_var_name = '@FETCHED_VAR_NAME@' + self.executor.run(fetch_list, fetch_var_name) + arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() + return [arr[i] for i in range(len(arr))] diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index cabb8e769dfcad8401fcfa17d6a43fa5b3656493..2ebdbaaca65fe865f6c24f9614674b6d18eba0e7 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -19,8 +19,54 @@ import paddle.v2.dataset.mnist as mnist import numpy +def simple_fc_net(): + reader = fluid.layers.open_recordio_file( + filename='./mnist.recordio', + shapes=[[-1, 784], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(reader) + hidden = img + for _ in xrange(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def fc_with_batchnorm(): + reader = fluid.layers.open_recordio_file( + filename='./mnist.recordio', + shapes=[[-1, 784], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(reader) + hidden = img + for _ in xrange(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + hidden = fluid.layers.batch_norm(input=hidden) + + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + class ParallelExecutor(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): # Convert mnist to recordio file with fluid.program_guard(fluid.Program(), fluid.Program()): reader = paddle.batch(mnist.train(), batch_size=32) @@ -35,51 +81,28 @@ class ParallelExecutor(unittest.TestCase): fluid.recordio_writer.convert_reader_to_recordio_file( './mnist.recordio', reader, feeder) - def test_main(self): + def test_simple_fc(self): + self.check_network_convergence(simple_fc_net) + + def test_batchnorm_fc(self): + self.check_network_convergence(fc_with_batchnorm) + + def check_network_convergence(self, method): main = fluid.Program() startup = fluid.Program() - with fluid.program_guard(main, startup): - reader = fluid.layers.open_recordio_file( - filename='./mnist.recordio', - shapes=[[-1, 784], [-1, 1]], - lod_levels=[0, 0], - dtypes=['float32', 'int64']) - img, label = fluid.layers.read_file(reader) - hidden = img - for _ in xrange(4): - hidden = fluid.layers.fc( - hidden, - size=200, - act='tanh', - bias_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(value=1.0))) - prediction = fluid.layers.fc(hidden, size=10, act='softmax') - loss = fluid.layers.cross_entropy(input=prediction, label=label) - loss = fluid.layers.mean(loss) + loss = method() adam = fluid.optimizer.Adam() adam.minimize(loss) - act_places = [] - for each in [fluid.CUDAPlace(0)]: - p = fluid.core.Place() - p.set_place(each) - act_places.append(p) - - exe = fluid.core.ParallelExecutor( - act_places, - set([p.name for p in main.global_block().iter_parameters()]), - startup.desc, main.desc, loss.name, fluid.global_scope()) - exe.run([loss.name], 'fetched_var') + exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True) + first_loss, = exe.run([loss.name]) + first_loss = numpy.array(first_loss) - first_loss = numpy.array(fluid.global_scope().find_var('fetched_var') - .get_lod_tensor_array()[0]) - print first_loss + for i in xrange(10): + exe.run([]) - for i in xrange(10): - exe.run([], 'fetched_var') - exe.run([loss.name], 'fetched_var') - last_loss = numpy.array(fluid.global_scope().find_var('fetched_var') - .get_lod_tensor_array()[0]) + last_loss, = exe.run([loss.name]) + last_loss = numpy.array(last_loss) - print first_loss, last_loss - self.assertGreater(first_loss[0], last_loss[0]) + print first_loss, last_loss + self.assertGreater(first_loss[0], last_loss[0])