提交 9ba1db46 编写于 作者: D dengkaipeng

move config out in test/infer.py

上级 f2b1fcd8
data data
checkpoints checkpoints
output* output*
*.py *.pyc
*.swp *.swp
*_result *_result
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
try:
from configparser import ConfigParser
except:
from ConfigParser import ConfigParser
from utils import AttrDict
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
def parse_config(cfg_file):
parser = ConfigParser()
cfg = AttrDict()
parser.read(cfg_file)
for sec in parser.sections():
sec_dict = AttrDict()
for k, v in parser.items(sec):
try:
v = eval(v)
except:
pass
setattr(sec_dict, k, v)
setattr(cfg, sec.upper(), sec_dict)
return cfg
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
...@@ -24,6 +24,7 @@ except: ...@@ -24,6 +24,7 @@ except:
import pickle import pickle
import paddle.fluid as fluid import paddle.fluid as fluid
from config import *
import models import models
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
...@@ -76,7 +77,13 @@ def parse_args(): ...@@ -76,7 +77,13 @@ def parse_args():
return args return args
def infer(infer_model, args): def infer(args):
# parse config
config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args))
infer_model = models.get_model(
args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False) infer_model.build_input(use_pyreader=False)
infer_model.build_model() infer_model.build_model()
infer_feeds = infer_model.feeds() infer_feeds = infer_model.feeds()
...@@ -146,6 +153,4 @@ if __name__ == "__main__": ...@@ -146,6 +153,4 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
logger.info(args) logger.info(args)
infer_model = models.get_model( infer(args)
args.model_name, args.config, mode='infer', args=vars(args))
infer(infer_model, args)
...@@ -22,8 +22,8 @@ __all__ = ["AttentionLSTM"] ...@@ -22,8 +22,8 @@ __all__ = ["AttentionLSTM"]
class AttentionLSTM(ModelBase): class AttentionLSTM(ModelBase):
def __init__(self, name, cfg, mode='train', args=None): def __init__(self, name, cfg, mode='train'):
super(AttentionLSTM, self).__init__(name, cfg, mode, args=args) super(AttentionLSTM, self).__init__(name, cfg, mode)
self.get_config() self.get_config()
def get_config(self): def get_config(self):
......
...@@ -23,8 +23,8 @@ __all__ = ["NEXTVLAD"] ...@@ -23,8 +23,8 @@ __all__ = ["NEXTVLAD"]
class NEXTVLAD(ModelBase): class NEXTVLAD(ModelBase):
def __init__(self, name, cfg, mode='train', args=None): def __init__(self, name, cfg, mode='train'):
super(NEXTVLAD, self).__init__(name, cfg, mode=mode, args=args) super(NEXTVLAD, self).__init__(name, cfg, mode=mode)
self.get_config() self.get_config()
def get_config(self): def get_config(self):
......
...@@ -21,8 +21,8 @@ __all__ = ["STNET"] ...@@ -21,8 +21,8 @@ __all__ = ["STNET"]
class STNET(ModelBase): class STNET(ModelBase):
def __init__(self, name, cfg, mode='train', args=None): def __init__(self, name, cfg, mode='train'):
super(STNET, self).__init__(name, cfg, mode=mode, args=args) super(STNET, self).__init__(name, cfg, mode=mode)
self.get_config() self.get_config()
def get_config(self): def get_config(self):
......
...@@ -22,8 +22,8 @@ __all__ = ["TSN"] ...@@ -22,8 +22,8 @@ __all__ = ["TSN"]
class TSN(ModelBase): class TSN(ModelBase):
def __init__(self, name, cfg, mode='train', args=None): def __init__(self, name, cfg, mode='train'):
super(TSN, self).__init__(name, cfg, mode=mode, args=args) super(TSN, self).__init__(name, cfg, mode=mode)
self.get_config() self.get_config()
def get_config(self): def get_config(self):
......
...@@ -19,6 +19,8 @@ import logging ...@@ -19,6 +19,8 @@ import logging
import argparse import argparse
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from config import *
import models import models
logging.root.handlers = [] logging.root.handlers = []
...@@ -60,13 +62,18 @@ def parse_args(): ...@@ -60,13 +62,18 @@ def parse_args():
return args return args
def test(test_model, args): def test(args):
# parse config
config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args))
# build model
test_model = models.get_model(
args.model_name, test_config, mode='test')
test_model.build_input(use_pyreader=False) test_model.build_input(use_pyreader=False)
test_model.build_model() test_model.build_model()
test_feeds = test_model.feeds() test_feeds = test_model.feeds()
test_outputs = test_model.outputs() test_outputs = test_model.outputs()
test_reader = test_model.reader()
test_metrics = test_model.metrics()
loss = test_model.loss() loss = test_model.loss()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
...@@ -82,6 +89,10 @@ def test(test_model, args): ...@@ -82,6 +89,10 @@ def test(test_model, args):
fluid.io.load_vars(exe, weights, predicate=if_exist) fluid.io.load_vars(exe, weights, predicate=if_exist)
# get reader and metrics
test_reader = test_model.reader()
test_metrics = test_model.metrics()
test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds) test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)
fetch_list = [loss.name] + [x.name fetch_list = [loss.name] + [x.name
for x in test_outputs] + [test_feeds[-1].name] for x in test_outputs] + [test_feeds[-1].name]
...@@ -111,7 +122,6 @@ def test(test_model, args): ...@@ -111,7 +122,6 @@ def test(test_model, args):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger.info(args)
test_model = models.get_model( test(args)
args.model_name, args.config, mode='test', args=vars(args))
test(test_model, args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
__all__ = ['AttrDict']
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册