提交 dd6126ed 编写于 作者: D dengkaipeng

remove reader and metrics in ModelBase

上级 ebb8f1ee
......@@ -143,21 +143,3 @@ class AttentionCluster(ModelBase):
"https://paddlemodels.bj.bcebos.com/video_classification/attention_cluster_youtube8m.tar.gz"
)
def create_dataset_args(self):
dataset_args = {}
dataset_args['num_classes'] = self.class_num
dataset_args['seg_num'] = self.seg_num
dataset_args['list'] = self.get_config_from_sec(self.mode, 'filelist')
if self.use_gpu and self.py_reader:
dataset_args['batch_size'] = int(self.batch_size / self.gpu_num)
else:
dataset_args['batch_size'] = self.batch_size
return dataset_args
def create_metrics_args(self):
metrics_args = {}
metrics_args['num_classes'] = self.class_num
metrics_args['topk'] = 20
return metrics_args
......@@ -58,52 +58,6 @@ class ModelNotFoundError(Exception):
return msg
class ModelConfig(object):
def __init__(self, cfg_file):
self.cfg_file = cfg_file
self.parser = ConfigParser()
self.cfg = AttrDict()
def parse(self):
self.parser.read(self.cfg_file)
for sec in self.parser.sections():
sec_dict = AttrDict()
for k, v in self.parser.items(sec):
try:
v = eval(v)
except:
pass
setattr(sec_dict, k, v)
setattr(self.cfg, sec.upper(), sec_dict)
def merge_configs(self, sec, cfg_dict):
sec_dict = getattr(self.cfg, sec.upper())
for k, v in cfg_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
def get_config_from_sec(self, sec, item):
try:
if hasattr(self.cfg, sec):
sec_dict = getattr(self.cfg, sec)
except:
return None
try:
if hasattr(sec_dict, item):
return getattr(sec_dict, item)
except:
return None
def get_configs(self):
return self.cfg
class ModelBase(object):
def __init__(self, name, cfg, mode='train'):
assert mode in ['train', 'valid', 'test', 'infer'], \
......@@ -147,22 +101,6 @@ class ModelBase(object):
"get feed inputs list"
raise NotImplementError(self, self.feeds)
def create_dataset_args(self):
"get model reader"
raise NotImplementError(self, self.create_dataset_args)
def reader(self):
dataset_args = self.create_dataset_args()
return get_reader(self.name.upper(), self.mode, **dataset_args)
def create_metrics_args(self):
"get model reader"
raise NotImplementError(self, self.create_metrics_args)
def metrics(self):
metrics_args = self.create_metrics_args()
return get_metrics(self.name.upper(), self.mode, **metrics_args)
def weights_info(self):
"get model weight default path and download url"
raise NotImplementError(self, self.weights_info)
......
......@@ -25,7 +25,6 @@ import models
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......
......@@ -26,7 +26,6 @@ from config import *
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册