提交 ead7f993 编写于 作者: W wuzewu

Update predict interface

上级 3b2cceb2
......@@ -8,7 +8,7 @@ import numpy as np
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=bool, default=False, help="Whether use GPU for predict.")
parser.add_argument("--use_gpu", type=bool, default=True, help="Whether use GPU for predict.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="resnet50", help="Module used as a feature extractor.")
......@@ -70,15 +70,17 @@ def predict(args):
data = ["./test/test_img_daisy.jpg", "./test/test_img_roses.jpg"]
label_map = dataset.label_dict()
for result in task.predict(data=data):
result = np.argmax(result, axis=2)
index = 0
for batch in result:
for predict_result in batch:
index += 1
predict_result = label_map[predict_result]
print("input %i is %s, and the predict result is %s" %
(index, data[index - 1], predict_result))
index = 0
# get classification result
results = task.predict(data=data)
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
index += 1
result = label_map[result]
print("input %i is %s, and the predict result is %s" %
(index, data[index - 1], result))
if __name__ == "__main__":
......
......@@ -30,31 +30,36 @@ CKPT_FILE_NAME = "ckpt.meta"
def load_checkpoint(checkpoint_dir,
exe,
main_program=fluid.default_main_program(),
startup_program=fluid.default_startup_program()):
startup_program=fluid.default_startup_program(),
load_best_model=False):
ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
ckpt = checkpoint_pb2.CheckPoint()
logger.info("Try loading checkpoint from {}".format(ckpt_meta_path))
if os.path.exists(ckpt_meta_path):
ckpt = checkpoint_pb2.CheckPoint()
with open(ckpt_meta_path, "rb") as f:
ckpt.ParseFromString(f.read())
current_epoch = 1
global_step = 0
best_model_path = os.path.join(checkpoint_dir, "best_model")
if load_best_model and os.path.exists(best_model_path):
fluid.io.load_persistables(exe, best_model_path, main_program)
logger.info("PaddleHub model best model loaded.")
return current_epoch, global_step
elif ckpt.latest_model_dir:
fluid.io.load_persistables(exe, ckpt.latest_model_dir, main_program)
logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}".format(ckpt.current_epoch,
ckpt.global_step))
return ckpt.current_epoch, ckpt.global_step
else:
current_epoch = 1
global_step = 0
latest_model_dir = None
logger.info(
"PaddleHub model checkpoint not found, start training from scratch..."
)
exe.run(startup_program)
return current_epoch, global_step
logger.info(
"PaddleHub model checkpoint not found, start training from scratch...")
exe.run(startup_program)
return current_epoch, global_step
def save_checkpoint(checkpoint_dir,
......
......@@ -128,10 +128,11 @@ class BasicTask(object):
# run environment
self._phases = []
self._envs = {}
self._predict_data = None
def init_if_necessary(self):
def init_if_necessary(self, load_best_model=False):
if not self._load_checkpoint:
self.load_checkpoint()
self.load_checkpoint(load_best_model=load_best_model)
self._load_checkpoint = True
@contextlib.contextmanager
......@@ -159,6 +160,11 @@ class BasicTask(object):
self.env.loss = self._add_loss()
self.env.metrics = self._add_metrics()
if self.is_predict_phase or self.is_test_phase:
self.env.main_program = self.env.main_program.clone(for_test=True)
hub.common.paddle_helper.set_op_attr(
self.env.main_program, is_test=True)
if self.config.use_pyreader:
t_program = fluid.Program()
with fluid.program_guard(t_program, self.env.startup_program):
......@@ -291,8 +297,12 @@ class BasicTask(object):
@property
def reader(self):
if self.is_predict_phase:
data = self._predict_data
else:
data = None
self.env.reader = self._base_data_reader.data_generator(
batch_size=self.config.batch_size, phase=self.phase)
batch_size=self.config.batch_size, phase=self.phase, data=data)
return self.env.reader
@property
......@@ -315,8 +325,6 @@ class BasicTask(object):
@property
def output(self):
if self.is_predict_phase:
raise RuntimeError()
if not self.env.is_inititalized:
self._build_env()
return self.env.output
......@@ -412,7 +420,8 @@ class BasicTask(object):
self.config.checkpoint_dir,
self.exe,
main_program=self.main_program,
startup_program=self._base_startup_program)
startup_program=self._base_startup_program,
load_best_model=load_best_model)
if load_best_model:
model_saved_dir = os.path.join(self.config.checkpoint_dir,
......@@ -454,10 +463,12 @@ class BasicTask(object):
self._eval_end_event(run_states)
def predict(self, data, load_best_model=True):
with self.phase_guard(phase=phase):
self.init_if_necessary()
for run_state in self._run():
yield run_state.run_results
with self.phase_guard(phase="predict"):
self._predict_data = data
self.init_if_necessary(load_best_model=load_best_model)
run_states = self._run()
self._predict_data = None
return [run_state.run_results for run_state in run_states]
def _run(self, do_eval=False):
with fluid.program_guard(self.main_program, self.startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册