提交 0ad41389 编写于 作者: W wuzewu

Optimize execution order

上级 90ed127a
......@@ -27,11 +27,7 @@ from paddlehub.common.logger import logger
CKPT_FILE_NAME = "ckpt.meta"
def load_checkpoint(checkpoint_dir,
exe,
main_program=fluid.default_main_program(),
startup_program=fluid.default_startup_program(),
load_best_model=False):
def load_checkpoint(checkpoint_dir, exe, main_program):
ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
ckpt = checkpoint_pb2.CheckPoint()
......@@ -41,33 +37,23 @@ def load_checkpoint(checkpoint_dir,
ckpt.ParseFromString(f.read())
current_epoch = 1
global_step = 0
pretrained_model = ""
best_model_path = os.path.join(checkpoint_dir, "best_model")
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
return os.path.exists(os.path.join(ckpt.latest_model_dir, var.name))
if load_best_model and os.path.exists(best_model_path):
pretrained_model = best_model_path
fluid.io.load_vars(
exe, best_model_path, main_program, predicate=if_exist)
logger.info("PaddleHub model best model loaded.")
return current_epoch, global_step
elif ckpt.latest_model_dir:
pretrained_model = ckpt.latest_model_dir
if ckpt.latest_model_dir:
fluid.io.load_vars(
exe, ckpt.latest_model_dir, main_program, predicate=if_exist)
logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}".format(ckpt.current_epoch,
ckpt.global_step))
return ckpt.current_epoch, ckpt.global_step
return True, ckpt.current_epoch, ckpt.global_step
logger.info(
"PaddleHub model checkpoint not found, start training from scratch...")
exe.run(startup_program)
return current_epoch, global_step
return False, current_epoch, global_step
def save_checkpoint(checkpoint_dir,
......
......@@ -30,7 +30,7 @@ from visualdl import LogWriter
import paddlehub as hub
from paddlehub.common.paddle_helper import dtype_map
from paddlehub.common.utils import mkdir
from paddlehub.common.utils import mkdir, to_list
from paddlehub.common.logger import logger
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
from paddlehub.finetune.evaluate import chunk_eval, calculate_f1
......@@ -105,21 +105,22 @@ class BasicTask(object):
self._base_startup_program = fluid.default_startup_program().clone()
else:
self._base_startup_program = startup_program.clone()
self._load_checkpoint = False
self._base_compile_program = None
self._base_compiled_program = None
self.is_checkpoint_loaded = False
# run config
self.config = config if config else RunConfig()
self.place = self.places[0]
self.device_count = len(self.places)
if self.config.use_data_parallel and self.config.batch_size < self.device_count:
logger.warning(
"Batch size({}) is less than the count of devices({}), which is not allowed in current Paddle versions"
.format(self.config.batch_size, self.device_count))
logger.warning("Batch size automatically adjusted to {}".format(
self.device_count))
self.config._batch_size = self.device_count
if self.config.use_data_parallel:
if not self.config.use_pyreader and self.config.batch_size < self.device_count:
logger.warning(
"Batch size({}) is less than the count of devices({}), which is not allowed in current Paddle versions"
.format(self.config.batch_size, self.device_count))
logger.warning("Batch size automatically adjusted to {}".format(
self.device_count))
self.config._batch_size = self.device_count
self.exe = fluid.Executor(place=self.place)
self.build_strategy = fluid.BuildStrategy()
......@@ -139,19 +140,29 @@ class BasicTask(object):
self._envs = {}
self._predict_data = None
def init_if_necessary(self, load_best_model=False):
if not self._load_checkpoint:
self.load_checkpoint(load_best_model=load_best_model)
self._load_checkpoint = True
# set default phase
self.enter_phase("train")
@contextlib.contextmanager
def phase_guard(self, phase):
self.enter_phase(phase)
yield
self.exit_phase()
def enter_phase(self, phase):
if phase not in ["train", "val", "dev", "test", "predict", "inference"]:
raise RuntimeError()
self._phases.append(phase)
yield
def exit_phase(self):
self._phases = self._phases[:-1]
def init_if_necessary(self):
if not self.is_checkpoint_loaded:
self.is_checkpoint_loaded = True
if not self.load_checkpoint():
self.exe.run(self._base_startup_program)
def _build_env(self):
if self.env.is_inititalized:
return
......@@ -186,6 +197,7 @@ class BasicTask(object):
feed_var_list = self.feed_var_list
py_vars = fluid.layers.read_file(self.env.py_reader)
py_vars = to_list(py_vars)
input_dict = {
feed_var_list[index].name: py_var
for index, py_var in enumerate(py_vars)
......@@ -198,18 +210,20 @@ class BasicTask(object):
need_log=False)
self.env.main_program = t_program
self.env.loss = self.env.main_program.global_block().vars[
self.env.loss.name]
if not self.is_predict_phase:
self.env.loss = self.env.main_program.global_block().vars[
self.env.loss.name]
metrics_name = [var.name for var in self.env.metrics]
self.env.metrics = [
self.env.main_program.global_block().vars[name]
for name in metrics_name
]
outputs_name = [var.name for var in self.env.outputs]
self.env.outputs = [
self.env.main_program.global_block().vars[name]
for name in outputs_name
]
metrics_name = [var.name for var in self.env.metrics]
self.env.metrics = [
self.env.main_program.global_block().vars[name]
for name in metrics_name
]
if self.config.enable_memory_optim:
for var_name in self.fetch_list:
......@@ -229,15 +243,15 @@ class BasicTask(object):
else:
loss_name = None
if self._base_compile_program is None:
if self._base_compiled_program is None:
share_vars_from = None
else:
share_vars_from = self._base_compile_program
share_vars_from = self._base_compiled_program
if not self.config.use_data_parallel:
if self.config.enable_memory_optim:
fluid.memory_optimize(self.env.main_program)
self.env.main_program_compiled = self.env.main_program
self.env.main_program_compiled = None
else:
self.env.main_program_compiled = fluid.CompiledProgram(
self.env.main_program).with_data_parallel(
......@@ -245,8 +259,8 @@ class BasicTask(object):
share_vars_from=share_vars_from,
build_strategy=self.build_strategy)
if self._base_compile_program is None:
self._base_compile_program = self.env.main_program_compiled
if self._base_compiled_program is None:
self._base_compiled_program = self.env.main_program_compiled
self.exe.run(self.env.startup_program)
self._build_env_end_event()
......@@ -318,6 +332,12 @@ class BasicTask(object):
self._build_env()
return self.env.main_program_compiled
@property
def main_program_to_be_run(self):
if self.config.use_data_parallel:
return self.main_program_compiled
return self.main_program
@property
def reader(self):
if self.is_predict_phase:
......@@ -443,16 +463,25 @@ class BasicTask(object):
exe=self.exe,
main_program=self.main_program)
def load_checkpoint(self, load_best_model=False):
self.env.current_epoch, self.env.current_step = load_checkpoint(
def load_checkpoint(self):
is_load_successful, self.env.current_epoch, self.env.current_step = load_checkpoint(
self.config.checkpoint_dir,
self.exe,
main_program=self.main_program,
startup_program=self._base_startup_program,
load_best_model=load_best_model)
main_program=self.main_program)
if self.is_predict_phase or self.is_test_phase:
self.env.current_step = 0
return is_load_successful
def load_parameters(self, dirname):
def if_exist(var):
path = os.path.join(dirname, var.name)
return os.path.exists(path)
fluid.io.load_vars(
self.exe, dirname, self.main_program, predicate=if_exist)
def save_parameters(self, dirname):
fluid.io.save_params(
self.exe, dirname=dirname, main_program=self.main_program)
def finetune_and_eval(self):
self.finetune(do_eval=True)
......@@ -486,8 +515,12 @@ class BasicTask(object):
def predict(self, data, load_best_model=True):
with self.phase_guard(phase="predict"):
self.init_if_necessary()
if load_best_model:
best_model_path = os.path.join(self.config.checkpoint_dir,
"best_model")
self.load_parameters(best_model_path)
self._predict_data = data
self.init_if_necessary(load_best_model=load_best_model)
run_states = self._run()
self._predict_data = None
return [run_state.run_results for run_state in run_states]
......@@ -515,7 +548,7 @@ class BasicTask(object):
num_batch_examples = len(batch)
fetch_result = self.exe.run(
self.main_program_compiled,
self.main_program_to_be_run,
feed=data_feeder.feed(batch),
fetch_list=self.fetch_list)
......@@ -543,41 +576,56 @@ class BasicTask(object):
return global_run_states
def _run_with_py_reader(self, do_eval=False):
global_run_states = []
period_run_states = []
self.py_reader.decorate_paddle_reader(self.reader)
self.py_reader.start()
try:
while True:
num_batch_examples = self.config.batch_size
step_run_state = RunState(len(self.fetch_list))
step_run_state.run_step = 1
fetch_result = self.exe.run(
self.main_program_compiled, fetch_list=self.fetch_list)
for index, result in enumerate(fetch_result):
step_run_state.run_results[index] = result
step_run_state.run_examples += num_batch_examples
step_run_state.update()
period_run_states += [step_run_state]
self.env.current_step += 1
if self.is_train_phase:
if self.current_step % self.config.log_interval == 0:
self._log_interval_event(period_run_states)
global_run_states += period_run_states
period_run_states = []
if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0:
self._save_ckpt_interval_event()
flag = False
while True:
global_run_states = []
period_run_states = []
self.py_reader.decorate_paddle_reader(self.reader)
self.py_reader.start()
try:
while True:
num_batch_examples = self.config.batch_size * self.device_count
step_run_state = RunState(len(self.fetch_list))
step_run_state.run_step = 1
fetch_result = self.exe.run(
self.main_program_to_be_run, fetch_list=self.fetch_list)
for index, result in enumerate(fetch_result):
step_run_state.run_results[index] = result
step_run_state.run_examples += num_batch_examples
step_run_state.update()
period_run_states += [step_run_state]
self.env.current_step += 1
if self.is_train_phase:
if self.current_step % self.config.log_interval == 0:
self._log_interval_event(period_run_states)
global_run_states += period_run_states
period_run_states = []
if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0:
self._save_ckpt_interval_event()
if do_eval and self.current_step % self.config.eval_interval == 0:
self._eval_interval_event()
self._run_step_event(step_run_state)
except fluid.core.EOFException:
global_run_states += period_run_states
self.py_reader.reset()
'''
When opening use_data_parallel and use_pyreader, if the amount of data is too small,
the reader will have thrown EOF Exception when not fetching to the running result.
In this case, temporarily close the use_data_parallel to get the result.
'''
if flag:
self.config._use_data_parallel = use_data_parallel_backup
elif len(global_run_states) == 0:
flag = True
use_data_parallel_backup = self.config.use_data_parallel
self.config._use_data_parallel = False
continue
break
if do_eval and self.current_step % self.config.eval_interval == 0:
self._eval_interval_event()
self._run_step_event(step_run_state)
except fluid.core.EOFException:
self.py_reader.reset()
global_run_states += period_run_states
return global_run_states
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册