diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index abeca93e5f1108301e2f2cd140a4a495062f38ff..0534067254613008d7ff4e531ecb7368a0ad6925 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -247,20 +247,7 @@ static bool has_fetch_operators( std::unique_ptr Executor::PrepareCtxCache( const ProgramDesc& program, int block_id, const std::vector& skip_ref_cnt_vars, bool force_disable_gc) { - std::unique_ptr ctx; - ctx.reset(new ExecutorPrepareContext(program, block_id)); - auto& block = program.Block(block_id); - for (auto& op_desc : block.AllOps()) { - ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); - } -#ifdef PADDLE_WITH_NGRAPH - if (FLAGS_use_ngraph) { - paddle::operators::NgraphEngine::FuseNgraphOps( - ctx->prog_.Block(ctx->block_id_), &ctx->ops_); - } -#endif - ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc); - return ctx; + return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc); } void Executor::Run(const ProgramDesc& program, Scope* scope, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a9bfb99333d6b8124e9a89b97ed058cd4f582132..9fc8801b27f932139a5e39535bb980a4464558c4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -943,6 +943,8 @@ All parameter, weight, gradient are variables in Paddle. }) .def("prepare_ctx_cache", &Executor::PrepareCtxCache, py::call_guard()) + .def("create_variables", &Executor::CreateVariables, + py::call_guard()) .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars, const std::vector &fetch_vars) { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index c44535f3dc6d762296a76f9a2248cbd0cffd32d4..dfa9a0f4d37291d0411e971c553dbcf0489d605d 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -361,11 +361,19 @@ class Executor(object): self.place = place self.program_caches = dict() self.ctx_caches = dict() + self.scope_caches = dict() + self.var_caches = dict() p = core.Place() p.set_place(self.place) self._default_executor = core.Executor(p) self._closed = False + def _get_var_cache(self, program_cache_key): + return self.var_caches.get(program_cache_key, None) + + def _get_scope_cache(self, program_cache_key): + return self.scope_caches.get(program_cache_key, None) + def _get_ctx_cache(self, program_cache_key): return self.ctx_caches.get(program_cache_key, None) @@ -378,6 +386,12 @@ class Executor(object): def _add_ctx_cache(self, ctx_cache_key, ctx): self.ctx_caches[ctx_cache_key] = ctx + def _add_scope_cache(self, scope_cache_key, scope): + self.scope_caches[scope_cache_key] = scope + + def _add_var_cache(self, var_cache_key, var): + self.var_caches[var_cache_key] = var + def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, fetch_var_name): tmp_program = program.clone() @@ -689,10 +703,12 @@ class Executor(object): "Executor requires Program as its Parameter. But you passed in %s" % (type(program))) - cache_key = _get_strong_program_cache_key(program, feed, fetch_list) if use_program_cache: + cache_key = _get_strong_program_cache_key(program, feed, fetch_list) cached_program = self._get_program_cache(cache_key) cached_ctx = self._get_ctx_cache(cache_key) + cached_scope = self._get_scope_cache(cache_key) + cached_var = self._get_var_cache(cache_key) if cached_program is None: cached_program = self._add_feed_fetch_ops( program=program, @@ -701,13 +717,25 @@ class Executor(object): feed_var_name=feed_var_name, fetch_var_name=fetch_var_name) self._add_program_cache(cache_key, cached_program) + fetch_list_str = list(map(_to_name_str, fetch_list)) cached_ctx = self._default_executor.prepare_ctx_cache( - cached_program.desc, 0, fetch_list, False) + cached_program.desc, 0, fetch_list_str, False) + cached_var = self._default_executor.create_variables( + cached_program.desc, scope, 0) + # currently, we cache program, vars, sub_scope here + # we suppose that in a life cycle of training, a user + # will not create many programs. So, here the basic + # rule of caching is to cache all unseen (program, var, scope) + # when a user use use_program_cache. + cached_scope = scope.new_scope() self._add_ctx_cache(cache_key, cached_ctx) + self._add_var_cache(cache_key, cached_var) + self._add_scope_cache(cache_key, cached_scope) program = cached_program ctx = cached_ctx + scope = cached_scope + var = cached_var else: - self.program_caches.pop(cache_key, None) program = self._add_feed_fetch_ops( program=program, feed=feed, @@ -719,7 +747,7 @@ class Executor(object): if not use_program_cache: exe.run(program.desc, scope, 0, True, True, fetch_var_name) else: - exe.run_cached_prepared_ctx(ctx, scope, True, True, False) + exe.run_cached_prepared_ctx(ctx, scope, False, False, False) outs = self._fetch_data(fetch_list, fetch_var_name, scope) if return_numpy: outs = as_numpy(outs) diff --git a/python/paddle/fluid/tests/unittests/test_executor_and_use_program_cache.py b/python/paddle/fluid/tests/unittests/test_executor_and_use_program_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..e1aaa82845bac6c02b8825adbbfa6dcf33d97894 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_executor_and_use_program_cache.py @@ -0,0 +1,100 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy +import paddle.fluid.core as core +from paddle.fluid.executor import Executor +from paddle.fluid.layers import mul, data + + +class TestExecutor(unittest.TestCase): + def test_mul(self): + a = data(name='a', shape=[784], dtype='float32') + b = data( + name='b', + shape=[784, 100], + dtype='float32', + append_batch_size=False) + output = mul(x=a, y=b) + place = core.CPUPlace() + a_np = numpy.random.random((100, 784)).astype('float32') + b_np = numpy.random.random((784, 100)).astype('float32') + exe = Executor(place) + import time + use_cache = True + step_num = 3 + run_time = 0.0 + for i in range(step_num): + begin = time.time() + outs = exe.run(feed={'a': a_np, + 'b': b_np}, + fetch_list=[output.name], + use_program_cache=use_cache) + end = time.time() + run_time += end - begin + out = outs[0] + self.assertEqual((100, 100), out.shape) + self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) + print("run time %f" % run_time) + use_cache = False + run_time = 0.0 + for i in range(step_num): + begin = time.time() + outs = exe.run(feed={'a': a_np, + 'b': b_np}, + fetch_list=[output.name], + use_program_cache=use_cache) + end = time.time() + run_time += end - begin + out = outs[0] + self.assertEqual((100, 100), out.shape) + self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) + print("run time %f" % run_time) + use_cache = True + run_time = 0.0 + for i in range(step_num): + begin = time.time() + outs = exe.run(feed={'a': a_np, + 'b': b_np}, + fetch_list=[output.name], + use_program_cache=use_cache) + end = time.time() + run_time += end - begin + out = outs[0] + self.assertEqual((100, 100), out.shape) + self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) + print("run time %f" % run_time) + + use_cache = True + run_time = 0.0 + for i in range(step_num): + begin = time.time() + outs = exe.run(feed={'a': a_np, + 'b': b_np}, + fetch_list=[output], + use_program_cache=use_cache) + end = time.time() + run_time += end - begin + out = outs[0] + self.assertEqual((100, 100), out.shape) + self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) + print("run time %f" % run_time) + + +if __name__ == '__main__': + unittest.main()