From c632a4c24ace568b1a9154bb1228ea9a45b036bc Mon Sep 17 00:00:00 2001 From: LutaoChu <30695251+LutaoChu@users.noreply.github.com> Date: Tue, 17 Sep 2019 17:21:17 +0800 Subject: [PATCH] modify collect.py (#32) * del reset_dataset --- pdseg/check.py | 2 +- pdseg/train.py | 2 +- pdseg/utils/collect.py | 20 ++++++++------------ 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/pdseg/check.py b/pdseg/check.py index f1b70a6a..1affbdd2 100644 --- a/pdseg/check.py +++ b/pdseg/check.py @@ -507,7 +507,7 @@ def check_test_dataset(): def main(args): if args.cfg_file is not None: cfg.update_from_file(args.cfg_file) - cfg.check_and_infer(reset_dataset=True) + cfg.check_and_infer() logger.info(pprint.pformat(cfg)) init_global_variable() diff --git a/pdseg/train.py b/pdseg/train.py index 22a430f7..9db3648c 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -449,7 +449,7 @@ def main(args): cfg.update_from_file(args.cfg_file) if args.opts is not None: cfg.update_from_list(args.opts) - cfg.check_and_infer(reset_dataset=True) + cfg.check_and_infer() print(pprint.pformat(cfg)) train(cfg) diff --git a/pdseg/utils/collect.py b/pdseg/utils/collect.py index 6b8f2f4e..02618381 100644 --- a/pdseg/utils/collect.py +++ b/pdseg/utils/collect.py @@ -88,7 +88,7 @@ class SegConfig(dict): except KeyError: raise KeyError('Non-existent config key: {}'.format(key)) - def check_and_infer(self, reset_dataset=False): + def check_and_infer(self): if self.DATASET.IMAGE_TYPE in ['rgb', 'gray']: self.DATASET.DATA_DIM = 3 elif self.DATASET.IMAGE_TYPE in ['rgba']: @@ -110,17 +110,13 @@ class SegConfig(dict): 'EVAL_CROP_SIZE is empty! Please set a pair of values in format (width, height)' ) - if reset_dataset: - # Ensure file list is use UTF-8 encoding - train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', - 'utf-8').readlines() - val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', - 'utf-8').readlines() - test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', - 'utf-8').readlines() - self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets) - self.DATASET.VAL_TOTAL_IMAGES = len(val_sets) - self.DATASET.TEST_TOTAL_IMAGES = len(test_sets) + # Ensure file list is use UTF-8 encoding + train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', 'utf-8').readlines() + val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', 'utf-8').readlines() + test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', 'utf-8').readlines() + self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets) + self.DATASET.VAL_TOTAL_IMAGES = len(val_sets) + self.DATASET.TEST_TOTAL_IMAGES = len(test_sets) if self.MODEL.MODEL_NAME == 'icnet' and \ len(self.MODEL.MULTI_LOSS_WEIGHT) != 3: -- GitLab