未验证 提交 16e4d026 编写于 作者: Y Yiqun Liu 提交者: GitHub

Refine the cache of program, context and scope in executor. (#18483)

* Refine the cache of program, context and scope in executor.
test=develop

* Refine the unittest test_executor_and_use_program_cache.

* Add the test the PaddingRNN with use_program_cache=True.
test=develop

* Remove a check.
test=develop

* Refine the unittest to check whether it is correct when setting use_program_cache=True.
test=develop
上级 b4897600
...@@ -287,12 +287,6 @@ static bool has_fetch_operators( ...@@ -287,12 +287,6 @@ static bool has_fetch_operators(
return fetch_count > 0; return fetch_count > 0;
} }
std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc);
}
void Executor::Run(const ProgramDesc& program, Scope* scope, void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets, std::map<std::string, LoDTensor*>* fetch_targets,
......
...@@ -95,12 +95,6 @@ class Executor { ...@@ -95,12 +95,6 @@ class Executor {
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
std::unique_ptr<ExecutorPrepareContext> PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
......
...@@ -1389,7 +1389,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1389,7 +1389,7 @@ All parameter, weight, gradient are variables in Paddle.
create_local_scope, create_vars, create_local_scope, create_vars,
feed_holder_name, fetch_holder_name); feed_holder_name, fetch_holder_name);
}) })
.def("run_cached_prepared_ctx", .def("run_prepared_ctx",
[](Executor &self, ExecutorPrepareContext *ctx, Scope *scope, [](Executor &self, ExecutorPrepareContext *ctx, Scope *scope,
bool create_local_scope = true, bool create_vars = true, bool create_local_scope = true, bool create_vars = true,
bool keep_kids = false) { bool keep_kids = false) {
...@@ -1397,10 +1397,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1397,10 +1397,16 @@ All parameter, weight, gradient are variables in Paddle.
self.RunPreparedContext(ctx, scope, create_local_scope, self.RunPreparedContext(ctx, scope, create_local_scope,
create_vars, keep_kids); create_vars, keep_kids);
}) })
.def("prepare_ctx_cache", &Executor::PrepareCtxCache, .def("prepare",
py::call_guard<py::gil_scoped_release>()) [](Executor &self, const ProgramDesc &program, int block_id,
.def("create_variables", &Executor::CreateVariables, const std::vector<std::string> &skip_ref_cnt_vars =
py::call_guard<py::gil_scoped_release>()) std::vector<std::string>(),
bool force_disable_gc = false) {
pybind11::gil_scoped_release release;
return self.Prepare(program, block_id, skip_ref_cnt_vars,
force_disable_gc);
})
.def("create_variables", &Executor::CreateVariables)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars, int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) { const std::vector<std::string> &fetch_vars) {
......
...@@ -489,9 +489,6 @@ class Executor(object): ...@@ -489,9 +489,6 @@ class Executor(object):
self._default_executor = core.Executor(p) self._default_executor = core.Executor(p)
self._closed = False self._closed = False
def _get_var_cache(self, program_cache_key):
return self.var_caches.get(program_cache_key, None)
def _get_scope_cache(self, program_cache_key): def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None) return self.scope_caches.get(program_cache_key, None)
...@@ -510,9 +507,6 @@ class Executor(object): ...@@ -510,9 +507,6 @@ class Executor(object):
def _add_scope_cache(self, scope_cache_key, scope): def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope self.scope_caches[scope_cache_key] = scope
def _add_var_cache(self, var_cache_key, var):
self.var_caches[var_cache_key] = var
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name): fetch_var_name):
tmp_program = program.clone() tmp_program = program.clone()
...@@ -853,7 +847,6 @@ class Executor(object): ...@@ -853,7 +847,6 @@ class Executor(object):
cached_program = self._get_program_cache(cache_key) cached_program = self._get_program_cache(cache_key)
cached_ctx = self._get_ctx_cache(cache_key) cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key) cached_scope = self._get_scope_cache(cache_key)
cached_var = self._get_var_cache(cache_key)
if cached_program is None: if cached_program is None:
cached_program = self._add_feed_fetch_ops( cached_program = self._add_feed_fetch_ops(
program=program, program=program,
...@@ -863,23 +856,21 @@ class Executor(object): ...@@ -863,23 +856,21 @@ class Executor(object):
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program) self._add_program_cache(cache_key, cached_program)
fetch_list_str = list(map(_to_name_str, fetch_list)) fetch_list_str = list(map(_to_name_str, fetch_list))
cached_ctx = self._default_executor.prepare_ctx_cache( cached_ctx = self._default_executor.prepare(
cached_program.desc, 0, fetch_list_str, False) cached_program.desc, 0, fetch_list_str, False)
cached_var = self._default_executor.create_variables(
cached_program.desc, scope, 0)
# currently, we cache program, vars, sub_scope here # currently, we cache program, vars, sub_scope here
# we suppose that in a life cycle of training, a user # we suppose that in a life cycle of training, a user
# will not create many programs. So, here the basic # will not create many programs. So, here the basic
# rule of caching is to cache all unseen (program, var, scope) # rule of caching is to cache all unseen (program, var, scope)
# when a user use use_program_cache. # when a user use use_program_cache.
cached_scope = scope.new_scope() cached_scope = scope.new_scope()
self._default_executor.create_variables(cached_program.desc,
cached_scope, 0)
self._add_ctx_cache(cache_key, cached_ctx) self._add_ctx_cache(cache_key, cached_ctx)
self._add_var_cache(cache_key, cached_var)
self._add_scope_cache(cache_key, cached_scope) self._add_scope_cache(cache_key, cached_scope)
program = cached_program program = cached_program
ctx = cached_ctx ctx = cached_ctx
scope = cached_scope scope = cached_scope
var = cached_var
else: else:
program = self._add_feed_fetch_ops( program = self._add_feed_fetch_ops(
program=program, program=program,
...@@ -893,8 +884,8 @@ class Executor(object): ...@@ -893,8 +884,8 @@ class Executor(object):
self._default_executor.run(program.desc, scope, 0, True, True, self._default_executor.run(program.desc, scope, 0, True, True,
fetch_var_name) fetch_var_name)
else: else:
self._default_executor.run_cached_prepared_ctx(ctx, scope, False, self._default_executor.run_prepared_ctx(ctx, scope, False, False,
False, False) False)
arr = scope.find_var(fetch_var_name).get_lod_tensor_array() arr = scope.find_var(fetch_var_name).get_lod_tensor_array()
tensors = arr._move_to_list() tensors = arr._move_to_list()
if return_numpy: if return_numpy:
......
...@@ -32,7 +32,7 @@ from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN ...@@ -32,7 +32,7 @@ from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN
os.environ["CPU_NUM"] = "1" os.environ["CPU_NUM"] = "1"
class RnnConfig(object): class RNNConfig(object):
def __init__(self, model_type, rnn_model): def __init__(self, model_type, rnn_model):
self.model_type = model_type self.model_type = model_type
self.rnn_model = rnn_model self.rnn_model = rnn_model
...@@ -478,11 +478,12 @@ def lm_model(hidden_size, ...@@ -478,11 +478,12 @@ def lm_model(hidden_size,
return loss, last_hidden, last_cell, feeding_list return loss, last_hidden, last_cell, feeding_list
class EagerDeletionPaddingRnnTest(unittest.TestCase): class PaddingRNNTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.reader = Reader() self.reader = Reader()
self.device_count = 1
def prepare_program(self, config): def prepare_program(self, config, parallel=True):
self.main_program = fluid.Program() self.main_program = fluid.Program()
self.startup_program = fluid.Program() self.startup_program = fluid.Program()
self.startup_program.random_seed = config.random_seed self.startup_program.random_seed = config.random_seed
...@@ -517,21 +518,23 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): ...@@ -517,21 +518,23 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
self.exe = Executor(fluid.CPUPlace()) self.exe = Executor(fluid.CPUPlace())
self.exe.run(self.startup_program) self.exe.run(self.startup_program)
self.device_count = 1 if parallel:
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = self.device_count exec_strategy.num_threads = self.device_count
exec_strategy.num_iteration_per_drop_scope = 100 exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
build_strategy.fuse_all_optimizer_ops = True build_strategy.fuse_all_optimizer_ops = True
self.train_program = fluid.compiler.CompiledProgram( self.train_program = fluid.compiler.CompiledProgram(
self.main_program).with_data_parallel( self.main_program).with_data_parallel(
loss_name=self.loss.name, loss_name=self.loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
else:
self.train_program = self.main_program
def generate_init_data(self): def generate_init_data(self):
init_hidden = np.zeros( init_hidden = np.zeros(
...@@ -572,7 +575,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): ...@@ -572,7 +575,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
res['learning_rate'] = self.generate_new_lr(epoch_id, device_count) res['learning_rate'] = self.generate_new_lr(epoch_id, device_count)
return res return res
def train_an_epoch(self, epoch_id, batch_times): def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True):
train_data_iter = self.reader.get_data_iter(self.config) train_data_iter = self.reader.get_data_iter(self.config)
total_loss = 0 total_loss = 0
...@@ -597,7 +600,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): ...@@ -597,7 +600,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
self.last_hidden.name, self.last_hidden.name,
self.last_cell.name self.last_cell.name
], ],
use_program_cache=True) use_program_cache=use_program_cache)
batch_time = time.time() - batch_start_time batch_time = time.time() - batch_start_time
batch_times.append(batch_time) batch_times.append(batch_time)
...@@ -613,47 +616,53 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase): ...@@ -613,47 +616,53 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
ppl = np.append(ppl, batch_ppl) ppl = np.append(ppl, batch_ppl)
return ppl return ppl
def train(self, config): def train(self, config, parallel=True, use_program_cache=True):
self.config = config self.config = config
self.prepare_program(config) self.prepare_program(config, parallel)
total_time = 0.0 total_time = 0.0
ppl = np.zeros(shape=(0, config.batch_size)) ppl = np.zeros(shape=(0, config.batch_size))
for epoch_id in range(config.max_epoch): for epoch_id in range(config.max_epoch):
batch_times = [] batch_times = []
epoch_start_time = time.time() epoch_start_time = time.time()
train_ppl = self.train_an_epoch(epoch_id, batch_times) train_ppl = self.train_an_epoch(epoch_id, batch_times,
use_program_cache)
epoch_time = time.time() - epoch_start_time epoch_time = time.time() - epoch_start_time
total_time += epoch_time total_time += epoch_time
ppl = np.append(ppl, train_ppl) ppl = np.append(ppl, train_ppl)
return ppl return ppl
def compare_padding_static_mode(self): def compare_padding_static_mode(self, parallel=True,
use_program_cache=True):
''' '''
Test that train ppl of padding mode is same to that of static mode Test that train ppl of padding mode is same to that of static mode
''' '''
config = RnnConfig('test', 'padding') config = RNNConfig('test', 'padding')
with fluid.scope_guard(fluid.Scope()): with fluid.scope_guard(fluid.Scope()):
padding_rnn_ppl = self.train(config) padding_rnn_ppl = self.train(config, parallel, use_program_cache)
config = RnnConfig('test', 'static') config = RNNConfig('test', 'static')
with fluid.scope_guard(fluid.Scope()): with fluid.scope_guard(fluid.Scope()):
static_rnn_ppl = self.train(config) static_rnn_ppl = self.train(config, parallel, use_program_cache)
self.assertTrue( self.assertTrue(
np.isclose( np.isclose(
padding_rnn_ppl, static_rnn_ppl, rtol=0.001).all()) padding_rnn_ppl, static_rnn_ppl, rtol=0.001).all())
class EagerDeletionPaddingRNNTest(PaddingRNNTestBase):
def test_padding_mode_no_eager_deletion(self): def test_padding_mode_no_eager_deletion(self):
''' '''
Test that train ppl of padding mode is same to that of static mode without eager deletion Test that train ppl of padding mode is same to that of static mode without eager deletion
''' '''
fluid.core._set_eager_deletion_mode(-1.0, 1.0, True) fluid.core._set_eager_deletion_mode(-1.0, 1.0, True)
self.compare_padding_static_mode() # When parallel is True, use_program_cache does not make a difference.
self.compare_padding_static_mode(parallel=True, use_program_cache=True)
def test_padding_mode_eager_deletion(self): def test_padding_mode_eager_deletion(self):
''' '''
Test that train ppl of padding mode is same to that of static mode under eager deletion Test that train ppl of padding mode is same to that of static mode under eager deletion
''' '''
fluid.core._set_eager_deletion_mode(0.0, 1.0, True) fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
self.compare_padding_static_mode() # When parallel is True, use_program_cache does not make a difference.
self.compare_padding_static_mode(parallel=True, use_program_cache=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,82 +18,130 @@ import unittest ...@@ -18,82 +18,130 @@ import unittest
import numpy import numpy
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.executor import Executor import paddle.fluid as fluid
from paddle.fluid.layers import mul, data from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase
class TestExecutor(unittest.TestCase): class TestExecutor(unittest.TestCase):
def test_mul(self): def test_mul(self):
a = data(name='a', shape=[784], dtype='float32') main_program = fluid.Program()
b = data( startup_program = fluid.Program()
name='b', with fluid.program_guard(main_program, startup_program):
shape=[784, 100], a = fluid.layers.data(name='a', shape=[784], dtype='float32')
dtype='float32', b = fluid.layers.data(
append_batch_size=False) name='b',
output = mul(x=a, y=b) shape=[784, 100],
place = core.CPUPlace() dtype='float32',
append_batch_size=False)
output = fluid.layers.mul(x=a, y=b)
# Compute with numpy
a_np = numpy.random.random((100, 784)).astype('float32') a_np = numpy.random.random((100, 784)).astype('float32')
b_np = numpy.random.random((784, 100)).astype('float32') b_np = numpy.random.random((784, 100)).astype('float32')
exe = Executor(place) out_np = numpy.dot(a_np, b_np)
import time
use_cache = True place = core.CPUPlace()
step_num = 3 exe = fluid.Executor(place)
run_time = 0.0
for i in range(step_num): def _train(use_program_cache, max_iters=1):
begin = time.time() import time
outs = exe.run(feed={'a': a_np,
'b': b_np}, run_time = 0.0
fetch_list=[output.name], for i in range(max_iters):
use_program_cache=use_cache) begin = time.time()
end = time.time() outs = exe.run(program=main_program,
run_time += end - begin feed={'a': a_np,
out = outs[0] 'b': b_np},
self.assertEqual((100, 100), out.shape) fetch_list=[output.name],
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) use_program_cache=use_program_cache)
print("run time %f" % run_time) end = time.time()
use_cache = False run_time += end - begin
run_time = 0.0 out = outs[0]
for i in range(step_num): self.assertEqual((100, 100), out.shape)
begin = time.time() self.assertTrue(numpy.allclose(out, out_np))
outs = exe.run(feed={'a': a_np, return run_time
'b': b_np},
fetch_list=[output.name], max_iters = 3
use_program_cache=use_cache) run_time_with_cache = _train(
end = time.time() use_program_cache=True, max_iters=max_iters)
run_time += end - begin print("run time with program cache: %f" % run_time_with_cache)
out = outs[0]
self.assertEqual((100, 100), out.shape) run_time_without_cache = _train(
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) use_program_cache=False, max_iters=max_iters)
print("run time %f" % run_time) print("run time without program cache: %f" % run_time_without_cache)
use_cache = True
run_time = 0.0 run_time_with_cache = _train(
for i in range(step_num): use_program_cache=True, max_iters=max_iters)
begin = time.time() print("run time with program cache: %f" % run_time_with_cache)
outs = exe.run(feed={'a': a_np,
'b': b_np}, run_time_with_cache = _train(
fetch_list=[output.name], use_program_cache=True, max_iters=max_iters)
use_program_cache=use_cache) print("run time with program cache: %f" % run_time_with_cache)
end = time.time()
run_time += end - begin
out = outs[0] class ExecutorPaddingRNNTest(PaddingRNNTestBase):
self.assertEqual((100, 100), out.shape) def train_and_save_inference_program(self,
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) rnn_model="static",
print("run time %f" % run_time) parallel=True,
use_program_cache=True):
use_cache = True config = RNNConfig("test", rnn_model)
run_time = 0.0 with fluid.scope_guard(fluid.Scope()):
for i in range(step_num): self.train(config, parallel, use_program_cache)
begin = time.time() fluid.io.save_inference_model(
outs = exe.run(feed={'a': a_np, main_program=self.main_program,
'b': b_np}, feeded_var_names=self.feed_order,
fetch_list=[output], target_vars=[self.loss, self.last_hidden, self.last_cell],
use_program_cache=use_cache) executor=self.exe,
end = time.time() dirname="padding_rnn." + rnn_model + ".inference_model",
run_time += end - begin params_filename="__params__")
out = outs[0]
self.assertEqual((100, 100), out.shape) def test_inference_output(self):
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np))) for rnn_model in ["static", "padding"]:
print("run time %f" % run_time) # Set parallel to False to use the default executor.
self.train_and_save_inference_program(
rnn_model=rnn_model, parallel=True, use_program_cache=True)
x_np = numpy.random.random(
(self.config.batch_size, self.config.num_steps,
1)).astype("int64")
y_np = numpy.random.random(
(self.config.batch_size * self.config.num_steps,
1)).astype("int64")
init_hidden_np = numpy.random.random(
(self.config.num_layers, self.config.batch_size,
self.config.hidden_size)).astype("float32")
init_cell_np = numpy.random.random(
(self.config.num_layers, self.config.batch_size,
self.config.hidden_size)).astype("float32")
for use_program_cache in [False, True]:
with fluid.scope_guard(fluid.Scope()):
save_dirname = "padding_rnn." + rnn_model + ".inference_model"
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, self.exe, params_filename="__params__")
results = self.exe.run(program=inference_program,
feed={
"x": x_np,
"y": y_np,
"init_hidden": init_hidden_np,
"init_cell": init_cell_np
},
fetch_list=fetch_targets,
use_program_cache=use_program_cache)
if use_program_cache is True:
results_with_cache = results
else:
results_without_cache = results
self.assertEqual(
len(results_with_cache), len(results_without_cache))
for i in range(len(results_with_cache)):
self.assertEqual(results_with_cache[i].shape,
results_without_cache[i].shape)
self.assertTrue(
numpy.allclose(results_with_cache[i], results_without_cache[
i]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册