提交 ee097042 编写于 作者: W wuzewu

add use_data_parallel config

上级 ef3a1931
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=bool, default=True, help="Whether use GPU for predict.") parser.add_argument("--use_gpu", type=bool, default=False, help="Whether use GPU for predict.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.") parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.") parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="resnet50", help="Module used as a feature extractor.") parser.add_argument("--module", type=str, default="resnet50", help="Module used as a feature extractor.")
...@@ -55,6 +55,7 @@ def predict(args): ...@@ -55,6 +55,7 @@ def predict(args):
feed_list = [img.name] feed_list = [img.name]
config = hub.RunConfig( config = hub.RunConfig(
use_data_parallel=False,
use_cuda=args.use_gpu, use_cuda=args.use_gpu,
batch_size=args.batch_size, batch_size=args.batch_size,
enable_memory_optim=False, enable_memory_optim=False,
......
...@@ -41,14 +41,22 @@ def load_checkpoint(checkpoint_dir, ...@@ -41,14 +41,22 @@ def load_checkpoint(checkpoint_dir,
ckpt.ParseFromString(f.read()) ckpt.ParseFromString(f.read())
current_epoch = 1 current_epoch = 1
global_step = 0 global_step = 0
pretrained_model = ""
best_model_path = os.path.join(checkpoint_dir, "best_model") best_model_path = os.path.join(checkpoint_dir, "best_model")
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
if load_best_model and os.path.exists(best_model_path): if load_best_model and os.path.exists(best_model_path):
fluid.io.load_persistables(exe, best_model_path, main_program) pretrained_model = best_model_path
fluid.io.load_vars(
exe, best_model_path, main_program, predicate=if_exist)
logger.info("PaddleHub model best model loaded.") logger.info("PaddleHub model best model loaded.")
return current_epoch, global_step return current_epoch, global_step
elif ckpt.latest_model_dir: elif ckpt.latest_model_dir:
fluid.io.load_persistables(exe, ckpt.latest_model_dir, main_program) pretrained_model = ckpt.latest_model_dir
fluid.io.load_vars(
exe, ckpt.latest_model_dir, main_program, predicate=if_exist)
logger.info("PaddleHub model checkpoint loaded. current_epoch={}, " logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}".format(ckpt.current_epoch, "global_step={}".format(ckpt.current_epoch,
......
...@@ -31,6 +31,7 @@ class RunConfig(object): ...@@ -31,6 +31,7 @@ class RunConfig(object):
log_interval=10, log_interval=10,
eval_interval=100, eval_interval=100,
use_pyreader=False, use_pyreader=False,
use_data_parallel=True,
save_ckpt_interval=None, save_ckpt_interval=None,
use_cuda=True, use_cuda=True,
checkpoint_dir=None, checkpoint_dir=None,
...@@ -47,6 +48,7 @@ class RunConfig(object): ...@@ -47,6 +48,7 @@ class RunConfig(object):
self._num_epoch = num_epoch self._num_epoch = num_epoch
self._batch_size = batch_size self._batch_size = batch_size
self._use_pyreader = use_pyreader self._use_pyreader = use_pyreader
self._use_data_parallel = use_data_parallel
if strategy is None: if strategy is None:
self._strategy = DefaultStrategy() self._strategy = DefaultStrategy()
else: else:
...@@ -100,3 +102,7 @@ class RunConfig(object): ...@@ -100,3 +102,7 @@ class RunConfig(object):
@property @property
def use_pyreader(self): def use_pyreader(self):
return self._use_pyreader return self._use_pyreader
@property
def use_data_parallel(self):
return self._use_data_parallel
...@@ -222,14 +222,19 @@ class BasicTask(object): ...@@ -222,14 +222,19 @@ class BasicTask(object):
else: else:
share_vars_from = self._base_compile_program share_vars_from = self._base_compile_program
self.env.main_program_compiled = fluid.CompiledProgram( if not self.config.use_data_parallel:
self.env.main_program).with_data_parallel( if self.config.enable_memory_optim:
loss_name=loss_name, fluid.memory_optimize(self.env.main_program)
share_vars_from=share_vars_from, self.env.main_program_compiled = self.env.main_program
build_strategy=self.build_strategy) else:
self.env.main_program_compiled = fluid.CompiledProgram(
self.env.main_program).with_data_parallel(
loss_name=loss_name,
share_vars_from=share_vars_from,
build_strategy=self.build_strategy)
if self._base_compile_program is None: if self._base_compile_program is None:
self._base_compile_program = self.env.main_program_compiled self._base_compile_program = self.env.main_program_compiled
self.exe.run(self.env.startup_program) self.exe.run(self.env.startup_program)
self._build_env_end_event() self._build_env_end_event()
...@@ -423,15 +428,6 @@ class BasicTask(object): ...@@ -423,15 +428,6 @@ class BasicTask(object):
startup_program=self._base_startup_program, startup_program=self._base_startup_program,
load_best_model=load_best_model) load_best_model=load_best_model)
if load_best_model:
model_saved_dir = os.path.join(self.config.checkpoint_dir,
"best_model")
if os.path.exists(model_saved_dir):
fluid.io.load_persistables(
executor=self.exe,
dirname=model_saved_dir,
main_program=self.main_program)
def finetune_and_eval(self): def finetune_and_eval(self):
self.finetune(do_eval=True) self.finetune(do_eval=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册