diff --git a/fluid/PaddleCV/video/.gitignore b/fluid/PaddleCV/video/.gitignore index f601ca035a20ddccbe8ee725ec6787668b32572f..5bd3b1700740dd01bc966a0233e2e87cd5e5c9f6 100644 --- a/fluid/PaddleCV/video/.gitignore +++ b/fluid/PaddleCV/video/.gitignore @@ -1,6 +1,6 @@ data checkpoints output* -*.py +*.pyc *.swp *_result diff --git a/fluid/PaddleCV/video/config.py b/fluid/PaddleCV/video/config.py new file mode 100755 index 0000000000000000000000000000000000000000..a534536c35c9446ed7dd4139c831757654e02222 --- /dev/null +++ b/fluid/PaddleCV/video/config.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +try: + from configparser import ConfigParser +except: + from ConfigParser import ConfigParser + +from utils import AttrDict + +CONFIG_SECS = [ + 'train', + 'valid', + 'test', + 'infer', + ] + + +def parse_config(cfg_file): + parser = ConfigParser() + cfg = AttrDict() + parser.read(cfg_file) + for sec in parser.sections(): + sec_dict = AttrDict() + for k, v in parser.items(sec): + try: + v = eval(v) + except: + pass + setattr(sec_dict, k, v) + setattr(cfg, sec.upper(), sec_dict) + + return cfg + +def merge_configs(cfg, sec, args_dict): + assert sec in CONFIG_SECS, "invalid config section {}".format(sec) + sec_dict = getattr(cfg, sec.upper()) + for k, v in args_dict.items(): + if v is None: + continue + try: + if hasattr(sec_dict, k): + setattr(sec_dict, k, v) + except: + pass + return cfg + diff --git a/fluid/PaddleCV/video/infer.py b/fluid/PaddleCV/video/infer.py index 70539dc251fb0580434951b2881f1c989efc6727..373242189589ce6e0e42bb4ee48a945daa997d5f 100755 --- a/fluid/PaddleCV/video/infer.py +++ b/fluid/PaddleCV/video/infer.py @@ -24,6 +24,7 @@ except: import pickle import paddle.fluid as fluid +from config import * import models FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' @@ -76,7 +77,13 @@ def parse_args(): return args -def infer(infer_model, args): +def infer(args): + # parse config + config = parse_config(args.config) + infer_config = merge_configs(config, 'infer', vars(args)) + infer_model = models.get_model( + args.model_name, infer_config, mode='infer') + infer_model.build_input(use_pyreader=False) infer_model.build_model() infer_feeds = infer_model.feeds() @@ -146,6 +153,4 @@ if __name__ == "__main__": args = parse_args() logger.info(args) - infer_model = models.get_model( - args.model_name, args.config, mode='infer', args=vars(args)) - infer(infer_model, args) + infer(args) diff --git a/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py b/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py index ca3d8a4b46e22b7e3b420a47b4366dedeab300ed..5a125366709b7022b8fb36e4721a9cb14beb42d0 100755 --- a/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py +++ b/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py @@ -22,8 +22,8 @@ __all__ = ["AttentionLSTM"] class AttentionLSTM(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(AttentionLSTM, self).__init__(name, cfg, mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(AttentionLSTM, self).__init__(name, cfg, mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/models/nextvlad/nextvlad.py b/fluid/PaddleCV/video/models/nextvlad/nextvlad.py index cdd2fedf5cabb028fb5fc4a745cb14c590df9a6b..b9dc851f29edb1d8022cec82991d9ed402c9e6d4 100755 --- a/fluid/PaddleCV/video/models/nextvlad/nextvlad.py +++ b/fluid/PaddleCV/video/models/nextvlad/nextvlad.py @@ -23,8 +23,8 @@ __all__ = ["NEXTVLAD"] class NEXTVLAD(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(NEXTVLAD, self).__init__(name, cfg, mode=mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(NEXTVLAD, self).__init__(name, cfg, mode=mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/models/stnet/stnet.py b/fluid/PaddleCV/video/models/stnet/stnet.py index daf3ce5b28b6537b1051c1ef6c2671dcf2604be1..06637bb37fc3bfe328b0ac9872e0d4d1127770ad 100644 --- a/fluid/PaddleCV/video/models/stnet/stnet.py +++ b/fluid/PaddleCV/video/models/stnet/stnet.py @@ -21,8 +21,8 @@ __all__ = ["STNET"] class STNET(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(STNET, self).__init__(name, cfg, mode=mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(STNET, self).__init__(name, cfg, mode=mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/models/tsn/tsn.py b/fluid/PaddleCV/video/models/tsn/tsn.py index 8bda42e5975cda64c2403e4c0fafd7efbf5439f1..7ad9818e6b4dee3dd16cf5e163bdb11db80e5912 100644 --- a/fluid/PaddleCV/video/models/tsn/tsn.py +++ b/fluid/PaddleCV/video/models/tsn/tsn.py @@ -22,8 +22,8 @@ __all__ = ["TSN"] class TSN(ModelBase): - def __init__(self, name, cfg, mode='train', args=None): - super(TSN, self).__init__(name, cfg, mode=mode, args=args) + def __init__(self, name, cfg, mode='train'): + super(TSN, self).__init__(name, cfg, mode=mode) self.get_config() def get_config(self): diff --git a/fluid/PaddleCV/video/test.py b/fluid/PaddleCV/video/test.py index 3111fba7a62b09b96b136a738b30eb397d880123..96e19ae6dcfee2edbaf25802d6d24057dabd5d70 100755 --- a/fluid/PaddleCV/video/test.py +++ b/fluid/PaddleCV/video/test.py @@ -19,6 +19,8 @@ import logging import argparse import numpy as np import paddle.fluid as fluid + +from config import * import models logging.root.handlers = [] @@ -60,13 +62,18 @@ def parse_args(): return args -def test(test_model, args): +def test(args): + # parse config + config = parse_config(args.config) + test_config = merge_configs(config, 'test', vars(args)) + + # build model + test_model = models.get_model( + args.model_name, test_config, mode='test') test_model.build_input(use_pyreader=False) test_model.build_model() test_feeds = test_model.feeds() test_outputs = test_model.outputs() - test_reader = test_model.reader() - test_metrics = test_model.metrics() loss = test_model.loss() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() @@ -82,6 +89,10 @@ def test(test_model, args): fluid.io.load_vars(exe, weights, predicate=if_exist) + # get reader and metrics + test_reader = test_model.reader() + test_metrics = test_model.metrics() + test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds) fetch_list = [loss.name] + [x.name for x in test_outputs] + [test_feeds[-1].name] @@ -111,7 +122,6 @@ def test(test_model, args): if __name__ == "__main__": args = parse_args() + logger.info(args) - test_model = models.get_model( - args.model_name, args.config, mode='test', args=vars(args)) - test(test_model, args) + test(args) diff --git a/fluid/PaddleCV/video/utils.py b/fluid/PaddleCV/video/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..3b07d606c60b9834429fef94d43c0a5619cd1db5 --- /dev/null +++ b/fluid/PaddleCV/video/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +__all__ = ['AttrDict'] + +class AttrDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value