From 9ba1db46c4fcae793b1f72d3829a81ce34524c11 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 30 Jan 2019 13:20:50 +0000 Subject: [PATCH] move config out in test/infer.py --- fluid/PaddleCV/video/.gitignore | 2 +- fluid/PaddleCV/video/config.py | 58 +++++++++++++++++++ fluid/PaddleCV/video/infer.py | 13 +++-- .../models/attention_lstm/attention_lstm.py | 4 +- .../video/models/nextvlad/nextvlad.py | 4 +- fluid/PaddleCV/video/models/stnet/stnet.py | 4 +- fluid/PaddleCV/video/models/tsn/tsn.py | 4 +- fluid/PaddleCV/video/test.py | 22 +++++-- fluid/PaddleCV/video/utils.py | 25 ++++++++ 9 files changed, 117 insertions(+), 19 deletions(-) create mode 100755 fluid/PaddleCV/video/config.py create mode 100755 fluid/PaddleCV/video/utils.py diff --git a/fluid/PaddleCV/video/.gitignore b/fluid/PaddleCV/video/.gitignore index f601ca03..5bd3b170 100644 --- a/fluid/PaddleCV/video/.gitignore +++ b/fluid/PaddleCV/video/.gitignore @@ -1,6 +1,6 @@ data checkpoints output* -*.py +*.pyc *.swp *_result diff --git a/fluid/PaddleCV/video/config.py b/fluid/PaddleCV/video/config.py new file mode 100755 index 00000000..a534536c --- /dev/null +++ b/fluid/PaddleCV/video/config.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +try: + from configparser import ConfigParser +except: + from ConfigParser import ConfigParser + +from utils import AttrDict + +CONFIG_SECS = [ + 'train', + 'valid', + 'test', + 'infer', + ] + + +def parse_config(cfg_file): + parser = ConfigParser() + cfg = AttrDict() + parser.read(cfg_file) + for sec in parser.sections(): + sec_dict = AttrDict() + for k, v in parser.items(sec): + try: + v = eval(v) + except: + pass + setattr(sec_dict, k, v) + setattr(cfg, sec.upper(), sec_dict) + + return cfg + +def merge_configs(cfg, sec, args_dict): + assert sec in CONFIG_SECS, "invalid config section {}".format(sec) + sec_dict = getattr(cfg, sec.upper()) + for k, v in args_dict.items(): + if v is None: + continue + try: + if hasattr(sec_dict, k): + setattr(sec_dict, k, v) + except: + pass + return cfg + diff --git a/fluid/PaddleCV/video/infer.py b/fluid/PaddleCV/video/infer.py index 70539dc2..37324218 100755 --- a/fluid/PaddleCV/video/infer.py +++ b/fluid/PaddleCV/video/infer.py @@ -24,6 +24,7 @@ except: import pickle import paddle.fluid as fluid +from config import * import models FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' @@ -76,7 +77,13 @@ def parse_args(): return args -def infer(infer_model, args): +def infer(args): + # parse config + config = parse_config(args.config) + infer_config = merge_configs(config, 'infer', vars(args)) + infer_model = models.get_model( + args.model_name, infer_config, mode='infer') + infer_model.build_input(use_pyreader=False) infer_model.build_model() infer_feeds = infer_model.feeds() @@ -146,6 +153,4 @@ if __name__ == "__main__": args = parse_args() logger.info(args) - infer_model = models.get_model( - args.model_name, args.config, mode='infer', args=vars(args)) - infer(infer_model, args) + infer(args) diff --git a/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py b/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py index ca3d8a4b..5a125366 100755 --- a/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py +++ b/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py @@ -22,8 +22,8 @@ __all__ = ["AttentionLSTM"] class AttentionLSTM(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(AttentionLSTM, self).__init__(name, cfg, mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(AttentionLSTM, self).__init__(name, cfg, mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/models/nextvlad/nextvlad.py b/fluid/PaddleCV/video/models/nextvlad/nextvlad.py index cdd2fedf..b9dc851f 100755 --- a/fluid/PaddleCV/video/models/nextvlad/nextvlad.py +++ b/fluid/PaddleCV/video/models/nextvlad/nextvlad.py @@ -23,8 +23,8 @@ __all__ = ["NEXTVLAD"] class NEXTVLAD(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(NEXTVLAD, self).__init__(name, cfg, mode=mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(NEXTVLAD, self).__init__(name, cfg, mode=mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/models/stnet/stnet.py b/fluid/PaddleCV/video/models/stnet/stnet.py index daf3ce5b..06637bb3 100644 --- a/fluid/PaddleCV/video/models/stnet/stnet.py +++ b/fluid/PaddleCV/video/models/stnet/stnet.py @@ -21,8 +21,8 @@ __all__ = ["STNET"] class STNET(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(STNET, self).__init__(name, cfg, mode=mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(STNET, self).__init__(name, cfg, mode=mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/models/tsn/tsn.py b/fluid/PaddleCV/video/models/tsn/tsn.py index 8bda42e5..7ad9818e 100644 --- a/fluid/PaddleCV/video/models/tsn/tsn.py +++ b/fluid/PaddleCV/video/models/tsn/tsn.py @@ -22,8 +22,8 @@ __all__ = ["TSN"] class TSN(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(TSN, self).__init__(name, cfg, mode=mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(TSN, self).__init__(name, cfg, mode=mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/test.py b/fluid/PaddleCV/video/test.py index 3111fba7..96e19ae6 100755 --- a/fluid/PaddleCV/video/test.py +++ b/fluid/PaddleCV/video/test.py @@ -19,6 +19,8 @@ import logging import argparse import numpy as np import paddle.fluid as fluid + +from config import * import models logging.root.handlers = [] @@ -60,13 +62,18 @@ def parse_args(): return args -def test(test_model, args): +def test(args): + # parse config + config = parse_config(args.config) + test_config = merge_configs(config, 'test', vars(args)) + + # build model + test_model = models.get_model( + args.model_name, test_config, mode='test') test_model.build_input(use_pyreader=False) test_model.build_model() test_feeds = test_model.feeds() test_outputs = test_model.outputs() - test_reader = test_model.reader() - test_metrics = test_model.metrics() loss = test_model.loss() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() @@ -82,6 +89,10 @@ def test(test_model, args): fluid.io.load_vars(exe, weights, predicate=if_exist) + # get reader and metrics + test_reader = test_model.reader() + test_metrics = test_model.metrics() + test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds) fetch_list = [loss.name] + [x.name for x in test_outputs] + [test_feeds[-1].name] @@ -111,7 +122,6 @@ def test(test_model, args): if __name__ == "__main__": args = parse_args() + logger.info(args) - test_model = models.get_model( - args.model_name, args.config, mode='test', args=vars(args)) - test(test_model, args) + test(args) diff --git a/fluid/PaddleCV/video/utils.py b/fluid/PaddleCV/video/utils.py new file mode 100755 index 00000000..3b07d606 --- /dev/null +++ b/fluid/PaddleCV/video/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +__all__ = ['AttrDict'] + +class AttrDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value -- GitLab