提交 216b26cc 编写于 作者: S SunGaofeng

Merge branch 'video_classification' of https://github.com/SunGaofeng/models...

Merge branch 'video_classification' of https://github.com/SunGaofeng/models into video_classification
......@@ -34,14 +34,16 @@ batch_size = 128
filelist = "./dataset/kinetics/val.list"
[TEST]
seg_num = 25
short_size = 256
target_size = 256
num_reader_threads = 12
buf_size = 1024
batch_size = 16
batch_size = 4
filelist = "./dataset/kinetics/test.list"
[INFER]
seg_num = 25
short_size = 256
target_size = 256
num_reader_threads = 12
......
......@@ -33,11 +33,12 @@ batch_size = 256
filelist = "./dataset/kinetics/val.list"
[TEST]
seg_num = 7
short_size = 256
target_size = 224
num_reader_threads = 12
buf_size = 1024
batch_size = 32
batch_size = 16
filelist = "./dataset/kinetics/test.list"
[INFER]
......
......@@ -54,16 +54,17 @@ class KineticsReader(DataReader):
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
super(KineticsReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.num_classes = cfg.MODEL.num_classes
self.seg_num = cfg.MODEL.seg_num
self.seglen = cfg.MODEL.seglen
self.short_size = cfg[mode.upper()]['short_size']
self.target_size = cfg[mode.upper()]['target_size']
self.num_reader_threads = cfg[mode.upper()]['num_reader_threads']
self.buf_size = cfg[mode.upper()]['buf_size']
self.num_classes = self.get_config_from_sec('model', 'num_classes')
self.seg_num = self.get_config_from_sec('model', 'seg_num')
self.seglen = self.get_config_from_sec('model', 'seglen')
self.seg_num = self.get_config_from_sec(mode, 'seg_num', self.seg_num)
self.short_size = self.get_config_from_sec(mode, 'short_size')
self.target_size = self.get_config_from_sec(mode, 'target_size')
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads')
self.buf_size = self.get_config_from_sec(mode, 'buf_size')
self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32)
......
......@@ -38,13 +38,20 @@ class DataReader(object):
"""data reader for video input"""
def __init__(self, model_name, mode, cfg):
"""Not implemented"""
pass
self.name = model_name
self.mode = mode
self.cfg = cfg
def create_reader(self):
"""Not implemented"""
pass
def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg:
return default
return self.cfg[sec.upper()].get(item, default)
class ReaderZoo(object):
def __init__(self):
......
......@@ -65,17 +65,9 @@ class ModelBase(object):
self.name = name
self.is_training = (mode == 'train')
self.mode = mode
self.cfg = cfg
self.py_reader = None
# parse config
# assert os.path.exists(cfg), \
# "Config file {} not exists".format(cfg)
# self._config = ModelConfig(cfg)
# self._config.parse()
# if args and isinstance(args, dict):
# self._config.merge_configs(mode, args)
# self.cfg = self._config.get_configs()
self.cfg = cfg
def build_model(self):
"build model struct"
......
......@@ -46,6 +46,7 @@ class STNET(ModelBase):
'l2_weight_decay')
self.momentum = self.get_config_from_sec('train', 'momentum')
self.seg_num = self.get_config_from_sec(self.mode, 'seg_num', self.seg_num)
self.target_size = self.get_config_from_sec(self.mode, 'target_size')
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
......
......@@ -47,6 +47,7 @@ class TSN(ModelBase):
'l2_weight_decay')
self.momentum = self.get_config_from_sec('train', 'momentum')
self.seg_num = self.get_config_from_sec(self.mode, 'seg_num', self.seg_num)
self.target_size = self.get_config_from_sec(self.mode, 'target_size')
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册