提交 ee097042 编写于 作者: W wuzewu

add use_data_parallel config

上级 ef3a1931
......@@ -8,7 +8,7 @@ import numpy as np
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=bool, default=True, help="Whether use GPU for predict.")
parser.add_argument("--use_gpu", type=bool, default=False, help="Whether use GPU for predict.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="resnet50", help="Module used as a feature extractor.")
......@@ -55,6 +55,7 @@ def predict(args):
feed_list = [img.name]
config = hub.RunConfig(
use_data_parallel=False,
use_cuda=args.use_gpu,
batch_size=args.batch_size,
enable_memory_optim=False,
......
......@@ -41,14 +41,22 @@ def load_checkpoint(checkpoint_dir,
ckpt.ParseFromString(f.read())
current_epoch = 1
global_step = 0
pretrained_model = ""
best_model_path = os.path.join(checkpoint_dir, "best_model")
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
if load_best_model and os.path.exists(best_model_path):
fluid.io.load_persistables(exe, best_model_path, main_program)
pretrained_model = best_model_path
fluid.io.load_vars(
exe, best_model_path, main_program, predicate=if_exist)
logger.info("PaddleHub model best model loaded.")
return current_epoch, global_step
elif ckpt.latest_model_dir:
fluid.io.load_persistables(exe, ckpt.latest_model_dir, main_program)
pretrained_model = ckpt.latest_model_dir
fluid.io.load_vars(
exe, ckpt.latest_model_dir, main_program, predicate=if_exist)
logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}".format(ckpt.current_epoch,
......
......@@ -31,6 +31,7 @@ class RunConfig(object):
log_interval=10,
eval_interval=100,
use_pyreader=False,
use_data_parallel=True,
save_ckpt_interval=None,
use_cuda=True,
checkpoint_dir=None,
......@@ -47,6 +48,7 @@ class RunConfig(object):
self._num_epoch = num_epoch
self._batch_size = batch_size
self._use_pyreader = use_pyreader
self._use_data_parallel = use_data_parallel
if strategy is None:
self._strategy = DefaultStrategy()
else:
......@@ -100,3 +102,7 @@ class RunConfig(object):
@property
def use_pyreader(self):
return self._use_pyreader
@property
def use_data_parallel(self):
return self._use_data_parallel
......@@ -222,14 +222,19 @@ class BasicTask(object):
else:
share_vars_from = self._base_compile_program
self.env.main_program_compiled = fluid.CompiledProgram(
self.env.main_program).with_data_parallel(
loss_name=loss_name,
share_vars_from=share_vars_from,
build_strategy=self.build_strategy)
if not self.config.use_data_parallel:
if self.config.enable_memory_optim:
fluid.memory_optimize(self.env.main_program)
self.env.main_program_compiled = self.env.main_program
else:
self.env.main_program_compiled = fluid.CompiledProgram(
self.env.main_program).with_data_parallel(
loss_name=loss_name,
share_vars_from=share_vars_from,
build_strategy=self.build_strategy)
if self._base_compile_program is None:
self._base_compile_program = self.env.main_program_compiled
if self._base_compile_program is None:
self._base_compile_program = self.env.main_program_compiled
self.exe.run(self.env.startup_program)
self._build_env_end_event()
......@@ -423,15 +428,6 @@ class BasicTask(object):
startup_program=self._base_startup_program,
load_best_model=load_best_model)
if load_best_model:
model_saved_dir = os.path.join(self.config.checkpoint_dir,
"best_model")
if os.path.exists(model_saved_dir):
fluid.io.load_persistables(
executor=self.exe,
dirname=model_saved_dir,
main_program=self.main_program)
def finetune_and_eval(self):
self.finetune(do_eval=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册