提交 c632a4c2 编写于 作者: L LutaoChu 提交者: wuzewu

modify collect.py (#32)

* del reset_dataset
上级 ac5cba98
...@@ -507,7 +507,7 @@ def check_test_dataset(): ...@@ -507,7 +507,7 @@ def check_test_dataset():
def main(args): def main(args):
if args.cfg_file is not None: if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file) cfg.update_from_file(args.cfg_file)
cfg.check_and_infer(reset_dataset=True) cfg.check_and_infer()
logger.info(pprint.pformat(cfg)) logger.info(pprint.pformat(cfg))
init_global_variable() init_global_variable()
......
...@@ -449,7 +449,7 @@ def main(args): ...@@ -449,7 +449,7 @@ def main(args):
cfg.update_from_file(args.cfg_file) cfg.update_from_file(args.cfg_file)
if args.opts is not None: if args.opts is not None:
cfg.update_from_list(args.opts) cfg.update_from_list(args.opts)
cfg.check_and_infer(reset_dataset=True) cfg.check_and_infer()
print(pprint.pformat(cfg)) print(pprint.pformat(cfg))
train(cfg) train(cfg)
......
...@@ -88,7 +88,7 @@ class SegConfig(dict): ...@@ -88,7 +88,7 @@ class SegConfig(dict):
except KeyError: except KeyError:
raise KeyError('Non-existent config key: {}'.format(key)) raise KeyError('Non-existent config key: {}'.format(key))
def check_and_infer(self, reset_dataset=False): def check_and_infer(self):
if self.DATASET.IMAGE_TYPE in ['rgb', 'gray']: if self.DATASET.IMAGE_TYPE in ['rgb', 'gray']:
self.DATASET.DATA_DIM = 3 self.DATASET.DATA_DIM = 3
elif self.DATASET.IMAGE_TYPE in ['rgba']: elif self.DATASET.IMAGE_TYPE in ['rgba']:
...@@ -110,17 +110,13 @@ class SegConfig(dict): ...@@ -110,17 +110,13 @@ class SegConfig(dict):
'EVAL_CROP_SIZE is empty! Please set a pair of values in format (width, height)' 'EVAL_CROP_SIZE is empty! Please set a pair of values in format (width, height)'
) )
if reset_dataset: # Ensure file list is use UTF-8 encoding
# Ensure file list is use UTF-8 encoding train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', 'utf-8').readlines()
train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', 'utf-8').readlines()
'utf-8').readlines() test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', 'utf-8').readlines()
val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets)
'utf-8').readlines() self.DATASET.VAL_TOTAL_IMAGES = len(val_sets)
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', self.DATASET.TEST_TOTAL_IMAGES = len(test_sets)
'utf-8').readlines()
self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets)
self.DATASET.VAL_TOTAL_IMAGES = len(val_sets)
self.DATASET.TEST_TOTAL_IMAGES = len(test_sets)
if self.MODEL.MODEL_NAME == 'icnet' and \ if self.MODEL.MODEL_NAME == 'icnet' and \
len(self.MODEL.MULTI_LOSS_WEIGHT) != 3: len(self.MODEL.MULTI_LOSS_WEIGHT) != 3:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册