diff --git a/PaddleCV/video/configs/stnet.txt b/PaddleCV/video/configs/stnet.txt index ff3e4ddd25202b0d75c4fb53425dfe41a8f4222a..7be17834d18ec157aec5738b11b5d68871430892 100644 --- a/PaddleCV/video/configs/stnet.txt +++ b/PaddleCV/video/configs/stnet.txt @@ -34,14 +34,16 @@ batch_size = 128 filelist = "./dataset/kinetics/val.list" [TEST] +seg_num = 25 short_size = 256 target_size = 256 num_reader_threads = 12 buf_size = 1024 -batch_size = 16 +batch_size = 4 filelist = "./dataset/kinetics/test.list" [INFER] +seg_num = 25 short_size = 256 target_size = 256 num_reader_threads = 12 diff --git a/PaddleCV/video/configs/tsn.txt b/PaddleCV/video/configs/tsn.txt index bca5ff349a9792bb07b18c815d7f994419cb82f5..d19353228f7c779092665552d8ae945c666d4882 100644 --- a/PaddleCV/video/configs/tsn.txt +++ b/PaddleCV/video/configs/tsn.txt @@ -33,11 +33,12 @@ batch_size = 256 filelist = "./dataset/kinetics/val.list" [TEST] +seg_num = 7 short_size = 256 target_size = 224 num_reader_threads = 12 buf_size = 1024 -batch_size = 32 +batch_size = 16 filelist = "./dataset/kinetics/test.list" [INFER] diff --git a/PaddleCV/video/datareader/kinetics_reader.py b/PaddleCV/video/datareader/kinetics_reader.py index ed1de044fdaef39f5ccda6c4684733c8dcd8c8ef..236ff8205c15b0d51a9251e619473509cd31b443 100644 --- a/PaddleCV/video/datareader/kinetics_reader.py +++ b/PaddleCV/video/datareader/kinetics_reader.py @@ -54,16 +54,17 @@ class KineticsReader(DataReader): """ def __init__(self, name, mode, cfg): - self.name = name - self.mode = mode + super(KineticsReader, self).__init__(name, mode, cfg) self.format = cfg.MODEL.format - self.num_classes = cfg.MODEL.num_classes - self.seg_num = cfg.MODEL.seg_num - self.seglen = cfg.MODEL.seglen - self.short_size = cfg[mode.upper()]['short_size'] - self.target_size = cfg[mode.upper()]['target_size'] - self.num_reader_threads = cfg[mode.upper()]['num_reader_threads'] - self.buf_size = cfg[mode.upper()]['buf_size'] + self.num_classes = self.get_config_from_sec('model', 'num_classes') + self.seg_num = self.get_config_from_sec('model', 'seg_num') + self.seglen = self.get_config_from_sec('model', 'seglen') + + self.seg_num = self.get_config_from_sec(mode, 'seg_num', self.seg_num) + self.short_size = self.get_config_from_sec(mode, 'short_size') + self.target_size = self.get_config_from_sec(mode, 'target_size') + self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads') + self.buf_size = self.get_config_from_sec(mode, 'buf_size') self.img_mean = np.array(cfg.MODEL.image_mean).reshape( [3, 1, 1]).astype(np.float32) diff --git a/PaddleCV/video/datareader/reader_utils.py b/PaddleCV/video/datareader/reader_utils.py index 4c8b436a74335c9b3b3361947123b1a3bb3d43dd..b21bed3df012f0f64cc4a4af9296c3d569925bc9 100644 --- a/PaddleCV/video/datareader/reader_utils.py +++ b/PaddleCV/video/datareader/reader_utils.py @@ -38,13 +38,20 @@ class DataReader(object): """data reader for video input""" def __init__(self, model_name, mode, cfg): - """Not implemented""" - pass + self.name = model_name + self.mode = mode + self.cfg = cfg def create_reader(self): """Not implemented""" pass + def get_config_from_sec(self, sec, item, default=None): + if sec.upper() not in self.cfg: + return default + return self.cfg[sec.upper()].get(item, default) + + class ReaderZoo(object): def __init__(self): diff --git a/PaddleCV/video/models/model.py b/PaddleCV/video/models/model.py index 5bc354be28aaae7ce54db82e79182e9ae78ed74f..bf5947f44f05c96d47b912edfbdc9b4c28f6321b 100644 --- a/PaddleCV/video/models/model.py +++ b/PaddleCV/video/models/model.py @@ -65,17 +65,9 @@ class ModelBase(object): self.name = name self.is_training = (mode == 'train') self.mode = mode + self.cfg = cfg self.py_reader = None - # parse config - # assert os.path.exists(cfg), \ - # "Config file {} not exists".format(cfg) - # self._config = ModelConfig(cfg) - # self._config.parse() - # if args and isinstance(args, dict): - # self._config.merge_configs(mode, args) - # self.cfg = self._config.get_configs() - self.cfg = cfg def build_model(self): "build model struct" diff --git a/PaddleCV/video/models/stnet/stnet.py b/PaddleCV/video/models/stnet/stnet.py index c408aa0894265b2e7b8eccb8b49cae4c799018cc..f82bba50a5f0fd8b565d0906946c9322a45f9002 100644 --- a/PaddleCV/video/models/stnet/stnet.py +++ b/PaddleCV/video/models/stnet/stnet.py @@ -46,6 +46,7 @@ class STNET(ModelBase): 'l2_weight_decay') self.momentum = self.get_config_from_sec('train', 'momentum') + self.seg_num = self.get_config_from_sec(self.mode, 'seg_num', self.seg_num) self.target_size = self.get_config_from_sec(self.mode, 'target_size') self.batch_size = self.get_config_from_sec(self.mode, 'batch_size') diff --git a/PaddleCV/video/models/tsn/tsn.py b/PaddleCV/video/models/tsn/tsn.py index 82fdb3279376a15e796df09818c343db24f048b4..5c18d0d7d34e53e102a72e13b2d1061d3b51b6c9 100644 --- a/PaddleCV/video/models/tsn/tsn.py +++ b/PaddleCV/video/models/tsn/tsn.py @@ -47,6 +47,7 @@ class TSN(ModelBase): 'l2_weight_decay') self.momentum = self.get_config_from_sec('train', 'momentum') + self.seg_num = self.get_config_from_sec(self.mode, 'seg_num', self.seg_num) self.target_size = self.get_config_from_sec(self.mode, 'target_size') self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')