提交 f2b1fcd8 编写于 作者: D dengkaipeng

move config out model in train.py

上级 d37dc3fa
...@@ -23,8 +23,8 @@ __all__ = ["AttentionCluster"] ...@@ -23,8 +23,8 @@ __all__ = ["AttentionCluster"]
class AttentionCluster(ModelBase): class AttentionCluster(ModelBase):
def __init__(self, name, cfg, mode='train', args=None): def __init__(self, name, cfg, mode='train'):
super(AttentionCluster, self).__init__(name, cfg, mode, args) super(AttentionCluster, self).__init__(name, cfg, mode)
self.get_config() self.get_config()
def get_config(self): def get_config(self):
......
...@@ -105,7 +105,7 @@ class ModelConfig(object): ...@@ -105,7 +105,7 @@ class ModelConfig(object):
class ModelBase(object): class ModelBase(object):
def __init__(self, name, cfg, mode='train', args=None): def __init__(self, name, cfg, mode='train'):
assert mode in ['train', 'valid', 'test', 'infer'], \ assert mode in ['train', 'valid', 'test', 'infer'], \
"Unknown mode type {}".format(mode) "Unknown mode type {}".format(mode)
self.name = name self.name = name
...@@ -114,13 +114,14 @@ class ModelBase(object): ...@@ -114,13 +114,14 @@ class ModelBase(object):
self.py_reader = None self.py_reader = None
# parse config # parse config
assert os.path.exists(cfg), \ # assert os.path.exists(cfg), \
"Config file {} not exists".format(cfg) # "Config file {} not exists".format(cfg)
self._config = ModelConfig(cfg) # self._config = ModelConfig(cfg)
self._config.parse() # self._config.parse()
if args and isinstance(args, dict): # if args and isinstance(args, dict):
self._config.merge_configs(mode, args) # self._config.merge_configs(mode, args)
self.cfg = self._config.get_configs() # self.cfg = self._config.get_configs()
self.cfg = cfg
def build_model(self): def build_model(self):
"build model struct" "build model struct"
...@@ -209,9 +210,9 @@ class ModelBase(object): ...@@ -209,9 +210,9 @@ class ModelBase(object):
fluid.io.load_params(exe, pretrain, main_program=prog) fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None): def get_config_from_sec(self, sec, item, default=None):
cfg_item = self._config.get_config_from_sec(sec.upper(), if sec.upper() not in self.cfg:
item) or default return default
return cfg_item return self.cfg[sec.upper()].get(item, default)
class ModelZoo(object): class ModelZoo(object):
...@@ -223,10 +224,10 @@ class ModelZoo(object): ...@@ -223,10 +224,10 @@ class ModelZoo(object):
type(model)) type(model))
self.model_zoo[name] = model self.model_zoo[name] = model
def get(self, name, cfg, mode='train', args=None): def get(self, name, cfg, mode='train'):
for k, v in self.model_zoo.items(): for k, v in self.model_zoo.items():
if k == name: if k == name:
return v(name, cfg, mode, args) return v(name, cfg, mode)
raise ModelNotFoundError(name, self.model_zoo.keys()) raise ModelNotFoundError(name, self.model_zoo.keys())
...@@ -238,6 +239,6 @@ def regist_model(name, model): ...@@ -238,6 +239,6 @@ def regist_model(name, model):
model_zoo.regist(name, model) model_zoo.regist(name, model)
def get_model(name, cfg, mode='train', args=None): def get_model(name, cfg, mode='train'):
return model_zoo.get(name, cfg, mode, args) return model_zoo.get(name, cfg, mode)
...@@ -156,8 +156,8 @@ class STNET(ModelBase): ...@@ -156,8 +156,8 @@ class STNET(ModelBase):
def load_pretrain_params(self, exe, pretrain, prog): def load_pretrain_params(self, exe, pretrain, prog):
def is_parameter(var): def is_parameter(var):
if isinstance(var, fluid.framework.Parameter): if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \ return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
vars = filter(is_parameter, prog.list_vars()) vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars) fluid.io.load_vars(exe, pretrain, vars=vars)
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from tools.train_utils import train_with_pyreader, train_without_pyreader from tools.train_utils import train_with_pyreader, train_without_pyreader
from config import *
import models import models
logging.root.handlers = [] logging.root.handlers = []
...@@ -98,7 +99,17 @@ def parse_args(): ...@@ -98,7 +99,17 @@ def parse_args():
return args return args
def train(train_model, valid_model, args): def train(args):
# parse config
config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args))
train_model = models.get_model(
args.model_name, train_config, mode='train')
valid_model = models.get_model(
args.model_name, valid_config, mode='valid')
# build model
startup = fluid.Program() startup = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup): with fluid.program_guard(train_prog, startup):
...@@ -117,8 +128,6 @@ def train(train_model, valid_model, args): ...@@ -117,8 +128,6 @@ def train(train_model, valid_model, args):
# outputs, loss, label should be fetched, so set persistable to be true # outputs, loss, label should be fetched, so set persistable to be true
optimizer = train_model.optimizer() optimizer = train_model.optimizer()
optimizer.minimize(train_loss) optimizer.minimize(train_loss)
train_reader = train_model.reader()
train_metrics = train_model.metrics()
train_pyreader = train_model.pyreader() train_pyreader = train_model.pyreader()
if not args.no_memory_optimize: if not args.no_memory_optimize:
...@@ -132,7 +141,6 @@ def train(train_model, valid_model, args): ...@@ -132,7 +141,6 @@ def train(train_model, valid_model, args):
valid_feeds = valid_model.feeds() valid_feeds = valid_model.feeds()
valid_outputs = valid_model.outputs() valid_outputs = valid_model.outputs()
valid_loss = valid_model.loss() valid_loss = valid_model.loss()
valid_reader = valid_model.reader()
valid_metrics = valid_model.metrics() valid_metrics = valid_model.metrics()
valid_pyreader = valid_model.pyreader() valid_pyreader = valid_model.pyreader()
...@@ -156,6 +164,18 @@ def train(train_model, valid_model, args): ...@@ -156,6 +164,18 @@ def train(train_model, valid_model, args):
share_vars_from=train_exe, share_vars_from=train_exe,
main_program=valid_prog) main_program=valid_prog)
# get reader
# train_reader = get_reader(train_config)
# valid_reader = get_reader(valid_config)
train_reader = train_model.reader()
valid_reader = valid_model.reader()
# get metrics
# train_metrics = get_metrics(train_config)
# valid_metrics = get_metrics(valid_config)
train_metrics = train_model.metrics()
train_metrics = train_model.metrics()
train_fetch_list = [train_loss.name] + [x.name for x in train_outputs train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
] + [train_feeds[-1].name] ] + [train_feeds[-1].name]
valid_fetch_list = [valid_loss.name] + [x.name for x in valid_outputs valid_fetch_list = [valid_loss.name] + [x.name for x in valid_outputs
...@@ -186,11 +206,8 @@ def train(train_model, valid_model, args): ...@@ -186,11 +206,8 @@ def train(train_model, valid_model, args):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger.info(args) logger.info(args)
if not os.path.exists(args.save_dir): if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir) os.makedirs(args.save_dir)
train_model = models.get_model( train(args)
args.model_name, args.config, mode='train', args=vars(args))
valid_model = models.get_model(
args.model_name, args.config, mode='valid', args=vars(args))
train(train_model, valid_model, args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册