From 16e4d0267564c24cf7897bcd6f0b05faebea09c9 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Thu, 31 Oct 2019 10:14:47 +0800 Subject: [PATCH] Refine the cache of program, context and scope in executor. (#18483) * Refine the cache of program, context and scope in executor. test=develop * Refine the unittest test_executor_and_use_program_cache. * Add the test the PaddingRNN with use_program_cache=True. test=develop * Remove a check. test=develop * Refine the unittest to check whether it is correct when setting use_program_cache=True. test=develop --- paddle/fluid/framework/executor.cc | 6 - paddle/fluid/framework/executor.h | 6 - paddle/fluid/pybind/pybind.cc | 16 +- python/paddle/fluid/executor.py | 19 +- .../test_eager_deletion_padding_rnn.py | 75 ++++--- .../test_executor_and_use_program_cache.py | 188 +++++++++++------- 6 files changed, 176 insertions(+), 134 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 14dd821367..a85683f9f8 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -287,12 +287,6 @@ static bool has_fetch_operators( return fetch_count > 0; } -std::unique_ptr Executor::PrepareCtxCache( - const ProgramDesc& program, int block_id, - const std::vector& skip_ref_cnt_vars, bool force_disable_gc) { - return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc); -} - void Executor::Run(const ProgramDesc& program, Scope* scope, std::map* feed_targets, std::map* fetch_targets, diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 13bbe29d73..6785b73a05 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -95,12 +95,6 @@ class Executor { const std::string& feed_holder_name = "feed", const std::string& fetch_holder_name = "fetch"); - std::unique_ptr PrepareCtxCache( - const ProgramDesc& program, int block_id, - const std::vector& skip_ref_cnt_vars = - std::vector(), - bool force_disable_gc = false); - static std::unique_ptr Prepare( const ProgramDesc& program, int block_id, const std::vector& skip_ref_cnt_vars = diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 370ebccd6c..2f2d69e50c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1389,7 +1389,7 @@ All parameter, weight, gradient are variables in Paddle. create_local_scope, create_vars, feed_holder_name, fetch_holder_name); }) - .def("run_cached_prepared_ctx", + .def("run_prepared_ctx", [](Executor &self, ExecutorPrepareContext *ctx, Scope *scope, bool create_local_scope = true, bool create_vars = true, bool keep_kids = false) { @@ -1397,10 +1397,16 @@ All parameter, weight, gradient are variables in Paddle. self.RunPreparedContext(ctx, scope, create_local_scope, create_vars, keep_kids); }) - .def("prepare_ctx_cache", &Executor::PrepareCtxCache, - py::call_guard()) - .def("create_variables", &Executor::CreateVariables, - py::call_guard()) + .def("prepare", + [](Executor &self, const ProgramDesc &program, int block_id, + const std::vector &skip_ref_cnt_vars = + std::vector(), + bool force_disable_gc = false) { + pybind11::gil_scoped_release release; + return self.Prepare(program, block_id, skip_ref_cnt_vars, + force_disable_gc); + }) + .def("create_variables", &Executor::CreateVariables) .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars, const std::vector &fetch_vars) { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 4209db5a7a..3a72c8381e 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -489,9 +489,6 @@ class Executor(object): self._default_executor = core.Executor(p) self._closed = False - def _get_var_cache(self, program_cache_key): - return self.var_caches.get(program_cache_key, None) - def _get_scope_cache(self, program_cache_key): return self.scope_caches.get(program_cache_key, None) @@ -510,9 +507,6 @@ class Executor(object): def _add_scope_cache(self, scope_cache_key, scope): self.scope_caches[scope_cache_key] = scope - def _add_var_cache(self, var_cache_key, var): - self.var_caches[var_cache_key] = var - def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, fetch_var_name): tmp_program = program.clone() @@ -853,7 +847,6 @@ class Executor(object): cached_program = self._get_program_cache(cache_key) cached_ctx = self._get_ctx_cache(cache_key) cached_scope = self._get_scope_cache(cache_key) - cached_var = self._get_var_cache(cache_key) if cached_program is None: cached_program = self._add_feed_fetch_ops( program=program, @@ -863,23 +856,21 @@ class Executor(object): fetch_var_name=fetch_var_name) self._add_program_cache(cache_key, cached_program) fetch_list_str = list(map(_to_name_str, fetch_list)) - cached_ctx = self._default_executor.prepare_ctx_cache( + cached_ctx = self._default_executor.prepare( cached_program.desc, 0, fetch_list_str, False) - cached_var = self._default_executor.create_variables( - cached_program.desc, scope, 0) # currently, we cache program, vars, sub_scope here # we suppose that in a life cycle of training, a user # will not create many programs. So, here the basic # rule of caching is to cache all unseen (program, var, scope) # when a user use use_program_cache. cached_scope = scope.new_scope() + self._default_executor.create_variables(cached_program.desc, + cached_scope, 0) self._add_ctx_cache(cache_key, cached_ctx) - self._add_var_cache(cache_key, cached_var) self._add_scope_cache(cache_key, cached_scope) program = cached_program ctx = cached_ctx scope = cached_scope - var = cached_var else: program = self._add_feed_fetch_ops( program=program, @@ -893,8 +884,8 @@ class Executor(object): self._default_executor.run(program.desc, scope, 0, True, True, fetch_var_name) else: - self._default_executor.run_cached_prepared_ctx(ctx, scope, False, - False, False) + self._default_executor.run_prepared_ctx(ctx, scope, False, False, + False) arr = scope.find_var(fetch_var_name).get_lod_tensor_array() tensors = arr._move_to_list() if return_numpy: diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py index 530964a731..c0fd448d43 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py @@ -32,7 +32,7 @@ from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN os.environ["CPU_NUM"] = "1" -class RnnConfig(object): +class RNNConfig(object): def __init__(self, model_type, rnn_model): self.model_type = model_type self.rnn_model = rnn_model @@ -478,11 +478,12 @@ def lm_model(hidden_size, return loss, last_hidden, last_cell, feeding_list -class EagerDeletionPaddingRnnTest(unittest.TestCase): +class PaddingRNNTestBase(unittest.TestCase): def setUp(self): self.reader = Reader() + self.device_count = 1 - def prepare_program(self, config): + def prepare_program(self, config, parallel=True): self.main_program = fluid.Program() self.startup_program = fluid.Program() self.startup_program.random_seed = config.random_seed @@ -517,21 +518,23 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): self.exe = Executor(fluid.CPUPlace()) self.exe.run(self.startup_program) - self.device_count = 1 - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.num_threads = self.device_count - exec_strategy.num_iteration_per_drop_scope = 100 - - build_strategy = fluid.BuildStrategy() - build_strategy.enable_inplace = True - build_strategy.memory_optimize = False - build_strategy.fuse_all_optimizer_ops = True - - self.train_program = fluid.compiler.CompiledProgram( - self.main_program).with_data_parallel( - loss_name=self.loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) + if parallel: + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = self.device_count + exec_strategy.num_iteration_per_drop_scope = 100 + + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = True + build_strategy.memory_optimize = False + build_strategy.fuse_all_optimizer_ops = True + + self.train_program = fluid.compiler.CompiledProgram( + self.main_program).with_data_parallel( + loss_name=self.loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + else: + self.train_program = self.main_program def generate_init_data(self): init_hidden = np.zeros( @@ -572,7 +575,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): res['learning_rate'] = self.generate_new_lr(epoch_id, device_count) return res - def train_an_epoch(self, epoch_id, batch_times): + def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True): train_data_iter = self.reader.get_data_iter(self.config) total_loss = 0 @@ -597,7 +600,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): self.last_hidden.name, self.last_cell.name ], - use_program_cache=True) + use_program_cache=use_program_cache) batch_time = time.time() - batch_start_time batch_times.append(batch_time) @@ -613,47 +616,53 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): ppl = np.append(ppl, batch_ppl) return ppl - def train(self, config): + def train(self, config, parallel=True, use_program_cache=True): self.config = config - self.prepare_program(config) + self.prepare_program(config, parallel) total_time = 0.0 ppl = np.zeros(shape=(0, config.batch_size)) for epoch_id in range(config.max_epoch): batch_times = [] epoch_start_time = time.time() - train_ppl = self.train_an_epoch(epoch_id, batch_times) + train_ppl = self.train_an_epoch(epoch_id, batch_times, + use_program_cache) epoch_time = time.time() - epoch_start_time total_time += epoch_time ppl = np.append(ppl, train_ppl) return ppl - def compare_padding_static_mode(self): + def compare_padding_static_mode(self, parallel=True, + use_program_cache=True): ''' - Test that train ppl of padding mode is same to that of static mode + Test that train ppl of padding mode is same to that of static mode ''' - config = RnnConfig('test', 'padding') + config = RNNConfig('test', 'padding') with fluid.scope_guard(fluid.Scope()): - padding_rnn_ppl = self.train(config) - config = RnnConfig('test', 'static') + padding_rnn_ppl = self.train(config, parallel, use_program_cache) + config = RNNConfig('test', 'static') with fluid.scope_guard(fluid.Scope()): - static_rnn_ppl = self.train(config) + static_rnn_ppl = self.train(config, parallel, use_program_cache) self.assertTrue( np.isclose( padding_rnn_ppl, static_rnn_ppl, rtol=0.001).all()) + +class EagerDeletionPaddingRNNTest(PaddingRNNTestBase): def test_padding_mode_no_eager_deletion(self): ''' - Test that train ppl of padding mode is same to that of static mode without eager deletion + Test that train ppl of padding mode is same to that of static mode without eager deletion ''' fluid.core._set_eager_deletion_mode(-1.0, 1.0, True) - self.compare_padding_static_mode() + # When parallel is True, use_program_cache does not make a difference. + self.compare_padding_static_mode(parallel=True, use_program_cache=True) def test_padding_mode_eager_deletion(self): ''' - Test that train ppl of padding mode is same to that of static mode under eager deletion + Test that train ppl of padding mode is same to that of static mode under eager deletion ''' fluid.core._set_eager_deletion_mode(0.0, 1.0, True) - self.compare_padding_static_mode() + # When parallel is True, use_program_cache does not make a difference. + self.compare_padding_static_mode(parallel=True, use_program_cache=True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_executor_and_use_program_cache.py b/python/paddle/fluid/tests/unittests/test_executor_and_use_program_cache.py index e1aaa82845..96d2317407 100644 --- a/python/paddle/fluid/tests/unittests/test_executor_and_use_program_cache.py +++ b/python/paddle/fluid/tests/unittests/test_executor_and_use_program_cache.py @@ -18,82 +18,130 @@ import unittest import numpy import paddle.fluid.core as core -from paddle.fluid.executor import Executor -from paddle.fluid.layers import mul, data +import paddle.fluid as fluid +from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase class TestExecutor(unittest.TestCase): def test_mul(self): - a = data(name='a', shape=[784], dtype='float32') - b = data( - name='b', - shape=[784, 100], - dtype='float32', - append_batch_size=False) - output = mul(x=a, y=b) - place = core.CPUPlace() + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + a = fluid.layers.data(name='a', shape=[784], dtype='float32') + b = fluid.layers.data( + name='b', + shape=[784, 100], + dtype='float32', + append_batch_size=False) + output = fluid.layers.mul(x=a, y=b) + + # Compute with numpy a_np = numpy.random.random((100, 784)).astype('float32') b_np = numpy.random.random((784, 100)).astype('float32') - exe = Executor(place) - import time - use_cache = True - step_num = 3 - run_time = 0.0 - for i in range(step_num): - begin = time.time() - outs = exe.run(feed={'a': a_np, - 'b': b_np}, - fetch_list=[output.name], - use_program_cache=use_cache) - end = time.time() - run_time += end - begin - out = outs[0] - self.assertEqual((100, 100), out.shape) - self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) - print("run time %f" % run_time) - use_cache = False - run_time = 0.0 - for i in range(step_num): - begin = time.time() - outs = exe.run(feed={'a': a_np, - 'b': b_np}, - fetch_list=[output.name], - use_program_cache=use_cache) - end = time.time() - run_time += end - begin - out = outs[0] - self.assertEqual((100, 100), out.shape) - self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) - print("run time %f" % run_time) - use_cache = True - run_time = 0.0 - for i in range(step_num): - begin = time.time() - outs = exe.run(feed={'a': a_np, - 'b': b_np}, - fetch_list=[output.name], - use_program_cache=use_cache) - end = time.time() - run_time += end - begin - out = outs[0] - self.assertEqual((100, 100), out.shape) - self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) - print("run time %f" % run_time) - - use_cache = True - run_time = 0.0 - for i in range(step_num): - begin = time.time() - outs = exe.run(feed={'a': a_np, - 'b': b_np}, - fetch_list=[output], - use_program_cache=use_cache) - end = time.time() - run_time += end - begin - out = outs[0] - self.assertEqual((100, 100), out.shape) - self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) - print("run time %f" % run_time) + out_np = numpy.dot(a_np, b_np) + + place = core.CPUPlace() + exe = fluid.Executor(place) + + def _train(use_program_cache, max_iters=1): + import time + + run_time = 0.0 + for i in range(max_iters): + begin = time.time() + outs = exe.run(program=main_program, + feed={'a': a_np, + 'b': b_np}, + fetch_list=[output.name], + use_program_cache=use_program_cache) + end = time.time() + run_time += end - begin + out = outs[0] + self.assertEqual((100, 100), out.shape) + self.assertTrue(numpy.allclose(out, out_np)) + return run_time + + max_iters = 3 + run_time_with_cache = _train( + use_program_cache=True, max_iters=max_iters) + print("run time with program cache: %f" % run_time_with_cache) + + run_time_without_cache = _train( + use_program_cache=False, max_iters=max_iters) + print("run time without program cache: %f" % run_time_without_cache) + + run_time_with_cache = _train( + use_program_cache=True, max_iters=max_iters) + print("run time with program cache: %f" % run_time_with_cache) + + run_time_with_cache = _train( + use_program_cache=True, max_iters=max_iters) + print("run time with program cache: %f" % run_time_with_cache) + + +class ExecutorPaddingRNNTest(PaddingRNNTestBase): + def train_and_save_inference_program(self, + rnn_model="static", + parallel=True, + use_program_cache=True): + config = RNNConfig("test", rnn_model) + with fluid.scope_guard(fluid.Scope()): + self.train(config, parallel, use_program_cache) + fluid.io.save_inference_model( + main_program=self.main_program, + feeded_var_names=self.feed_order, + target_vars=[self.loss, self.last_hidden, self.last_cell], + executor=self.exe, + dirname="padding_rnn." + rnn_model + ".inference_model", + params_filename="__params__") + + def test_inference_output(self): + for rnn_model in ["static", "padding"]: + # Set parallel to False to use the default executor. + self.train_and_save_inference_program( + rnn_model=rnn_model, parallel=True, use_program_cache=True) + + x_np = numpy.random.random( + (self.config.batch_size, self.config.num_steps, + 1)).astype("int64") + y_np = numpy.random.random( + (self.config.batch_size * self.config.num_steps, + 1)).astype("int64") + init_hidden_np = numpy.random.random( + (self.config.num_layers, self.config.batch_size, + self.config.hidden_size)).astype("float32") + init_cell_np = numpy.random.random( + (self.config.num_layers, self.config.batch_size, + self.config.hidden_size)).astype("float32") + + for use_program_cache in [False, True]: + with fluid.scope_guard(fluid.Scope()): + save_dirname = "padding_rnn." + rnn_model + ".inference_model" + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + save_dirname, self.exe, params_filename="__params__") + + results = self.exe.run(program=inference_program, + feed={ + "x": x_np, + "y": y_np, + "init_hidden": init_hidden_np, + "init_cell": init_cell_np + }, + fetch_list=fetch_targets, + use_program_cache=use_program_cache) + if use_program_cache is True: + results_with_cache = results + else: + results_without_cache = results + self.assertEqual( + len(results_with_cache), len(results_without_cache)) + for i in range(len(results_with_cache)): + self.assertEqual(results_with_cache[i].shape, + results_without_cache[i].shape) + self.assertTrue( + numpy.allclose(results_with_cache[i], results_without_cache[ + i])) if __name__ == '__main__': -- GitLab