提交 ead7f993 编写于 作者: W wuzewu

Update predict interface

上级 3b2cceb2
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=bool, default=False, help="Whether use GPU for predict.") parser.add_argument("--use_gpu", type=bool, default=True, help="Whether use GPU for predict.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.") parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.") parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="resnet50", help="Module used as a feature extractor.") parser.add_argument("--module", type=str, default="resnet50", help="Module used as a feature extractor.")
...@@ -70,15 +70,17 @@ def predict(args): ...@@ -70,15 +70,17 @@ def predict(args):
data = ["./test/test_img_daisy.jpg", "./test/test_img_roses.jpg"] data = ["./test/test_img_daisy.jpg", "./test/test_img_roses.jpg"]
label_map = dataset.label_dict() label_map = dataset.label_dict()
for result in task.predict(data=data):
result = np.argmax(result, axis=2)
index = 0 index = 0
for batch in result: # get classification result
for predict_result in batch: results = task.predict(data=data)
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
index += 1 index += 1
predict_result = label_map[predict_result] result = label_map[result]
print("input %i is %s, and the predict result is %s" % print("input %i is %s, and the predict result is %s" %
(index, data[index - 1], predict_result)) (index, data[index - 1], result))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -30,28 +30,33 @@ CKPT_FILE_NAME = "ckpt.meta" ...@@ -30,28 +30,33 @@ CKPT_FILE_NAME = "ckpt.meta"
def load_checkpoint(checkpoint_dir, def load_checkpoint(checkpoint_dir,
exe, exe,
main_program=fluid.default_main_program(), main_program=fluid.default_main_program(),
startup_program=fluid.default_startup_program()): startup_program=fluid.default_startup_program(),
load_best_model=False):
ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME) ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
ckpt = checkpoint_pb2.CheckPoint()
logger.info("Try loading checkpoint from {}".format(ckpt_meta_path)) logger.info("Try loading checkpoint from {}".format(ckpt_meta_path))
if os.path.exists(ckpt_meta_path): if os.path.exists(ckpt_meta_path):
ckpt = checkpoint_pb2.CheckPoint()
with open(ckpt_meta_path, "rb") as f: with open(ckpt_meta_path, "rb") as f:
ckpt.ParseFromString(f.read()) ckpt.ParseFromString(f.read())
current_epoch = 1
global_step = 0
best_model_path = os.path.join(checkpoint_dir, "best_model")
if load_best_model and os.path.exists(best_model_path):
fluid.io.load_persistables(exe, best_model_path, main_program)
logger.info("PaddleHub model best model loaded.")
return current_epoch, global_step
elif ckpt.latest_model_dir:
fluid.io.load_persistables(exe, ckpt.latest_model_dir, main_program) fluid.io.load_persistables(exe, ckpt.latest_model_dir, main_program)
logger.info("PaddleHub model checkpoint loaded. current_epoch={}, " logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}".format(ckpt.current_epoch, "global_step={}".format(ckpt.current_epoch,
ckpt.global_step)) ckpt.global_step))
return ckpt.current_epoch, ckpt.global_step return ckpt.current_epoch, ckpt.global_step
else:
current_epoch = 1
global_step = 0
latest_model_dir = None
logger.info( logger.info(
"PaddleHub model checkpoint not found, start training from scratch..." "PaddleHub model checkpoint not found, start training from scratch...")
)
exe.run(startup_program) exe.run(startup_program)
return current_epoch, global_step return current_epoch, global_step
......
...@@ -128,10 +128,11 @@ class BasicTask(object): ...@@ -128,10 +128,11 @@ class BasicTask(object):
# run environment # run environment
self._phases = [] self._phases = []
self._envs = {} self._envs = {}
self._predict_data = None
def init_if_necessary(self): def init_if_necessary(self, load_best_model=False):
if not self._load_checkpoint: if not self._load_checkpoint:
self.load_checkpoint() self.load_checkpoint(load_best_model=load_best_model)
self._load_checkpoint = True self._load_checkpoint = True
@contextlib.contextmanager @contextlib.contextmanager
...@@ -159,6 +160,11 @@ class BasicTask(object): ...@@ -159,6 +160,11 @@ class BasicTask(object):
self.env.loss = self._add_loss() self.env.loss = self._add_loss()
self.env.metrics = self._add_metrics() self.env.metrics = self._add_metrics()
if self.is_predict_phase or self.is_test_phase:
self.env.main_program = self.env.main_program.clone(for_test=True)
hub.common.paddle_helper.set_op_attr(
self.env.main_program, is_test=True)
if self.config.use_pyreader: if self.config.use_pyreader:
t_program = fluid.Program() t_program = fluid.Program()
with fluid.program_guard(t_program, self.env.startup_program): with fluid.program_guard(t_program, self.env.startup_program):
...@@ -291,8 +297,12 @@ class BasicTask(object): ...@@ -291,8 +297,12 @@ class BasicTask(object):
@property @property
def reader(self): def reader(self):
if self.is_predict_phase:
data = self._predict_data
else:
data = None
self.env.reader = self._base_data_reader.data_generator( self.env.reader = self._base_data_reader.data_generator(
batch_size=self.config.batch_size, phase=self.phase) batch_size=self.config.batch_size, phase=self.phase, data=data)
return self.env.reader return self.env.reader
@property @property
...@@ -315,8 +325,6 @@ class BasicTask(object): ...@@ -315,8 +325,6 @@ class BasicTask(object):
@property @property
def output(self): def output(self):
if self.is_predict_phase:
raise RuntimeError()
if not self.env.is_inititalized: if not self.env.is_inititalized:
self._build_env() self._build_env()
return self.env.output return self.env.output
...@@ -412,7 +420,8 @@ class BasicTask(object): ...@@ -412,7 +420,8 @@ class BasicTask(object):
self.config.checkpoint_dir, self.config.checkpoint_dir,
self.exe, self.exe,
main_program=self.main_program, main_program=self.main_program,
startup_program=self._base_startup_program) startup_program=self._base_startup_program,
load_best_model=load_best_model)
if load_best_model: if load_best_model:
model_saved_dir = os.path.join(self.config.checkpoint_dir, model_saved_dir = os.path.join(self.config.checkpoint_dir,
...@@ -454,10 +463,12 @@ class BasicTask(object): ...@@ -454,10 +463,12 @@ class BasicTask(object):
self._eval_end_event(run_states) self._eval_end_event(run_states)
def predict(self, data, load_best_model=True): def predict(self, data, load_best_model=True):
with self.phase_guard(phase=phase): with self.phase_guard(phase="predict"):
self.init_if_necessary() self._predict_data = data
for run_state in self._run(): self.init_if_necessary(load_best_model=load_best_model)
yield run_state.run_results run_states = self._run()
self._predict_data = None
return [run_state.run_results for run_state in run_states]
def _run(self, do_eval=False): def _run(self, do_eval=False):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册