未验证 提交 c9ae1362 编写于 作者: W WangXi 提交者: GitHub

[hybrid performance] pipeline add program cache (#33954)

上级 6b95e674
...@@ -1135,7 +1135,10 @@ class Executor(object): ...@@ -1135,7 +1135,10 @@ class Executor(object):
if "startup_program" in program._pipeline_opt: if "startup_program" in program._pipeline_opt:
program = program._pipeline_opt["startup_program"] program = program._pipeline_opt["startup_program"]
else: else:
return self.train_from_dataset(program, fetch_list=fetch_list) return self._run_pipeline(
program,
fetch_list=fetch_list,
use_program_cache=use_program_cache)
if isinstance(program, Program) and \ if isinstance(program, Program) and \
len(program.global_block().ops) == 0: len(program.global_block().ops) == 0:
if use_default_main_program: if use_default_main_program:
...@@ -1537,6 +1540,141 @@ class Executor(object): ...@@ -1537,6 +1540,141 @@ class Executor(object):
return None return None
def _prepare_pipeline_ctx(self,
program=None,
dataset=None,
scope=None,
thread=0,
is_infer=False,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100,
fetch_handler=None,
use_program_cache=False):
assert program._pipeline_opt is not None
assert dataset is None, "dataset should be None for pipeline mode"
cache_key = _get_strong_program_cache_key(program, None, fetch_list)
ctx = self._get_ctx_cache(cache_key)
if use_program_cache and ctx is not None:
return ctx
import paddle
# The following fake dataset is created to call
# the _prepare_trainer api, and it is meaningless.
def _get_dataset():
data_vars = []
for var in program.global_block().vars.values():
if var.is_data:
data_vars.append(var)
if core.is_compiled_with_npu():
dataset = paddle.fluid.DatasetFactory().create_dataset(
'InMemoryDataset')
else:
dataset = paddle.fluid.DatasetFactory().create_dataset(
'FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['None'])
dataset.set_use_var(data_vars)
dataset._prepare_to_run()
return dataset
dataset = _get_dataset()
def _get_real_program_fetch_list():
real_program = program._pipeline_opt["section_program"]
real_fetch_list = []
for fetch_var in fetch_list:
if isinstance(fetch_var, Variable):
fetch_var_name = fetch_var.name
else:
fetch_var_name = fetch_var
if fetch_var_name in real_program.global_block().vars:
real_fetch_list.append(fetch_var)
real_program = self._add_feed_fetch_ops(
program=real_program,
feed=[],
fetch_list=real_fetch_list,
feed_var_name='feed',
fetch_var_name='fetch')
main_block = real_program.block(0)
for op in main_block.ops:
# set the op_role of fetch op to Optimize to avoid
# erase the fetched vars by gc for pipeline
if op.type == 'fetch':
op._set_attr(
'op_role',
core.op_proto_and_checker_maker.OpRole.Optimize)
return real_program, real_fetch_list
real_program, real_fetch_list = _get_real_program_fetch_list()
program._pipeline_opt["section_program"] = real_program
fetch_list = None
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
thread=thread,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer._set_infer(is_infer)
trainer._gen_trainer_desc()
# NOTE: only for debug, very slow
# self._dump_debug_info(program=program, trainer=trainer)
# in case of calling _set_use_ps_gpu explicitly
if dataset.use_ps_gpu is False:
dataset._set_use_ps_gpu(trainer.proto_desc.use_ps_gpu)
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
trainer_desc = trainer._desc() # slow, cache
ctx = [trainer_desc, dataset, scope, real_fetch_list]
if use_program_cache: self._add_ctx_cache(cache_key, ctx)
return ctx
def _run_pipeline(self,
program=None,
dataset=None,
scope=None,
thread=0,
is_infer=False,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100,
fetch_handler=None,
use_program_cache=False):
trainer_desc, dataset, scope, real_fetch_list = \
self._prepare_pipeline_ctx(program, dataset, scope, thread,
is_infer, debug, fetch_list, fetch_info,
print_period, fetch_handler,
use_program_cache)
trainer_instance = self._default_executor.init_for_dataset(
program.desc, trainer_desc, scope, dataset.dataset)
self._default_executor.run_from_dataset(trainer_instance)
self._default_executor.release_trainer(trainer_instance)
dataset._dynamic_adjust_after_train()
dataset._finish_to_run()
if real_fetch_list:
arr = scope.find_var('fetch').get_fetch_list()
tensors = arr._move_to_list()
return as_numpy(tensors)
return None
def infer_from_dataset(self, def infer_from_dataset(self,
program=None, program=None,
dataset=None, dataset=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册