未验证 提交 16e4d026 编写于 作者: Y Yiqun Liu 提交者: GitHub

Refine the cache of program, context and scope in executor. (#18483)

* Refine the cache of program, context and scope in executor.
test=develop

* Refine the unittest test_executor_and_use_program_cache.

* Add the test the PaddingRNN with use_program_cache=True.
test=develop

* Remove a check.
test=develop

* Refine the unittest to check whether it is correct when setting use_program_cache=True.
test=develop
上级 b4897600
......@@ -287,12 +287,6 @@ static bool has_fetch_operators(
return fetch_count > 0;
}
std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc);
}
void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
......
......@@ -95,12 +95,6 @@ class Executor {
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
std::unique_ptr<ExecutorPrepareContext> PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
......
......@@ -1389,7 +1389,7 @@ All parameter, weight, gradient are variables in Paddle.
create_local_scope, create_vars,
feed_holder_name, fetch_holder_name);
})
.def("run_cached_prepared_ctx",
.def("run_prepared_ctx",
[](Executor &self, ExecutorPrepareContext *ctx, Scope *scope,
bool create_local_scope = true, bool create_vars = true,
bool keep_kids = false) {
......@@ -1397,10 +1397,16 @@ All parameter, weight, gradient are variables in Paddle.
self.RunPreparedContext(ctx, scope, create_local_scope,
create_vars, keep_kids);
})
.def("prepare_ctx_cache", &Executor::PrepareCtxCache,
py::call_guard<py::gil_scoped_release>())
.def("create_variables", &Executor::CreateVariables,
py::call_guard<py::gil_scoped_release>())
.def("prepare",
[](Executor &self, const ProgramDesc &program, int block_id,
const std::vector<std::string> &skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false) {
pybind11::gil_scoped_release release;
return self.Prepare(program, block_id, skip_ref_cnt_vars,
force_disable_gc);
})
.def("create_variables", &Executor::CreateVariables)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) {
......
......@@ -489,9 +489,6 @@ class Executor(object):
self._default_executor = core.Executor(p)
self._closed = False
def _get_var_cache(self, program_cache_key):
return self.var_caches.get(program_cache_key, None)
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
......@@ -510,9 +507,6 @@ class Executor(object):
def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope
def _add_var_cache(self, var_cache_key, var):
self.var_caches[var_cache_key] = var
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
tmp_program = program.clone()
......@@ -853,7 +847,6 @@ class Executor(object):
cached_program = self._get_program_cache(cache_key)
cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key)
cached_var = self._get_var_cache(cache_key)
if cached_program is None:
cached_program = self._add_feed_fetch_ops(
program=program,
......@@ -863,23 +856,21 @@ class Executor(object):
fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program)
fetch_list_str = list(map(_to_name_str, fetch_list))
cached_ctx = self._default_executor.prepare_ctx_cache(
cached_ctx = self._default_executor.prepare(
cached_program.desc, 0, fetch_list_str, False)
cached_var = self._default_executor.create_variables(
cached_program.desc, scope, 0)
# currently, we cache program, vars, sub_scope here
# we suppose that in a life cycle of training, a user
# will not create many programs. So, here the basic
# rule of caching is to cache all unseen (program, var, scope)
# when a user use use_program_cache.
cached_scope = scope.new_scope()
self._default_executor.create_variables(cached_program.desc,
cached_scope, 0)
self._add_ctx_cache(cache_key, cached_ctx)
self._add_var_cache(cache_key, cached_var)
self._add_scope_cache(cache_key, cached_scope)
program = cached_program
ctx = cached_ctx
scope = cached_scope
var = cached_var
else:
program = self._add_feed_fetch_ops(
program=program,
......@@ -893,8 +884,8 @@ class Executor(object):
self._default_executor.run(program.desc, scope, 0, True, True,
fetch_var_name)
else:
self._default_executor.run_cached_prepared_ctx(ctx, scope, False,
False, False)
self._default_executor.run_prepared_ctx(ctx, scope, False, False,
False)
arr = scope.find_var(fetch_var_name).get_lod_tensor_array()
tensors = arr._move_to_list()
if return_numpy:
......
......@@ -32,7 +32,7 @@ from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN
os.environ["CPU_NUM"] = "1"
class RnnConfig(object):
class RNNConfig(object):
def __init__(self, model_type, rnn_model):
self.model_type = model_type
self.rnn_model = rnn_model
......@@ -478,11 +478,12 @@ def lm_model(hidden_size,
return loss, last_hidden, last_cell, feeding_list
class EagerDeletionPaddingRnnTest(unittest.TestCase):
class PaddingRNNTestBase(unittest.TestCase):
def setUp(self):
self.reader = Reader()
self.device_count = 1
def prepare_program(self, config):
def prepare_program(self, config, parallel=True):
self.main_program = fluid.Program()
self.startup_program = fluid.Program()
self.startup_program.random_seed = config.random_seed
......@@ -517,7 +518,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
self.exe = Executor(fluid.CPUPlace())
self.exe.run(self.startup_program)
self.device_count = 1
if parallel:
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = self.device_count
exec_strategy.num_iteration_per_drop_scope = 100
......@@ -532,6 +533,8 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
loss_name=self.loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
else:
self.train_program = self.main_program
def generate_init_data(self):
init_hidden = np.zeros(
......@@ -572,7 +575,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
res['learning_rate'] = self.generate_new_lr(epoch_id, device_count)
return res
def train_an_epoch(self, epoch_id, batch_times):
def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True):
train_data_iter = self.reader.get_data_iter(self.config)
total_loss = 0
......@@ -597,7 +600,7 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
self.last_hidden.name,
self.last_cell.name
],
use_program_cache=True)
use_program_cache=use_program_cache)
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
......@@ -613,47 +616,53 @@ class EagerDeletionPaddingRnnTest(unittest.TestCase):
ppl = np.append(ppl, batch_ppl)
return ppl
def train(self, config):
def train(self, config, parallel=True, use_program_cache=True):
self.config = config
self.prepare_program(config)
self.prepare_program(config, parallel)
total_time = 0.0
ppl = np.zeros(shape=(0, config.batch_size))
for epoch_id in range(config.max_epoch):
batch_times = []
epoch_start_time = time.time()
train_ppl = self.train_an_epoch(epoch_id, batch_times)
train_ppl = self.train_an_epoch(epoch_id, batch_times,
use_program_cache)
epoch_time = time.time() - epoch_start_time
total_time += epoch_time
ppl = np.append(ppl, train_ppl)
return ppl
def compare_padding_static_mode(self):
def compare_padding_static_mode(self, parallel=True,
use_program_cache=True):
'''
Test that train ppl of padding mode is same to that of static mode
'''
config = RnnConfig('test', 'padding')
config = RNNConfig('test', 'padding')
with fluid.scope_guard(fluid.Scope()):
padding_rnn_ppl = self.train(config)
config = RnnConfig('test', 'static')
padding_rnn_ppl = self.train(config, parallel, use_program_cache)
config = RNNConfig('test', 'static')
with fluid.scope_guard(fluid.Scope()):
static_rnn_ppl = self.train(config)
static_rnn_ppl = self.train(config, parallel, use_program_cache)
self.assertTrue(
np.isclose(
padding_rnn_ppl, static_rnn_ppl, rtol=0.001).all())
class EagerDeletionPaddingRNNTest(PaddingRNNTestBase):
def test_padding_mode_no_eager_deletion(self):
'''
Test that train ppl of padding mode is same to that of static mode without eager deletion
'''
fluid.core._set_eager_deletion_mode(-1.0, 1.0, True)
self.compare_padding_static_mode()
# When parallel is True, use_program_cache does not make a difference.
self.compare_padding_static_mode(parallel=True, use_program_cache=True)
def test_padding_mode_eager_deletion(self):
'''
Test that train ppl of padding mode is same to that of static mode under eager deletion
'''
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
self.compare_padding_static_mode()
# When parallel is True, use_program_cache does not make a difference.
self.compare_padding_static_mode(parallel=True, use_program_cache=True)
if __name__ == '__main__':
......
......@@ -18,82 +18,130 @@ import unittest
import numpy
import paddle.fluid.core as core
from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data
import paddle.fluid as fluid
from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase
class TestExecutor(unittest.TestCase):
def test_mul(self):
a = data(name='a', shape=[784], dtype='float32')
b = data(
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
a = fluid.layers.data(name='a', shape=[784], dtype='float32')
b = fluid.layers.data(
name='b',
shape=[784, 100],
dtype='float32',
append_batch_size=False)
output = mul(x=a, y=b)
place = core.CPUPlace()
output = fluid.layers.mul(x=a, y=b)
# Compute with numpy
a_np = numpy.random.random((100, 784)).astype('float32')
b_np = numpy.random.random((784, 100)).astype('float32')
exe = Executor(place)
out_np = numpy.dot(a_np, b_np)
place = core.CPUPlace()
exe = fluid.Executor(place)
def _train(use_program_cache, max_iters=1):
import time
use_cache = True
step_num = 3
run_time = 0.0
for i in range(step_num):
begin = time.time()
outs = exe.run(feed={'a': a_np,
'b': b_np},
fetch_list=[output.name],
use_program_cache=use_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
use_cache = False
run_time = 0.0
for i in range(step_num):
begin = time.time()
outs = exe.run(feed={'a': a_np,
'b': b_np},
fetch_list=[output.name],
use_program_cache=use_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
use_cache = True
run_time = 0.0
for i in range(step_num):
for i in range(max_iters):
begin = time.time()
outs = exe.run(feed={'a': a_np,
outs = exe.run(program=main_program,
feed={'a': a_np,
'b': b_np},
fetch_list=[output.name],
use_program_cache=use_cache)
use_program_cache=use_program_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
self.assertTrue(numpy.allclose(out, out_np))
return run_time
use_cache = True
run_time = 0.0
for i in range(step_num):
begin = time.time()
outs = exe.run(feed={'a': a_np,
'b': b_np},
fetch_list=[output],
use_program_cache=use_cache)
end = time.time()
run_time += end - begin
out = outs[0]
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
print("run time %f" % run_time)
max_iters = 3
run_time_with_cache = _train(
use_program_cache=True, max_iters=max_iters)
print("run time with program cache: %f" % run_time_with_cache)
run_time_without_cache = _train(
use_program_cache=False, max_iters=max_iters)
print("run time without program cache: %f" % run_time_without_cache)
run_time_with_cache = _train(
use_program_cache=True, max_iters=max_iters)
print("run time with program cache: %f" % run_time_with_cache)
run_time_with_cache = _train(
use_program_cache=True, max_iters=max_iters)
print("run time with program cache: %f" % run_time_with_cache)
class ExecutorPaddingRNNTest(PaddingRNNTestBase):
def train_and_save_inference_program(self,
rnn_model="static",
parallel=True,
use_program_cache=True):
config = RNNConfig("test", rnn_model)
with fluid.scope_guard(fluid.Scope()):
self.train(config, parallel, use_program_cache)
fluid.io.save_inference_model(
main_program=self.main_program,
feeded_var_names=self.feed_order,
target_vars=[self.loss, self.last_hidden, self.last_cell],
executor=self.exe,
dirname="padding_rnn." + rnn_model + ".inference_model",
params_filename="__params__")
def test_inference_output(self):
for rnn_model in ["static", "padding"]:
# Set parallel to False to use the default executor.
self.train_and_save_inference_program(
rnn_model=rnn_model, parallel=True, use_program_cache=True)
x_np = numpy.random.random(
(self.config.batch_size, self.config.num_steps,
1)).astype("int64")
y_np = numpy.random.random(
(self.config.batch_size * self.config.num_steps,
1)).astype("int64")
init_hidden_np = numpy.random.random(
(self.config.num_layers, self.config.batch_size,
self.config.hidden_size)).astype("float32")
init_cell_np = numpy.random.random(
(self.config.num_layers, self.config.batch_size,
self.config.hidden_size)).astype("float32")
for use_program_cache in [False, True]:
with fluid.scope_guard(fluid.Scope()):
save_dirname = "padding_rnn." + rnn_model + ".inference_model"
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, self.exe, params_filename="__params__")
results = self.exe.run(program=inference_program,
feed={
"x": x_np,
"y": y_np,
"init_hidden": init_hidden_np,
"init_cell": init_cell_np
},
fetch_list=fetch_targets,
use_program_cache=use_program_cache)
if use_program_cache is True:
results_with_cache = results
else:
results_without_cache = results
self.assertEqual(
len(results_with_cache), len(results_without_cache))
for i in range(len(results_with_cache)):
self.assertEqual(results_with_cache[i].shape,
results_without_cache[i].shape)
self.assertTrue(
numpy.allclose(results_with_cache[i], results_without_cache[
i]))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册