提交 0ad41389 编写于 作者: W wuzewu

Optimize execution order

上级 90ed127a
...@@ -27,11 +27,7 @@ from paddlehub.common.logger import logger ...@@ -27,11 +27,7 @@ from paddlehub.common.logger import logger
CKPT_FILE_NAME = "ckpt.meta" CKPT_FILE_NAME = "ckpt.meta"
def load_checkpoint(checkpoint_dir, def load_checkpoint(checkpoint_dir, exe, main_program):
exe,
main_program=fluid.default_main_program(),
startup_program=fluid.default_startup_program(),
load_best_model=False):
ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME) ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
ckpt = checkpoint_pb2.CheckPoint() ckpt = checkpoint_pb2.CheckPoint()
...@@ -41,33 +37,23 @@ def load_checkpoint(checkpoint_dir, ...@@ -41,33 +37,23 @@ def load_checkpoint(checkpoint_dir,
ckpt.ParseFromString(f.read()) ckpt.ParseFromString(f.read())
current_epoch = 1 current_epoch = 1
global_step = 0 global_step = 0
pretrained_model = ""
best_model_path = os.path.join(checkpoint_dir, "best_model")
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(ckpt.latest_model_dir, var.name))
if load_best_model and os.path.exists(best_model_path): if ckpt.latest_model_dir:
pretrained_model = best_model_path
fluid.io.load_vars(
exe, best_model_path, main_program, predicate=if_exist)
logger.info("PaddleHub model best model loaded.")
return current_epoch, global_step
elif ckpt.latest_model_dir:
pretrained_model = ckpt.latest_model_dir
fluid.io.load_vars( fluid.io.load_vars(
exe, ckpt.latest_model_dir, main_program, predicate=if_exist) exe, ckpt.latest_model_dir, main_program, predicate=if_exist)
logger.info("PaddleHub model checkpoint loaded. current_epoch={}, " logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}".format(ckpt.current_epoch, "global_step={}".format(ckpt.current_epoch,
ckpt.global_step)) ckpt.global_step))
return ckpt.current_epoch, ckpt.global_step return True, ckpt.current_epoch, ckpt.global_step
logger.info( logger.info(
"PaddleHub model checkpoint not found, start training from scratch...") "PaddleHub model checkpoint not found, start training from scratch...")
exe.run(startup_program)
return current_epoch, global_step return False, current_epoch, global_step
def save_checkpoint(checkpoint_dir, def save_checkpoint(checkpoint_dir,
......
...@@ -30,7 +30,7 @@ from visualdl import LogWriter ...@@ -30,7 +30,7 @@ from visualdl import LogWriter
import paddlehub as hub import paddlehub as hub
from paddlehub.common.paddle_helper import dtype_map from paddlehub.common.paddle_helper import dtype_map
from paddlehub.common.utils import mkdir from paddlehub.common.utils import mkdir, to_list
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
from paddlehub.finetune.evaluate import chunk_eval, calculate_f1 from paddlehub.finetune.evaluate import chunk_eval, calculate_f1
...@@ -105,21 +105,22 @@ class BasicTask(object): ...@@ -105,21 +105,22 @@ class BasicTask(object):
self._base_startup_program = fluid.default_startup_program().clone() self._base_startup_program = fluid.default_startup_program().clone()
else: else:
self._base_startup_program = startup_program.clone() self._base_startup_program = startup_program.clone()
self._load_checkpoint = False self._base_compiled_program = None
self._base_compile_program = None self.is_checkpoint_loaded = False
# run config # run config
self.config = config if config else RunConfig() self.config = config if config else RunConfig()
self.place = self.places[0] self.place = self.places[0]
self.device_count = len(self.places) self.device_count = len(self.places)
if self.config.use_data_parallel and self.config.batch_size < self.device_count: if self.config.use_data_parallel:
logger.warning( if not self.config.use_pyreader and self.config.batch_size < self.device_count:
"Batch size({}) is less than the count of devices({}), which is not allowed in current Paddle versions" logger.warning(
.format(self.config.batch_size, self.device_count)) "Batch size({}) is less than the count of devices({}), which is not allowed in current Paddle versions"
logger.warning("Batch size automatically adjusted to {}".format( .format(self.config.batch_size, self.device_count))
self.device_count)) logger.warning("Batch size automatically adjusted to {}".format(
self.config._batch_size = self.device_count self.device_count))
self.config._batch_size = self.device_count
self.exe = fluid.Executor(place=self.place) self.exe = fluid.Executor(place=self.place)
self.build_strategy = fluid.BuildStrategy() self.build_strategy = fluid.BuildStrategy()
...@@ -139,19 +140,29 @@ class BasicTask(object): ...@@ -139,19 +140,29 @@ class BasicTask(object):
self._envs = {} self._envs = {}
self._predict_data = None self._predict_data = None
def init_if_necessary(self, load_best_model=False): # set default phase
if not self._load_checkpoint: self.enter_phase("train")
self.load_checkpoint(load_best_model=load_best_model)
self._load_checkpoint = True
@contextlib.contextmanager @contextlib.contextmanager
def phase_guard(self, phase): def phase_guard(self, phase):
self.enter_phase(phase)
yield
self.exit_phase()
def enter_phase(self, phase):
if phase not in ["train", "val", "dev", "test", "predict", "inference"]: if phase not in ["train", "val", "dev", "test", "predict", "inference"]:
raise RuntimeError() raise RuntimeError()
self._phases.append(phase) self._phases.append(phase)
yield
def exit_phase(self):
self._phases = self._phases[:-1] self._phases = self._phases[:-1]
def init_if_necessary(self):
if not self.is_checkpoint_loaded:
self.is_checkpoint_loaded = True
if not self.load_checkpoint():
self.exe.run(self._base_startup_program)
def _build_env(self): def _build_env(self):
if self.env.is_inititalized: if self.env.is_inititalized:
return return
...@@ -186,6 +197,7 @@ class BasicTask(object): ...@@ -186,6 +197,7 @@ class BasicTask(object):
feed_var_list = self.feed_var_list feed_var_list = self.feed_var_list
py_vars = fluid.layers.read_file(self.env.py_reader) py_vars = fluid.layers.read_file(self.env.py_reader)
py_vars = to_list(py_vars)
input_dict = { input_dict = {
feed_var_list[index].name: py_var feed_var_list[index].name: py_var
for index, py_var in enumerate(py_vars) for index, py_var in enumerate(py_vars)
...@@ -198,18 +210,20 @@ class BasicTask(object): ...@@ -198,18 +210,20 @@ class BasicTask(object):
need_log=False) need_log=False)
self.env.main_program = t_program self.env.main_program = t_program
self.env.loss = self.env.main_program.global_block().vars[ if not self.is_predict_phase:
self.env.loss.name] self.env.loss = self.env.main_program.global_block().vars[
self.env.loss.name]
metrics_name = [var.name for var in self.env.metrics]
self.env.metrics = [
self.env.main_program.global_block().vars[name]
for name in metrics_name
]
outputs_name = [var.name for var in self.env.outputs] outputs_name = [var.name for var in self.env.outputs]
self.env.outputs = [ self.env.outputs = [
self.env.main_program.global_block().vars[name] self.env.main_program.global_block().vars[name]
for name in outputs_name for name in outputs_name
] ]
metrics_name = [var.name for var in self.env.metrics]
self.env.metrics = [
self.env.main_program.global_block().vars[name]
for name in metrics_name
]
if self.config.enable_memory_optim: if self.config.enable_memory_optim:
for var_name in self.fetch_list: for var_name in self.fetch_list:
...@@ -229,15 +243,15 @@ class BasicTask(object): ...@@ -229,15 +243,15 @@ class BasicTask(object):
else: else:
loss_name = None loss_name = None
if self._base_compile_program is None: if self._base_compiled_program is None:
share_vars_from = None share_vars_from = None
else: else:
share_vars_from = self._base_compile_program share_vars_from = self._base_compiled_program
if not self.config.use_data_parallel: if not self.config.use_data_parallel:
if self.config.enable_memory_optim: if self.config.enable_memory_optim:
fluid.memory_optimize(self.env.main_program) fluid.memory_optimize(self.env.main_program)
self.env.main_program_compiled = self.env.main_program self.env.main_program_compiled = None
else: else:
self.env.main_program_compiled = fluid.CompiledProgram( self.env.main_program_compiled = fluid.CompiledProgram(
self.env.main_program).with_data_parallel( self.env.main_program).with_data_parallel(
...@@ -245,8 +259,8 @@ class BasicTask(object): ...@@ -245,8 +259,8 @@ class BasicTask(object):
share_vars_from=share_vars_from, share_vars_from=share_vars_from,
build_strategy=self.build_strategy) build_strategy=self.build_strategy)
if self._base_compile_program is None: if self._base_compiled_program is None:
self._base_compile_program = self.env.main_program_compiled self._base_compiled_program = self.env.main_program_compiled
self.exe.run(self.env.startup_program) self.exe.run(self.env.startup_program)
self._build_env_end_event() self._build_env_end_event()
...@@ -318,6 +332,12 @@ class BasicTask(object): ...@@ -318,6 +332,12 @@ class BasicTask(object):
self._build_env() self._build_env()
return self.env.main_program_compiled return self.env.main_program_compiled
@property
def main_program_to_be_run(self):
if self.config.use_data_parallel:
return self.main_program_compiled
return self.main_program
@property @property
def reader(self): def reader(self):
if self.is_predict_phase: if self.is_predict_phase:
...@@ -443,16 +463,25 @@ class BasicTask(object): ...@@ -443,16 +463,25 @@ class BasicTask(object):
exe=self.exe, exe=self.exe,
main_program=self.main_program) main_program=self.main_program)
def load_checkpoint(self, load_best_model=False): def load_checkpoint(self):
self.env.current_epoch, self.env.current_step = load_checkpoint( is_load_successful, self.env.current_epoch, self.env.current_step = load_checkpoint(
self.config.checkpoint_dir, self.config.checkpoint_dir,
self.exe, self.exe,
main_program=self.main_program, main_program=self.main_program)
startup_program=self._base_startup_program,
load_best_model=load_best_model)
if self.is_predict_phase or self.is_test_phase: return is_load_successful
self.env.current_step = 0
def load_parameters(self, dirname):
def if_exist(var):
path = os.path.join(dirname, var.name)
return os.path.exists(path)
fluid.io.load_vars(
self.exe, dirname, self.main_program, predicate=if_exist)
def save_parameters(self, dirname):
fluid.io.save_params(
self.exe, dirname=dirname, main_program=self.main_program)
def finetune_and_eval(self): def finetune_and_eval(self):
self.finetune(do_eval=True) self.finetune(do_eval=True)
...@@ -486,8 +515,12 @@ class BasicTask(object): ...@@ -486,8 +515,12 @@ class BasicTask(object):
def predict(self, data, load_best_model=True): def predict(self, data, load_best_model=True):
with self.phase_guard(phase="predict"): with self.phase_guard(phase="predict"):
self.init_if_necessary()
if load_best_model:
best_model_path = os.path.join(self.config.checkpoint_dir,
"best_model")
self.load_parameters(best_model_path)
self._predict_data = data self._predict_data = data
self.init_if_necessary(load_best_model=load_best_model)
run_states = self._run() run_states = self._run()
self._predict_data = None self._predict_data = None
return [run_state.run_results for run_state in run_states] return [run_state.run_results for run_state in run_states]
...@@ -515,7 +548,7 @@ class BasicTask(object): ...@@ -515,7 +548,7 @@ class BasicTask(object):
num_batch_examples = len(batch) num_batch_examples = len(batch)
fetch_result = self.exe.run( fetch_result = self.exe.run(
self.main_program_compiled, self.main_program_to_be_run,
feed=data_feeder.feed(batch), feed=data_feeder.feed(batch),
fetch_list=self.fetch_list) fetch_list=self.fetch_list)
...@@ -543,41 +576,56 @@ class BasicTask(object): ...@@ -543,41 +576,56 @@ class BasicTask(object):
return global_run_states return global_run_states
def _run_with_py_reader(self, do_eval=False): def _run_with_py_reader(self, do_eval=False):
global_run_states = [] flag = False
period_run_states = [] while True:
self.py_reader.decorate_paddle_reader(self.reader) global_run_states = []
self.py_reader.start() period_run_states = []
try: self.py_reader.decorate_paddle_reader(self.reader)
while True: self.py_reader.start()
num_batch_examples = self.config.batch_size try:
step_run_state = RunState(len(self.fetch_list)) while True:
step_run_state.run_step = 1 num_batch_examples = self.config.batch_size * self.device_count
fetch_result = self.exe.run( step_run_state = RunState(len(self.fetch_list))
self.main_program_compiled, fetch_list=self.fetch_list) step_run_state.run_step = 1
fetch_result = self.exe.run(
for index, result in enumerate(fetch_result): self.main_program_to_be_run, fetch_list=self.fetch_list)
step_run_state.run_results[index] = result
step_run_state.run_examples += num_batch_examples for index, result in enumerate(fetch_result):
step_run_state.update() step_run_state.run_results[index] = result
period_run_states += [step_run_state] step_run_state.run_examples += num_batch_examples
self.env.current_step += 1 step_run_state.update()
if self.is_train_phase: period_run_states += [step_run_state]
if self.current_step % self.config.log_interval == 0: self.env.current_step += 1
self._log_interval_event(period_run_states) if self.is_train_phase:
global_run_states += period_run_states if self.current_step % self.config.log_interval == 0:
period_run_states = [] self._log_interval_event(period_run_states)
global_run_states += period_run_states
if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0: period_run_states = []
self._save_ckpt_interval_event()
if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0:
self._save_ckpt_interval_event()
if do_eval and self.current_step % self.config.eval_interval == 0:
self._eval_interval_event()
self._run_step_event(step_run_state)
except fluid.core.EOFException:
global_run_states += period_run_states
self.py_reader.reset()
'''
When opening use_data_parallel and use_pyreader, if the amount of data is too small,
the reader will have thrown EOF Exception when not fetching to the running result.
In this case, temporarily close the use_data_parallel to get the result.
'''
if flag:
self.config._use_data_parallel = use_data_parallel_backup
elif len(global_run_states) == 0:
flag = True
use_data_parallel_backup = self.config.use_data_parallel
self.config._use_data_parallel = False
continue
break
if do_eval and self.current_step % self.config.eval_interval == 0:
self._eval_interval_event()
self._run_step_event(step_run_state)
except fluid.core.EOFException:
self.py_reader.reset()
global_run_states += period_run_states
return global_run_states return global_run_states
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册