提交 f2b1fcd8 编写于 作者: D dengkaipeng

move config out model in train.py

上级 d37dc3fa
......@@ -23,8 +23,8 @@ __all__ = ["AttentionCluster"]
class AttentionCluster(ModelBase):
def __init__(self, name, cfg, mode='train', args=None):
super(AttentionCluster, self).__init__(name, cfg, mode, args)
def __init__(self, name, cfg, mode='train'):
super(AttentionCluster, self).__init__(name, cfg, mode)
self.get_config()
def get_config(self):
......
......@@ -105,7 +105,7 @@ class ModelConfig(object):
class ModelBase(object):
def __init__(self, name, cfg, mode='train', args=None):
def __init__(self, name, cfg, mode='train'):
assert mode in ['train', 'valid', 'test', 'infer'], \
"Unknown mode type {}".format(mode)
self.name = name
......@@ -114,13 +114,14 @@ class ModelBase(object):
self.py_reader = None
# parse config
assert os.path.exists(cfg), \
"Config file {} not exists".format(cfg)
self._config = ModelConfig(cfg)
self._config.parse()
if args and isinstance(args, dict):
self._config.merge_configs(mode, args)
self.cfg = self._config.get_configs()
# assert os.path.exists(cfg), \
# "Config file {} not exists".format(cfg)
# self._config = ModelConfig(cfg)
# self._config.parse()
# if args and isinstance(args, dict):
# self._config.merge_configs(mode, args)
# self.cfg = self._config.get_configs()
self.cfg = cfg
def build_model(self):
"build model struct"
......@@ -209,9 +210,9 @@ class ModelBase(object):
fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
cfg_item = self._config.get_config_from_sec(sec.upper(),
item) or default
return cfg_item
if sec.upper() not in self.cfg:
return default
return self.cfg[sec.upper()].get(item, default)
class ModelZoo(object):
......@@ -223,10 +224,10 @@ class ModelZoo(object):
type(model))
self.model_zoo[name] = model
def get(self, name, cfg, mode='train', args=None):
def get(self, name, cfg, mode='train'):
for k, v in self.model_zoo.items():
if k == name:
return v(name, cfg, mode, args)
return v(name, cfg, mode)
raise ModelNotFoundError(name, self.model_zoo.keys())
......@@ -238,6 +239,6 @@ def regist_model(name, model):
model_zoo.regist(name, model)
def get_model(name, cfg, mode='train', args=None):
return model_zoo.get(name, cfg, mode, args)
def get_model(name, cfg, mode='train'):
return model_zoo.get(name, cfg, mode)
......@@ -156,8 +156,8 @@ class STNET(ModelBase):
def load_pretrain_params(self, exe, pretrain, prog):
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars)
......
......@@ -21,6 +21,7 @@ import numpy as np
import paddle.fluid as fluid
from tools.train_utils import train_with_pyreader, train_without_pyreader
from config import *
import models
logging.root.handlers = []
......@@ -98,7 +99,17 @@ def parse_args():
return args
def train(train_model, valid_model, args):
def train(args):
# parse config
config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args))
train_model = models.get_model(
args.model_name, train_config, mode='train')
valid_model = models.get_model(
args.model_name, valid_config, mode='valid')
# build model
startup = fluid.Program()
train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup):
......@@ -117,8 +128,6 @@ def train(train_model, valid_model, args):
# outputs, loss, label should be fetched, so set persistable to be true
optimizer = train_model.optimizer()
optimizer.minimize(train_loss)
train_reader = train_model.reader()
train_metrics = train_model.metrics()
train_pyreader = train_model.pyreader()
if not args.no_memory_optimize:
......@@ -132,7 +141,6 @@ def train(train_model, valid_model, args):
valid_feeds = valid_model.feeds()
valid_outputs = valid_model.outputs()
valid_loss = valid_model.loss()
valid_reader = valid_model.reader()
valid_metrics = valid_model.metrics()
valid_pyreader = valid_model.pyreader()
......@@ -156,6 +164,18 @@ def train(train_model, valid_model, args):
share_vars_from=train_exe,
main_program=valid_prog)
# get reader
# train_reader = get_reader(train_config)
# valid_reader = get_reader(valid_config)
train_reader = train_model.reader()
valid_reader = valid_model.reader()
# get metrics
# train_metrics = get_metrics(train_config)
# valid_metrics = get_metrics(valid_config)
train_metrics = train_model.metrics()
train_metrics = train_model.metrics()
train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
] + [train_feeds[-1].name]
valid_fetch_list = [valid_loss.name] + [x.name for x in valid_outputs
......@@ -186,11 +206,8 @@ def train(train_model, valid_model, args):
if __name__ == "__main__":
args = parse_args()
logger.info(args)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
train_model = models.get_model(
args.model_name, args.config, mode='train', args=vars(args))
valid_model = models.get_model(
args.model_name, args.config, mode='valid', args=vars(args))
train(train_model, valid_model, args)
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册