diff --git a/paddlehub/common/paddle_helper.py b/paddlehub/common/paddle_helper.py index 955ae80979a1bcbacc4f1d8b17bc98f5783509f7..f6e57e2b1669671a152f7e80c7adb74660b35ebd 100644 --- a/paddlehub/common/paddle_helper.py +++ b/paddlehub/common/paddle_helper.py @@ -142,7 +142,11 @@ def from_module_attr_to_param(module_attr): return param -def connect_program(pre_program, next_program, input_dict=None, inplace=True): +def connect_program(pre_program, + next_program, + input_dict=None, + inplace=True, + need_log=True): def _copy_vars_and_ops_in_blocks(from_block, to_block): for var in from_block.vars: var = from_block.var(var) @@ -198,7 +202,8 @@ def connect_program(pre_program, next_program, input_dict=None, inplace=True): outputs={'Out': output_var}) block_map = {0: 0} - logger.info("Connect program's input tensor") + if need_log: + logger.info("Connect program's input tensor") for index, block in enumerate(next_program.blocks): if block.idx == 0: _copy_vars_and_ops_in_blocks(block, output_program.global_block()) @@ -210,7 +215,8 @@ def connect_program(pre_program, next_program, input_dict=None, inplace=True): new_block = output_program._create_block( parent_idx=block_map[block.parent_idx]) _copy_vars_and_ops_in_blocks(block, new_block) - logger.info("Connect program's input tensor done") + if need_log: + logger.info("Connect program's input tensor done") return output_program diff --git a/paddlehub/finetune/task.py b/paddlehub/finetune/task.py index ee7f1341486cc1a8a8ad76ec549932dfe282adb5..7f9574a30045210e39b728139a7ebe74456b2be4 100644 --- a/paddlehub/finetune/task.py +++ b/paddlehub/finetune/task.py @@ -18,8 +18,10 @@ from __future__ import print_function import os import collections +import contextlib import time import multiprocessing +import copy import numpy as np import paddle.fluid as fluid @@ -61,6 +63,28 @@ class RunState(object): return self +class RunEnv(object): + def __init__(self): + self.current_epoch = 0 + self.current_step = 0 + self.main_program = None + self.start_program = None + self.main_program_compiled = None + self.py_reader = None + self.reader = None + self.loss = None + self.label = None + self.metrics = None + self.is_inititalized = False + self.UNG = copy.deepcopy(fluid.unique_name.generator) + + def __setattr__(self, key, value): + self.__dict__[key] = value + + def __getattr__(self, key): + return self.__dict__[key] + + class BasicTask(object): def __init__(self, feed_list, @@ -68,36 +92,277 @@ class BasicTask(object): main_program=None, startup_program=None, config=None): - self.data_reader = data_reader - self.main_program = main_program if main_program else fluid.default_main_program( - ) - self.startup_program = startup_program if startup_program else fluid.default_startup_program( - ) + + # base item + self._base_data_reader = data_reader + self._base_feed_list = feed_list + if main_program is None: + self._base_main_program = fluid.default_main_program().clone() + else: + self._base_main_program = main_program.clone() + if startup_program is None: + self._base_startup_program = fluid.default_startup_program().clone() + else: + self._base_startup_program = startup_program.clone() + self._load_checkpoint = False + self._base_compile_program = None + + # run config self.config = config if config else RunConfig() self.place, self.device_count = hub.common.get_running_device_info( self.config) self.exe = fluid.Executor(place=self.place) - self.feed_list = feed_list - self.feed_variables = [ - main_program.global_block().vars[var_name] for var_name in feed_list - ] - self.metrics = [] - self.is_inititalized = False - self.current_step = 0 - self.current_epoch = 0 + self.build_strategy = fluid.BuildStrategy() + if self.config.enable_memory_optim: + self.build_strategy.memory_optimize = True + else: + self.build_strategy.memory_optimize = False + + # log item + if not os.path.exists(self.config.checkpoint_dir): + mkdir(self.config.checkpoint_dir) + vdl_log_dir = os.path.join(self.config.checkpoint_dir, "vdllog") + self.log_writer = LogWriter(vdl_log_dir, sync_cycle=1) + + # run environment + self._phases = [] + self._envs = {} + + def init_if_necessary(self): + if not self._load_checkpoint: + self.load_checkpoint() + self._load_checkpoint = True + + @contextlib.contextmanager + def phase_guard(self, phase): + if phase not in ["train", "val", "dev", "test", "predict", "inference"]: + raise RuntimeError() + self._phases.append(phase) + yield + self._phases = self._phases[:-1] + + def _build_env(self): + if self.env.is_inititalized: + return + + self._build_env_start_event() + self.env.is_inititalized = True + self.env.main_program = self._base_main_program.clone() + self.env.startup_program = fluid.Program() + with fluid.program_guard(self.env.main_program, + self._base_startup_program): + with fluid.unique_name.guard(self.env.UNG): + self.env.output = self._build_net() + if self.is_train_phase or self.is_test_phase: + self.env.label = self._add_label() + self.env.loss = self._add_loss() + self.env.metrics = self._add_metrics() - def _init_start_event(self): + if self.config.use_pyreader: + t_program = fluid.Program() + with fluid.program_guard(t_program, self.env.startup_program): + self.env.py_reader = fluid.layers.py_reader( + capacity=64, + shapes=[var.shape for var in self.feed_var_list], + dtypes=[dtype_map[var.dtype] for var in self.feed_var_list], + lod_levels=[var.lod_level for var in self.feed_var_list], + use_double_buffer=False) + + feed_var_list = self.feed_var_list + py_vars = fluid.layers.read_file(self.env.py_reader) + input_dict = { + feed_var_list[index].name: py_var + for index, py_var in enumerate(py_vars) + } + + hub.connect_program( + pre_program=t_program, + next_program=self.env.main_program, + input_dict=input_dict, + need_log=False) + + self.env.main_program = t_program + self.env.loss = self.env.main_program.global_block().vars[ + self.env.loss.name] + self.env.output = self.env.main_program.global_block().vars[ + self.env.output.name] + metrics_name = [var.name for var in self.env.metrics] + self.env.metrics = [ + self.env.main_program.global_block().vars[name] + for name in metrics_name + ] + + if self.config.enable_memory_optim: + for var_name in self.fetch_list: + var = self.env.main_program.global_block().vars[var_name] + var.persistable = True + + if self.is_train_phase: + with fluid.program_guard(self.env.main_program, + self._base_startup_program): + with fluid.unique_name.guard(self.env.UNG): + self.config.strategy.execute( + self.loss, self._base_data_reader, self.config) + + if self.is_train_phase: + loss_name = self.env.loss.name + share_vars_from = None + else: + loss_name = None + + if self._base_compile_program is None: + share_vars_from = None + else: + share_vars_from = self._base_compile_program + + self.env.main_program_compiled = fluid.CompiledProgram( + self.env.main_program).with_data_parallel( + loss_name=loss_name, + share_vars_from=share_vars_from, + build_strategy=self.build_strategy) + + if self._base_compile_program is None: + self._base_compile_program = self.env.main_program_compiled + + self.exe.run(self.env.startup_program) + self._build_env_end_event() + + @property + def is_train_phase(self): + return self.phase in ["train"] + + @property + def is_test_phase(self): + return self.phase in ["val", "dev", "test"] + + @property + def is_predict_phase(self): + return self.phase in ["predict", "inference"] + + @property + def phase(self): + return self._phases[-1] + + @property + def env(self): + phase = self.phase + if phase in ["val", "dev", "test"]: + phase = "val" + if not phase in self._envs: + self._envs[phase] = RunEnv() + return self._envs[phase] + + @property + def py_reader(self): + if not self.env.is_inititalized: + self._build_env() + return self.env.py_reader + + @property + def current_step(self): + if not self.env.is_inititalized: + self._build_env() + return self.env.current_step + + @property + def current_epoch(self): + if not self.env.is_inititalized: + self._build_env() + return self.env.current_epoch + + @property + def main_program(self): + if not self.env.is_inititalized: + self._build_env() + return self.env.main_program + + @property + def startup_program(self): + if not self.env.is_inititalized: + self._build_env() + return self.env.startup_program + + @property + def main_program_compiled(self): + if not self.env.is_inititalized: + self._build_env() + return self.env.main_program_compiled + + @property + def reader(self): + self.env.reader = self._base_data_reader.data_generator( + batch_size=self.config.batch_size, phase=self.phase) + return self.env.reader + + @property + def loss(self): + if self.is_predict_phase: + raise RuntimeError() + + if not self.env.is_inititalized: + self._build_env() + return self.env.loss + + @property + def label(self): + if self.is_predict_phase: + raise RuntimeError() + + if not self.env.is_inititalized: + self._build_env() + return self.env.label + + @property + def output(self): + if self.is_predict_phase: + raise RuntimeError() + if not self.env.is_inititalized: + self._build_env() + return self.env.output + + @property + def metrics(self): + if self.is_predict_phase: + raise RuntimeError() + + if not self.env.is_inititalized: + self._build_env() + return self.env.metrics + + @property + def unique_name_generator(self): + return self.env.UNG + + @property + def feed_list(self): + feed_list = [varname for varname in self._base_feed_list] + if self.is_train_phase or self.is_test_phase: + feed_list += [self.label.name] + return feed_list + + @property + def feed_var_list(self): + vars = self.main_program.global_block().vars + return [vars[varname] for varname in self.feed_list] + + @property + def fetch_list(self): + if self.is_train_phase or self.is_test_phase: + return [metric.name for metric in self.metrics] + [self.loss.name] + return [self.output.name] + + def _build_env_start_event(self): pass - def _init_end_event(self): + def _build_env_end_event(self): pass - def _eval_start_event(self, phase): - logger.info("Evaluation on {} dataset start".format(phase)) + def _eval_start_event(self): + logger.info("Evaluation on {} dataset start".format(self.phase)) - def _eval_end_event(self, phase, run_state): + def _eval_end_event(self, run_state): logger.info("[%s dataset evaluation result] [step/sec: %.2f]" % - (phase, run_state.run_speed)) + (self.phase, run_state.run_speed)) def _log_interval_event(self, run_state): logger.info("step %d: [step/sec: %.2f]" % (self.current_step, @@ -109,8 +374,8 @@ class BasicTask(object): def _eval_interval_event(self): self.eval(phase="dev") - def _run_step_event(self, phase, run_state): - if phase == "predict": + def _run_step_event(self, run_state): + if self.is_predict_phase: yield run_state.run_results def _finetune_start_event(self): @@ -131,114 +396,6 @@ class BasicTask(object): def _add_metrics(self): raise NotImplementedError - def _add_py_reader(self): - - for program, add_label in ((self.main_program, - True), (self.test_program, True), - (self.inference_program, False)): - temp_program = fluid.Program() - startup_program = fluid.Program() - with fluid.program_guard(temp_program, startup_program): - feed_variables = self.feed_variables - if add_label: - feed_variables = feed_variables + [self.label] - feed_list = self.feed_list - if add_label: - feed_list = feed_list + [self.label.name] - - py_reader = fluid.layers.py_reader( - capacity=16, - shapes=[var.shape for var in feed_variables], - lod_levels=[var.lod_level for var in feed_variables], - dtypes=[dtype_map[var.dtype] for var in feed_variables], - use_double_buffer=True) - - feed_variables = fluid.layers.read_file(py_reader) - input_dict = { - key: feed_variables[index] - for index, key in enumerate(feed_list) - } - - hub.connect_program( - pre_program=temp_program, - next_program=program, - input_dict=input_dict, - inplace=True) - - self.exe.run(startup_program) - - if program == self.main_program: - self.main_program = temp_program - self.loss = self.main_program.global_block().vars[ - self.loss.name] - for index, metric in enumerate(self.metrics): - self.metrics[index] = self.main_program.global_block().vars[ - metric.name] - self.output = self.main_program.global_block().vars[ - self.output.name] - self.loss.persistable = True - for metric in self.metrics: - metric.persistable = True - self.output.persistable = True - self.main_py_reader = py_reader - elif program == self.test_program: - self.test_program = temp_program - self.test_py_reader = py_reader - elif program == self.inference_program: - self.inference_program = temp_program - self.inference_py_reader = py_reader - - def _init_if_necessary(self, load_best_model=False): - if not self.is_inititalized: - self._init_start_event() - with fluid.program_guard(self.main_program): - self.output = self._build_net() - self.inference_program = self.main_program.clone(for_test=True) - self._add_label() - self._add_loss() - self._add_metrics() - self.test_program = self.main_program.clone(for_test=True) - - if self.config.use_pyreader: - self._add_py_reader() - - with fluid.program_guard(self.main_program): - self.config.strategy.execute(self.loss, self.data_reader, - self.config) - - self.loss.persistable = True - for metric in self.metrics: - metric.persistable = True - self.output.persistable = True - - self.build_strategy = fluid.BuildStrategy() - if self.config.enable_memory_optim: - self.build_strategy.memory_optimize = True - else: - self.build_strategy.memory_optimize = False - - self.main_program_compiled = fluid.CompiledProgram( - self.main_program).with_data_parallel( - loss_name=self.loss.name, - build_strategy=self.build_strategy) - self.inference_program_compiled = fluid.CompiledProgram( - self.inference_program).with_data_parallel( - share_vars_from=self.main_program_compiled, - build_strategy=self.build_strategy) - self.test_program_compiled = fluid.CompiledProgram( - self.test_program).with_data_parallel( - share_vars_from=self.main_program_compiled, - build_strategy=self.build_strategy) - - self.load_checkpoint(load_best_model=load_best_model) - - if not os.path.exists(self.config.checkpoint_dir): - mkdir(self.config.checkpoint_dir) - vdl_log_dir = os.path.join(self.config.checkpoint_dir, "vdllog") - self.log_writer = LogWriter(vdl_log_dir, sync_cycle=1) - self.is_inititalized = True - self._init_end_event() - # NOTE: current saved checkpoint machanism is not completed, # it can't restore dataset training status def save_checkpoint(self, epoch, step): @@ -250,11 +407,11 @@ class BasicTask(object): main_program=self.main_program) def load_checkpoint(self, load_best_model=False): - self.current_epoch, self.current_step = load_checkpoint( + self.env.current_epoch, self.env.current_step = load_checkpoint( self.config.checkpoint_dir, self.exe, main_program=self.main_program, - startup_program=self.startup_program) + startup_program=self._base_startup_program) if load_best_model: model_saved_dir = os.path.join(self.config.checkpoint_dir, @@ -265,89 +422,73 @@ class BasicTask(object): dirname=model_saved_dir, main_program=self.main_program) - def get_feed_list(self, phase): - if phase in ["train", "dev", "val", "test"]: - return self.feed_list + [self.label.name] - return self.feed_list - - def get_fetch_list(self, phase): - metrics_name = [metric.name for metric in self.metrics] - if phase in ["train", "dev", "val", "test"]: - return metrics_name + [self.loss.name] - return [self.output.name] - def finetune_and_eval(self): self.finetune(do_eval=True) def finetune(self, do_eval=False): - self._init_if_necessary() - self._finetune_start_event() - run_states = [] - if self.current_epoch <= self.config.num_epoch: - # Start to finetune - with fluid.program_guard(self.main_program): + # Start to finetune + with self.phase_guard(phase="train"): + self.init_if_necessary() + self._finetune_start_event() + run_states = [] + if self.current_epoch <= self.config.num_epoch: while self.current_epoch <= self.config.num_epoch: - train_reader = self.data_reader.data_generator( - batch_size=self.config.batch_size, phase='train') - run_states = self._run( - train_reader, - phase="train", - do_eval=do_eval, - program_compiled=self.main_program_compiled) - self.current_epoch += 1 + run_states = self._run(do_eval=do_eval) + self.env.current_epoch += 1 - # Save checkpoint after finetune - self.save_checkpoint(self.current_epoch + 1, self.current_step) + # Save checkpoint after finetune + self.save_checkpoint(self.current_epoch + 1, self.current_step) - # Final evaluation - self.eval(phase="dev") - self.eval(phase="test") + # Final evaluation + self.eval(phase="dev") + self.eval(phase="test") - self._finetune_end_event(run_states) + self._finetune_end_event(run_states) def eval(self, phase="dev"): - self._init_if_necessary() - self._eval_start_event(phase) - with fluid.program_guard(self.test_program): - test_reader = self.data_reader.data_generator( - batch_size=self.config.batch_size, phase=phase) - run_states = self._run( - test_reader, - phase=phase, - program_compiled=self.test_program_compiled) - - self._eval_end_event(phase, run_states) - - def _run_with_data_feeder(self, - reader, - phase, - do_eval=False, - program_compiled=None): - if program_compiled is None: - program_compiled = self.main_program_compiled - feed_list = self.get_feed_list(phase=phase) - data_feeder = fluid.DataFeeder(feed_list=feed_list, place=self.place) - fetch_list = self.get_fetch_list(phase=phase) + with self.phase_guard(phase=phase): + self.init_if_necessary() + self._eval_start_event() + run_states = self._run() + self._eval_end_event(run_states) + + def predict(self, data, load_best_model=True): + with self.phase_guard(phase=phase): + self.init_if_necessary() + for run_state in self._run(): + yield run_state.run_results + + def _run(self, do_eval=False): + with fluid.program_guard(self.main_program, self.startup_program): + if self.config.use_pyreader: + return self._run_with_py_reader(do_eval=do_eval) + return self._run_with_data_feeder(do_eval=do_eval) + + def _run_with_data_feeder(self, do_eval=False): + + data_feeder = fluid.DataFeeder( + feed_list=self.feed_list, place=self.place) + global_run_states = [] period_run_states = [] - for run_step, batch in enumerate(reader(), start=1): - step_run_state = RunState(len(fetch_list)) + for run_step, batch in enumerate(self.reader(), start=1): + step_run_state = RunState(len(self.fetch_list)) step_run_state.run_step = 1 num_batch_examples = len(batch) fetch_result = self.exe.run( - program_compiled, + self.main_program_compiled, feed=data_feeder.feed(batch), - fetch_list=fetch_list) + fetch_list=self.fetch_list) for index, result in enumerate(fetch_result): step_run_state.run_results[index] = result step_run_state.run_examples += num_batch_examples step_run_state.update() period_run_states += [step_run_state] - if phase == "train": - self.current_step += 1 + if self.is_train_phase: + self.env.current_step += 1 if self.current_step % self.config.log_interval == 0: self._log_interval_event(period_run_states) global_run_states += period_run_states @@ -359,46 +500,31 @@ class BasicTask(object): if do_eval and self.current_step % self.config.eval_interval == 0: self._eval_interval_event() - self._run_step_event(phase, step_run_state) + self._run_step_event(step_run_state) global_run_states += period_run_states return global_run_states - def _run_with_py_reader(self, - reader, - phase, - do_eval=False, - program_compiled=None): - if program_compiled is None: - program_compiled = self.main_program_compiled - if phase == "train": - py_reader = self.main_py_reader - elif phase in ["dev", "val", "test"]: - py_reader = self.test_py_reader - elif phase == "predict": - py_reader = self.inference_py_reader - - py_reader.decorate_paddle_reader(reader) - fetch_list = self.get_fetch_list(phase=phase) + def _run_with_py_reader(self, do_eval=False): global_run_states = [] period_run_states = [] - - py_reader.start() + self.py_reader.decorate_paddle_reader(self.reader) + self.py_reader.start() try: while True: num_batch_examples = self.config.batch_size - step_run_state = RunState(len(fetch_list)) + step_run_state = RunState(len(self.fetch_list)) step_run_state.run_step = 1 fetch_result = self.exe.run( - program_compiled, fetch_list=fetch_list) + self.main_program_compiled, fetch_list=self.fetch_list) for index, result in enumerate(fetch_result): step_run_state.run_results[index] = result step_run_state.run_examples += num_batch_examples step_run_state.update() period_run_states += [step_run_state] - if phase == "train": - self.current_step += 1 + if self.is_train_phase: + self.env.current_step += 1 if self.current_step % self.config.log_interval == 0: self._log_interval_event(period_run_states) global_run_states += period_run_states @@ -410,38 +536,13 @@ class BasicTask(object): if do_eval and self.current_step % self.config.eval_interval == 0: self._eval_interval_event() - self._run_step_event(phase, step_run_state) + self._run_step_event(step_run_state) except fluid.core.EOFException: - py_reader.reset() + self.py_reader.reset() global_run_states += period_run_states return global_run_states - def _run(self, reader, phase, do_eval=False, program_compiled=None): - if self.config.use_pyreader: - return self._run_with_py_reader( - reader, - phase, - do_eval=do_eval, - program_compiled=program_compiled) - else: - return self._run_with_data_feeder( - reader, - phase, - do_eval=do_eval, - program_compiled=program_compiled) - - def predict(self, data, load_best_model=True): - self._init_if_necessary(load_best_model=load_best_model) - with fluid.program_guard(self.inference_program): - inference_reader = self.data_reader.data_generator( - batch_size=self.config.batch_size, phase='predict', data=data) - for run_state in self._run( - inference_reader, - phase='predict', - program_compiled=self.inference_program_compiled): - yield run_state.run_results - class ClassifierTask(BasicTask): def __init__(self, @@ -487,25 +588,22 @@ class ClassifierTask(BasicTask): return logits def _add_label(self): - self.label = fluid.layers.data(name="label", dtype="int64", shape=[1]) + return fluid.layers.data(name="label", dtype="int64", shape=[1]) def _add_loss(self): ce_loss = fluid.layers.cross_entropy( input=self.output, label=self.label) - self.loss = fluid.layers.mean(x=ce_loss) + return fluid.layers.mean(x=ce_loss) def _add_metrics(self): - self.accuracy = fluid.layers.accuracy( - input=self.output, label=self.label) - self.metrics.append(self.accuracy) + return [fluid.layers.accuracy(input=self.output, label=self.label)] - def _init_end_event(self): - with self.log_writer.mode("train") as logw: - self.train_loss_scalar = logw.scalar(tag="Loss [train]") - self.train_acc_scalar = logw.scalar(tag="Accuracy [train]") - with self.log_writer.mode("evaluate") as logw: - self.eval_loss_scalar = logw.scalar(tag="Loss [eval]") - self.eval_acc_scalar = logw.scalar(tag="Accuracy [eval]") + def _build_env_end_event(self): + with self.log_writer.mode(self.phase) as logw: + self.env.loss_scalar = logw.scalar( + tag="Loss [{}]".format(self.phase)) + self.env.acc_scalar = logw.scalar( + tag="Accuracy [{}]".format(self.phase)) def _calculate_metrics(self, run_states): loss_sum = acc_sum = run_examples = 0 @@ -527,19 +625,19 @@ class ClassifierTask(BasicTask): def _log_interval_event(self, run_states): avg_loss, avg_acc, run_speed = self._calculate_metrics(run_states) - self.train_loss_scalar.add_record(self.current_step, avg_loss) - self.train_acc_scalar.add_record(self.current_step, avg_acc) + self.env.loss_scalar.add_record(self.current_step, avg_loss) + self.env.acc_scalar.add_record(self.current_step, avg_acc) logger.info("step %d: loss=%.5f acc=%.5f [step/sec: %.2f]" % (self.current_step, avg_loss, avg_acc, run_speed)) - def _eval_end_event(self, phase, run_states): + def _eval_end_event(self, run_states): eval_loss, eval_acc, run_speed = self._calculate_metrics(run_states) logger.info( "[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]" - % (phase, eval_loss, eval_acc, run_speed)) - if phase in ["dev", "val"] and eval_acc > self.best_accuracy: - self.eval_loss_scalar.add_record(self.current_step, eval_loss) - self.eval_acc_scalar.add_record(self.current_step, eval_acc) + % (self.phase, eval_loss, eval_acc, run_speed)) + if self.phase in ["dev", "val"] and eval_acc > self.best_accuracy: + self.env.loss_scalar.add_record(self.current_step, eval_loss) + self.env.acc_scalar.add_record(self.current_step, eval_acc) self.best_accuracy = eval_acc model_saved_dir = os.path.join(self.config.checkpoint_dir, "best_model") @@ -548,7 +646,7 @@ class ClassifierTask(BasicTask): save_result = fluid.io.save_persistables( executor=self.exe, dirname=model_saved_dir, - main_program=self.test_program) + main_program=self.main_program) ImageClassifierTask = ClassifierTask @@ -644,30 +742,34 @@ class SequenceLabelTask(BasicTask): return logits def _add_label(self): - self.label = fluid.layers.data( + label = fluid.layers.data( name="label", shape=[self.max_seq_len, 1], dtype='int64') + return label def _add_loss(self): labels = fluid.layers.flatten(self.label, axis=2) ce_loss = fluid.layers.cross_entropy(input=self.output, label=labels) - self.loss = fluid.layers.mean(x=ce_loss) + loss = fluid.layers.mean(x=ce_loss) + return loss def _add_metrics(self): - self.ret_labels = fluid.layers.reshape(x=self.label, shape=[-1, 1]) - self.ret_infers = fluid.layers.reshape( + ret_labels = fluid.layers.reshape(x=self.label, shape=[-1, 1]) + ret_infers = fluid.layers.reshape( x=fluid.layers.argmax(self.logits, axis=2), shape=[-1, 1]) self.seq_len = fluid.layers.data( name="seq_len", shape=[1], dtype='int64') - self.seq_len = fluid.layers.assign(self.seq_len) - self.metrics += [self.ret_labels, self.ret_infers, self.seq_len] - - def _init_end_event(self): - with self.log_writer.mode("train") as logw: - self.train_loss_scalar = logw.scalar(tag="Loss [train]") - with self.log_writer.mode("evaluate") as logw: - self.eval_f1_scalar = logw.scalar(tag="F1 [eval]") - self.eval_precision_scalar = logw.scalar(tag="Precision [eval]") - self.eval_recall_scalar = logw.scalar(tag="Recall [eval]") + seq_len = fluid.layers.assign(self.seq_len) + return [ret_labels, ret_infers, seq_len] + + def _build_env_end_event(self): + with self.log_writer.mode(self.phase) as logw: + self.env.loss_scalar = logw.scalar( + tag="Loss [{}]".format(self.phase)) + self.env.f1_scalar = logw.scalar(tag="F1 [{}]".format(self.phase)) + self.env.precision_scalar = logw.scalar( + tag="Precision [{}]".format(self.phase)) + self.env.recall_scalar = logw.scalar( + tag="Recall [{}]".format(self.phase)) def _calculate_metrics(self, run_states): total_infer = total_label = total_correct = loss_sum = 0 @@ -696,22 +798,22 @@ class SequenceLabelTask(BasicTask): def _log_interval_event(self, run_states): precision, recall, f1, avg_loss, run_speed = self._calculate_metrics( run_states) - self.train_loss_scalar.add_record(self.current_step, avg_loss) + self.env.loss_scalar.add_record(self.current_step, avg_loss) logger.info("step %d: loss=%.5f [step/sec: %.2f]" % (self.current_step, avg_loss, run_speed)) - def _eval_end_event(self, phase, run_states): + def _eval_end_event(self, run_states): precision, recall, f1, avg_loss, run_speed = self._calculate_metrics( run_states) - self.eval_f1_scalar.add_record(self.current_step, f1) - self.eval_precision_scalar.add_record(self.current_step, precision) - self.eval_recall_scalar.add_record(self.current_step, recall) + self.env.f1_scalar.add_record(self.current_step, f1) + self.env.precision_scalar.add_record(self.current_step, precision) + self.env.recall_scalar.add_record(self.current_step, recall) logger.info("[%s dataset evaluation result] [step/sec: %.2f]" % - (phase, run_speed)) + (self.phase, run_speed)) logger.info( "[%s evaluation] F1-Score=%f, precision=%f, recall=%f [step/sec: %.2f]" - % (phase, f1, precision, recall, run_speed)) - if f1 > self.best_f1: + % (self.phase, f1, precision, recall, run_speed)) + if self.phase in ["dev", "val"] and f1 > self.best_f1: self.best_f1 = f1 model_saved_dir = os.path.join(self.config.checkpoint_dir, "best_model") @@ -719,7 +821,9 @@ class SequenceLabelTask(BasicTask): (model_saved_dir, self.best_f1)) fluid.io.save_persistables(self.exe, dirname=model_saved_dir) - def get_feed_list(self, phase): - if phase in ["train", "dev", "val", "test"]: - return self.feed_list + [self.label.name] + [self.seq_len.name] - return self.feed_list + @property + def feed_list(self): + feed_list = [varname for varname in self._base_feed_list] + if self.is_train_phase or self.is_test_phase: + feed_list += [self.label.name, self.seq_len.name] + return feed_list