提交 4dd59a1a 编写于 作者: W WuHaobo

Init PaddleClas

上级 9f39da88
...@@ -89,8 +89,8 @@ def print_config(config): ...@@ -89,8 +89,8 @@ def print_config(config):
config: configs config: configs
""" """
copyright = "PaddleCLS is powered by PaddlePaddle" copyright = "PaddleClas is powered by PaddlePaddle"
ad = "https://github.com/PaddlePaddle/PaddleCLS" ad = "https://github.com/PaddlePaddle/PaddleClas"
logger.info("\n" * 2) logger.info("\n" * 2)
logger.info(copyright) logger.info(copyright)
...@@ -193,9 +193,11 @@ def get_config(fname, overrides=[], show=True): ...@@ -193,9 +193,11 @@ def get_config(fname, overrides=[], show=True):
assert os.path.exists(fname), \ assert os.path.exists(fname), \
('config file({}) is not exist'.format(fname)) ('config file({}) is not exist'.format(fname))
config = parse_config(fname) config = parse_config(fname)
if show: print_config(config) if show:
print_config(config)
if len(overrides) > 0: if len(overrides) > 0:
override_config(config, overrides) override_config(config, overrides)
print_config(config) if show:
print_config(config)
check_config(config) check_config(config)
return config return config
...@@ -30,7 +30,7 @@ __all__ = ['init_model', 'save_model'] ...@@ -30,7 +30,7 @@ __all__ = ['init_model', 'save_model']
def _mkdir_if_not_exist(path): def _mkdir_if_not_exist(path):
""" """
mkdir if not exists mkdir if not exists
""" """
if not os.path.exists(os.path.join(path)): if not os.path.exists(os.path.join(path)):
os.makedirs(os.path.join(path)) os.makedirs(os.path.join(path))
...@@ -97,25 +97,27 @@ def load_params(exe, prog, path, ignore_params=[]): ...@@ -97,25 +97,27 @@ def load_params(exe, prog, path, ignore_params=[]):
fluid.io.set_program_state(prog, state) fluid.io.set_program_state(prog, state)
def init_model(config, program, exe): def init_model(config, program, exe, prefix="ppcls"):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
checkpoints = config.get('checkpoints') checkpoints = config.get('checkpoints')
if checkpoints and os.path.exists(checkpoints): if checkpoints:
fluid.load(program, checkpoints, exe) path = os.path.join(checkpoints, prefix)
logger.info("Finish initing model from {}".format(checkpoints)) fluid.load(program, path, exe)
logger.info("Finish initing model from {}".format(path))
return return
pretrained_model = config.get('pretrained_model') pretrained_model = config.get('pretrained_model')
if pretrained_model and os.path.exists(pretrained_model): if pretrained_model:
load_params(exe, program, pretrained_model) path = os.path.join(pretrained_model, prefix)
logger.info("Finish initing model from {}".format(pretrained_model)) load_params(exe, program, path)
logger.info("Finish initing model from {}".format(path))
def save_model(program, model_path, epoch_id, prefix='ppcls'): def save_model(program, model_path, epoch_id, prefix='ppcls'):
""" """
save model to the target path save model to the target path
""" """
model_path = os.path.join(model_path, str(epoch_id)) model_path = os.path.join(model_path, str(epoch_id))
_mkdir_if_not_exist(model_path) _mkdir_if_not_exist(model_path)
......
...@@ -357,6 +357,8 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -357,6 +357,8 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
""" """
fetch_list = [f[0] for f in fetchs.values()] fetch_list = [f[0] for f in fetchs.values()]
metric_list = [f[1] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()]
for m in metric_list:
m.reset()
batch_time = AverageMeter('cost', ':6.3f') batch_time = AverageMeter('cost', ':6.3f')
tic = time.time() tic = time.time()
for idx, batch in enumerate(dataloader()): for idx, batch in enumerate(dataloader()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册