From dd6126ed6450b5bcda4b2db025ccbb1e8b0bc0a5 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 31 Jan 2019 02:55:03 +0000 Subject: [PATCH] remove reader and metrics in ModelBase --- .../attention_cluster/attention_cluster.py | 18 ------ fluid/PaddleCV/video/models/model.py | 62 ------------------- fluid/PaddleCV/video/test.py | 1 - fluid/PaddleCV/video/train.py | 1 - 4 files changed, 82 deletions(-) diff --git a/fluid/PaddleCV/video/models/attention_cluster/attention_cluster.py b/fluid/PaddleCV/video/models/attention_cluster/attention_cluster.py index 6f0340b6..2af98543 100755 --- a/fluid/PaddleCV/video/models/attention_cluster/attention_cluster.py +++ b/fluid/PaddleCV/video/models/attention_cluster/attention_cluster.py @@ -143,21 +143,3 @@ class AttentionCluster(ModelBase): "https://paddlemodels.bj.bcebos.com/video_classification/attention_cluster_youtube8m.tar.gz" ) - def create_dataset_args(self): - dataset_args = {} - dataset_args['num_classes'] = self.class_num - dataset_args['seg_num'] = self.seg_num - dataset_args['list'] = self.get_config_from_sec(self.mode, 'filelist') - - if self.use_gpu and self.py_reader: - dataset_args['batch_size'] = int(self.batch_size / self.gpu_num) - else: - dataset_args['batch_size'] = self.batch_size - - return dataset_args - - def create_metrics_args(self): - metrics_args = {} - metrics_args['num_classes'] = self.class_num - metrics_args['topk'] = 20 - return metrics_args diff --git a/fluid/PaddleCV/video/models/model.py b/fluid/PaddleCV/video/models/model.py index 21436194..c5b04708 100755 --- a/fluid/PaddleCV/video/models/model.py +++ b/fluid/PaddleCV/video/models/model.py @@ -58,52 +58,6 @@ class ModelNotFoundError(Exception): return msg -class ModelConfig(object): - def __init__(self, cfg_file): - self.cfg_file = cfg_file - self.parser = ConfigParser() - self.cfg = AttrDict() - - def parse(self): - self.parser.read(self.cfg_file) - for sec in self.parser.sections(): - sec_dict = AttrDict() - for k, v in self.parser.items(sec): - try: - v = eval(v) - except: - pass - setattr(sec_dict, k, v) - setattr(self.cfg, sec.upper(), sec_dict) - - def merge_configs(self, sec, cfg_dict): - sec_dict = getattr(self.cfg, sec.upper()) - for k, v in cfg_dict.items(): - if v is None: - continue - try: - if hasattr(sec_dict, k): - setattr(sec_dict, k, v) - except: - pass - - def get_config_from_sec(self, sec, item): - try: - if hasattr(self.cfg, sec): - sec_dict = getattr(self.cfg, sec) - except: - return None - - try: - if hasattr(sec_dict, item): - return getattr(sec_dict, item) - except: - return None - - def get_configs(self): - return self.cfg - - class ModelBase(object): def __init__(self, name, cfg, mode='train'): assert mode in ['train', 'valid', 'test', 'infer'], \ @@ -147,22 +101,6 @@ class ModelBase(object): "get feed inputs list" raise NotImplementError(self, self.feeds) - def create_dataset_args(self): - "get model reader" - raise NotImplementError(self, self.create_dataset_args) - - def reader(self): - dataset_args = self.create_dataset_args() - return get_reader(self.name.upper(), self.mode, **dataset_args) - - def create_metrics_args(self): - "get model reader" - raise NotImplementError(self, self.create_metrics_args) - - def metrics(self): - metrics_args = self.create_metrics_args() - return get_metrics(self.name.upper(), self.mode, **metrics_args) - def weights_info(self): "get model weight default path and download url" raise NotImplementError(self, self.weights_info) diff --git a/fluid/PaddleCV/video/test.py b/fluid/PaddleCV/video/test.py index 28dffdc4..88ab509c 100755 --- a/fluid/PaddleCV/video/test.py +++ b/fluid/PaddleCV/video/test.py @@ -25,7 +25,6 @@ import models from datareader import get_reader from metrics import get_metrics -logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) diff --git a/fluid/PaddleCV/video/train.py b/fluid/PaddleCV/video/train.py index 1731cdcc..3713d86c 100755 --- a/fluid/PaddleCV/video/train.py +++ b/fluid/PaddleCV/video/train.py @@ -26,7 +26,6 @@ from config import * from datareader import get_reader from metrics import get_metrics -logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) -- GitLab