未验证 提交 d5239109 编写于 作者: G guru4elephant 提交者: GitHub

fix prepare context redundant code problem, optimize executor by cach… (#17743)

* fix prepare context redundant code problem, optimize executor by caching create_varaiables
test=develop

* cache sub_scope, program, var when use_program_cache=True is set

* make fetch_list runable with variables, add more unittest for use_program_cache
上级 2c58f1a8
...@@ -247,20 +247,7 @@ static bool has_fetch_operators( ...@@ -247,20 +247,7 @@ static bool has_fetch_operators(
std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache( std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) { const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
std::unique_ptr<ExecutorPrepareContext> ctx; return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc);
ctx.reset(new ExecutorPrepareContext(program, block_id));
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
}
#endif
ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
return ctx;
} }
void Executor::Run(const ProgramDesc& program, Scope* scope, void Executor::Run(const ProgramDesc& program, Scope* scope,
......
...@@ -943,6 +943,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -943,6 +943,8 @@ All parameter, weight, gradient are variables in Paddle.
}) })
.def("prepare_ctx_cache", &Executor::PrepareCtxCache, .def("prepare_ctx_cache", &Executor::PrepareCtxCache,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("create_variables", &Executor::CreateVariables,
py::call_guard<py::gil_scoped_release>())
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars, int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) { const std::vector<std::string> &fetch_vars) {
......
...@@ -361,11 +361,19 @@ class Executor(object): ...@@ -361,11 +361,19 @@ class Executor(object):
self.place = place self.place = place
self.program_caches = dict() self.program_caches = dict()
self.ctx_caches = dict() self.ctx_caches = dict()
self.scope_caches = dict()
self.var_caches = dict()
p = core.Place() p = core.Place()
p.set_place(self.place) p.set_place(self.place)
self._default_executor = core.Executor(p) self._default_executor = core.Executor(p)
self._closed = False self._closed = False
def _get_var_cache(self, program_cache_key):
return self.var_caches.get(program_cache_key, None)
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
def _get_ctx_cache(self, program_cache_key): def _get_ctx_cache(self, program_cache_key):
return self.ctx_caches.get(program_cache_key, None) return self.ctx_caches.get(program_cache_key, None)
...@@ -378,6 +386,12 @@ class Executor(object): ...@@ -378,6 +386,12 @@ class Executor(object):
def _add_ctx_cache(self, ctx_cache_key, ctx): def _add_ctx_cache(self, ctx_cache_key, ctx):
self.ctx_caches[ctx_cache_key] = ctx self.ctx_caches[ctx_cache_key] = ctx
def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope
def _add_var_cache(self, var_cache_key, var):
self.var_caches[var_cache_key] = var
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name): fetch_var_name):
tmp_program = program.clone() tmp_program = program.clone()
...@@ -689,10 +703,12 @@ class Executor(object): ...@@ -689,10 +703,12 @@ class Executor(object):
"Executor requires Program as its Parameter. But you passed in %s" "Executor requires Program as its Parameter. But you passed in %s"
% (type(program))) % (type(program)))
cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
if use_program_cache: if use_program_cache:
cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_program = self._get_program_cache(cache_key) cached_program = self._get_program_cache(cache_key)
cached_ctx = self._get_ctx_cache(cache_key) cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key)
cached_var = self._get_var_cache(cache_key)
if cached_program is None: if cached_program is None:
cached_program = self._add_feed_fetch_ops( cached_program = self._add_feed_fetch_ops(
program=program, program=program,
...@@ -701,13 +717,25 @@ class Executor(object): ...@@ -701,13 +717,25 @@ class Executor(object):
feed_var_name=feed_var_name, feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program) self._add_program_cache(cache_key, cached_program)
fetch_list_str = list(map(_to_name_str, fetch_list))
cached_ctx = self._default_executor.prepare_ctx_cache( cached_ctx = self._default_executor.prepare_ctx_cache(
cached_program.desc, 0, fetch_list, False) cached_program.desc, 0, fetch_list_str, False)
cached_var = self._default_executor.create_variables(
cached_program.desc, scope, 0)
# currently, we cache program, vars, sub_scope here
# we suppose that in a life cycle of training, a user
# will not create many programs. So, here the basic
# rule of caching is to cache all unseen (program, var, scope)
# when a user use use_program_cache.
cached_scope = scope.new_scope()
self._add_ctx_cache(cache_key, cached_ctx) self._add_ctx_cache(cache_key, cached_ctx)
self._add_var_cache(cache_key, cached_var)
self._add_scope_cache(cache_key, cached_scope)
program = cached_program program = cached_program
ctx = cached_ctx ctx = cached_ctx
scope = cached_scope
var = cached_var
else: else:
self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops( program = self._add_feed_fetch_ops(
program=program, program=program,
feed=feed, feed=feed,
...@@ -719,7 +747,7 @@ class Executor(object): ...@@ -719,7 +747,7 @@ class Executor(object):
if not use_program_cache: if not use_program_cache:
exe.run(program.desc, scope, 0, True, True, fetch_var_name) exe.run(program.desc, scope, 0, True, True, fetch_var_name)
else: else:
exe.run_cached_prepared_ctx(ctx, scope, True, True, False) exe.run_cached_prepared_ctx(ctx, scope, False, False, False)
outs = self._fetch_data(fetch_list, fetch_var_name, scope) outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid.core as core
from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data
class TestExecutor(unittest.TestCase):
def test_mul(self):
a = data(name='a', shape=[784], dtype='float32')
b = data(
name='b',
shape=[784, 100],
dtype='float32',
append_batch_size=False)
output = mul(x=a, y=b)
place = core.CPUPlace()
a_np = numpy.random.random((100, 784)).astype('float32')
b_np = numpy.random.random((784, 100)).astype('float32')
exe = Executor(place)
import time
use_cache = True
step_num = 3
run_time = 0.0
for i in range(step_num):
begin = time.time()
outs = exe.run(feed={'a': a_np,
'b': b_np},
fetch_list=[output.name],
use_program_cache=use_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
use_cache = False
run_time = 0.0
for i in range(step_num):
begin = time.time()
outs = exe.run(feed={'a': a_np,
'b': b_np},
fetch_list=[output.name],
use_program_cache=use_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
use_cache = True
run_time = 0.0
for i in range(step_num):
begin = time.time()
outs = exe.run(feed={'a': a_np,
'b': b_np},
fetch_list=[output.name],
use_program_cache=use_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
use_cache = True
run_time = 0.0
for i in range(step_num):
begin = time.time()
outs = exe.run(feed={'a': a_np,
'b': b_np},
fetch_list=[output],
use_program_cache=use_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册