提交 28eb496f 编写于 作者: W wuzewu

Add py_reader processing flow

上级 2620edc3
......@@ -29,6 +29,7 @@ class RunConfig(object):
def __init__(self,
log_interval=10,
eval_interval=100,
use_pyreader=False,
save_ckpt_interval=None,
use_cuda=True,
checkpoint_dir=None,
......@@ -44,6 +45,7 @@ class RunConfig(object):
self._checkpoint_dir = checkpoint_dir
self._num_epoch = num_epoch
self._batch_size = batch_size
self._use_pyreader = use_pyreader
if strategy is None:
self._strategy = DefaultStrategy()
else:
......@@ -93,3 +95,7 @@ class RunConfig(object):
@property
def enable_memory_optim(self):
return self._enable_memory_optim
@property
def use_pyreader(self):
return self._use_pyreader
......@@ -26,6 +26,7 @@ import paddle.fluid as fluid
from visualdl import LogWriter
import paddlehub as hub
from paddlehub.common.paddle_helper import dtype_map
from paddlehub.common.utils import mkdir
from paddlehub.common.logger import logger
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
......@@ -77,6 +78,9 @@ class BasicTask(object):
self.config)
self.exe = fluid.Executor(place=self.place)
self.feed_list = feed_list
self.feed_variables = [
main_program.global_block().vars[var_name] for var_name in feed_list
]
self.metrics = []
self.is_inititalized = False
self.current_step = 0
......@@ -127,6 +131,63 @@ class BasicTask(object):
def _add_metrics(self):
raise NotImplementedError
def _add_py_reader(self):
for program, add_label in ((self.main_program,
True), (self.test_program, True),
(self.inference_program, False)):
temp_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(temp_program, startup_program):
feed_variables = self.feed_variables
if add_label:
feed_variables = feed_variables + [self.label]
feed_list = self.feed_list
if add_label:
feed_list = feed_list + [self.label.name]
py_reader = fluid.layers.py_reader(
capacity=16,
shapes=[var.shape for var in feed_variables],
lod_levels=[var.lod_level for var in feed_variables],
dtypes=[dtype_map[var.dtype] for var in feed_variables],
use_double_buffer=True)
feed_variables = fluid.layers.read_file(py_reader)
input_dict = {
key: feed_variables[index]
for index, key in enumerate(feed_list)
}
hub.connect_program(
pre_program=temp_program,
next_program=program,
input_dict=input_dict,
inplace=True)
self.exe.run(startup_program)
if program == self.main_program:
self.main_program = temp_program
self.loss = self.main_program.global_block().vars[
self.loss.name]
for index, metric in enumerate(self.metrics):
self.metrics[index] = self.main_program.global_block().vars[
metric.name]
self.output = self.main_program.global_block().vars[
self.output.name]
self.loss.persistable = True
for metric in self.metrics:
metric.persistable = True
self.output.persistable = True
self.main_py_reader = py_reader
elif program == self.test_program:
self.test_program = temp_program
self.test_py_reader = py_reader
elif program == self.inference_program:
self.inference_program = temp_program
self.inference_py_reader = py_reader
def _init_if_necessary(self, load_best_model=False):
if not self.is_inititalized:
self._init_start_event()
......@@ -137,12 +198,17 @@ class BasicTask(object):
self._add_loss()
self._add_metrics()
self.test_program = self.main_program.clone(for_test=True)
if self.config.use_pyreader:
self._add_py_reader()
with fluid.program_guard(self.main_program):
self.config.strategy.execute(self.loss, self.data_reader,
self.config)
self.loss.persistable = True
for metrics in self.metrics:
metrics.persistable = True
for metric in self.metrics:
metric.persistable = True
self.output.persistable = True
self.build_strategy = fluid.BuildStrategy()
......@@ -187,7 +253,8 @@ class BasicTask(object):
self.current_epoch, self.current_step = load_checkpoint(
self.config.checkpoint_dir,
self.exe,
main_program=self.main_program)
main_program=self.main_program,
startup_program=self.startup_program)
if load_best_model:
model_saved_dir = os.path.join(self.config.checkpoint_dir,
......@@ -245,11 +312,17 @@ class BasicTask(object):
test_reader = self.data_reader.data_generator(
batch_size=self.config.batch_size, phase=phase)
run_states = self._run(
test_reader, phase=phase, program_compiled=self.test_program)
test_reader,
phase=phase,
program_compiled=self.test_program_compiled)
self._eval_end_event(phase, run_states)
def _run(self, reader, phase, do_eval=False, program_compiled=None):
def _run_with_data_feeder(self,
reader,
phase,
do_eval=False,
program_compiled=None):
if program_compiled is None:
program_compiled = self.main_program_compiled
feed_list = self.get_feed_list(phase=phase)
......@@ -291,6 +364,73 @@ class BasicTask(object):
global_run_states += period_run_states
return global_run_states
def _run_with_py_reader(self,
reader,
phase,
do_eval=False,
program_compiled=None):
if program_compiled is None:
program_compiled = self.main_program_compiled
if phase == "train":
py_reader = self.main_py_reader
elif phase in ["dev", "val", "test"]:
py_reader = self.test_py_reader
elif phase == "predict":
py_reader = self.inference_py_reader
py_reader.decorate_paddle_reader(reader)
fetch_list = self.get_fetch_list(phase=phase)
global_run_states = []
period_run_states = []
py_reader.start()
try:
while True:
num_batch_examples = self.config.batch_size
step_run_state = RunState(len(fetch_list))
step_run_state.run_step = 1
fetch_result = self.exe.run(
program_compiled, fetch_list=fetch_list)
for index, result in enumerate(fetch_result):
step_run_state.run_results[index] = result
step_run_state.run_examples += num_batch_examples
step_run_state.update()
period_run_states += [step_run_state]
if phase == "train":
self.current_step += 1
if self.current_step % self.config.log_interval == 0:
self._log_interval_event(period_run_states)
global_run_states += period_run_states
period_run_states = []
if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0:
self._save_ckpt_interval_event()
if do_eval and self.current_step % self.config.eval_interval == 0:
self._eval_interval_event()
self._run_step_event(phase, step_run_state)
except fluid.core.EOFException:
py_reader.reset()
global_run_states += period_run_states
return global_run_states
def _run(self, reader, phase, do_eval=False, program_compiled=None):
if self.config.use_pyreader:
return self._run_with_py_reader(
reader,
phase,
do_eval=do_eval,
program_compiled=program_compiled)
else:
return self._run_with_data_feeder(
reader,
phase,
do_eval=do_eval,
program_compiled=program_compiled)
def predict(self, data, load_best_model=True):
self._init_if_necessary(load_best_model=load_best_model)
with fluid.program_guard(self.inference_program):
......@@ -299,7 +439,7 @@ class BasicTask(object):
for run_state in self._run(
inference_reader,
phase='predict',
program_compiled=self.inference_program):
program_compiled=self.inference_program_compiled):
yield run_state.run_results
......@@ -408,7 +548,7 @@ class ClassifierTask(BasicTask):
save_result = fluid.io.save_persistables(
executor=self.exe,
dirname=model_saved_dir,
main_program=self.main_program)
main_program=self.test_program)
ImageClassifierTask = ClassifierTask
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册