提交 216b26cc 编写于 作者: S SunGaofeng

Merge branch 'video_classification' of https://github.com/SunGaofeng/models...

Merge branch 'video_classification' of https://github.com/SunGaofeng/models into video_classification
...@@ -34,14 +34,16 @@ batch_size = 128 ...@@ -34,14 +34,16 @@ batch_size = 128
filelist = "./dataset/kinetics/val.list" filelist = "./dataset/kinetics/val.list"
[TEST] [TEST]
seg_num = 25
short_size = 256 short_size = 256
target_size = 256 target_size = 256
num_reader_threads = 12 num_reader_threads = 12
buf_size = 1024 buf_size = 1024
batch_size = 16 batch_size = 4
filelist = "./dataset/kinetics/test.list" filelist = "./dataset/kinetics/test.list"
[INFER] [INFER]
seg_num = 25
short_size = 256 short_size = 256
target_size = 256 target_size = 256
num_reader_threads = 12 num_reader_threads = 12
......
...@@ -33,11 +33,12 @@ batch_size = 256 ...@@ -33,11 +33,12 @@ batch_size = 256
filelist = "./dataset/kinetics/val.list" filelist = "./dataset/kinetics/val.list"
[TEST] [TEST]
seg_num = 7
short_size = 256 short_size = 256
target_size = 224 target_size = 224
num_reader_threads = 12 num_reader_threads = 12
buf_size = 1024 buf_size = 1024
batch_size = 32 batch_size = 16
filelist = "./dataset/kinetics/test.list" filelist = "./dataset/kinetics/test.list"
[INFER] [INFER]
......
...@@ -54,16 +54,17 @@ class KineticsReader(DataReader): ...@@ -54,16 +54,17 @@ class KineticsReader(DataReader):
""" """
def __init__(self, name, mode, cfg): def __init__(self, name, mode, cfg):
self.name = name super(KineticsReader, self).__init__(name, mode, cfg)
self.mode = mode
self.format = cfg.MODEL.format self.format = cfg.MODEL.format
self.num_classes = cfg.MODEL.num_classes self.num_classes = self.get_config_from_sec('model', 'num_classes')
self.seg_num = cfg.MODEL.seg_num self.seg_num = self.get_config_from_sec('model', 'seg_num')
self.seglen = cfg.MODEL.seglen self.seglen = self.get_config_from_sec('model', 'seglen')
self.short_size = cfg[mode.upper()]['short_size']
self.target_size = cfg[mode.upper()]['target_size'] self.seg_num = self.get_config_from_sec(mode, 'seg_num', self.seg_num)
self.num_reader_threads = cfg[mode.upper()]['num_reader_threads'] self.short_size = self.get_config_from_sec(mode, 'short_size')
self.buf_size = cfg[mode.upper()]['buf_size'] self.target_size = self.get_config_from_sec(mode, 'target_size')
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads')
self.buf_size = self.get_config_from_sec(mode, 'buf_size')
self.img_mean = np.array(cfg.MODEL.image_mean).reshape( self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32) [3, 1, 1]).astype(np.float32)
......
...@@ -38,13 +38,20 @@ class DataReader(object): ...@@ -38,13 +38,20 @@ class DataReader(object):
"""data reader for video input""" """data reader for video input"""
def __init__(self, model_name, mode, cfg): def __init__(self, model_name, mode, cfg):
"""Not implemented""" self.name = model_name
pass self.mode = mode
self.cfg = cfg
def create_reader(self): def create_reader(self):
"""Not implemented""" """Not implemented"""
pass pass
def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg:
return default
return self.cfg[sec.upper()].get(item, default)
class ReaderZoo(object): class ReaderZoo(object):
def __init__(self): def __init__(self):
......
...@@ -65,17 +65,9 @@ class ModelBase(object): ...@@ -65,17 +65,9 @@ class ModelBase(object):
self.name = name self.name = name
self.is_training = (mode == 'train') self.is_training = (mode == 'train')
self.mode = mode self.mode = mode
self.cfg = cfg
self.py_reader = None self.py_reader = None
# parse config
# assert os.path.exists(cfg), \
# "Config file {} not exists".format(cfg)
# self._config = ModelConfig(cfg)
# self._config.parse()
# if args and isinstance(args, dict):
# self._config.merge_configs(mode, args)
# self.cfg = self._config.get_configs()
self.cfg = cfg
def build_model(self): def build_model(self):
"build model struct" "build model struct"
......
...@@ -46,6 +46,7 @@ class STNET(ModelBase): ...@@ -46,6 +46,7 @@ class STNET(ModelBase):
'l2_weight_decay') 'l2_weight_decay')
self.momentum = self.get_config_from_sec('train', 'momentum') self.momentum = self.get_config_from_sec('train', 'momentum')
self.seg_num = self.get_config_from_sec(self.mode, 'seg_num', self.seg_num)
self.target_size = self.get_config_from_sec(self.mode, 'target_size') self.target_size = self.get_config_from_sec(self.mode, 'target_size')
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size') self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
......
...@@ -47,6 +47,7 @@ class TSN(ModelBase): ...@@ -47,6 +47,7 @@ class TSN(ModelBase):
'l2_weight_decay') 'l2_weight_decay')
self.momentum = self.get_config_from_sec('train', 'momentum') self.momentum = self.get_config_from_sec('train', 'momentum')
self.seg_num = self.get_config_from_sec(self.mode, 'seg_num', self.seg_num)
self.target_size = self.get_config_from_sec(self.mode, 'target_size') self.target_size = self.get_config_from_sec(self.mode, 'target_size')
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size') self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册