diff --git a/README.md b/README.md index 4fb6874af1ea4b49f5a117e6564fc6f37ca6b9c5..cb7ca47084f3574031f1b9eb9665c03a0d721b3f 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,15 @@ PaddleClas的安装说明、模型训练、预测、评估以及模型微调(f ### 10万类图像分类预训练模型 在实际应用中,由于训练数据匮乏,往往将ImageNet1K数据集训练的分类模型作为预训练模型,进行图像分类的迁移学习。然而ImageNet1K数据集的类别只有1000种,预训练模型的特征迁移能力有限。因此百度自研了一个有语义体系的、粒度有粗有细的10w级别的Tag体系,通过人工或半监督方式,至今收集到 5500w+图片训练数据;该系统是国内甚至世界范围内最大规模的图片分类体系和训练集合。PaddleClas提供了在该数据集上训练的ResNet50_vd的模型。下表显示了一些实际应用场景中,使用ImageNet预训练模型和上述10万类图像分类预训练模型的效果比对,使用10万类图像分类预训练模型,识别准确率最高可以提升30%。 - -
- -
+ +| 数据集 | 数据统计 | ImageNet预训练模型 | 10万类图像分类预训练模型 | +|:--:|:--:|:--:|:--:| +| 花卉 | class_num:102
train/val:5789/2396 | 0.7779 | 0.9892 | +| 手绘简笔画 | class_num:18
train/val:1007/432 | 0.8785 | 0.9107 | +| 植物叶子 | class_num:6
train/val:5256/2278 | 0.8212 | 0.8385 | +| 集装箱车辆 | class_num:115
train/val:4879/2094 | 0.623 | 0.9524 | +| 椅子 | class_num:5
train/val:169/784 | 0.8557 | 0.9077 | +| 地质 | class_num:4
train/val:671/296 | 0.5719 | 0.6781 | 10万类图像分类预训练模型下载地址如下,更多的相关内容请参考文档教程中的[**图像分类迁移学习章节**](https://paddleclas.readthedocs.io/zh_CN/latest/application/transfer_learning.html#id1)。 diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index 3d30cab706319c1c8f85318d028cf35ba282e12e..524d9a6e3b5b72ebbd8f3068c0de1ea4b306c689 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -14,13 +14,12 @@ import os import yaml + from ppcls.utils import check from ppcls.utils import logger __all__ = ['get_config'] -CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE'] - class AttrDict(dict): def __getattr__(self, key): @@ -47,13 +46,12 @@ def create_attr_dict(yaml_config): create_attr_dict(yaml_config[key]) else: yaml_config[key] = value - return def parse_config(cfg_file): """Load a config file into AttrDict""" with open(cfg_file, 'r') as fopen: - yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.FullLoader)) + yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader)) create_attr_dict(yaml_config) return yaml_config @@ -63,10 +61,8 @@ def print_dict(d, delimiter=0): Recursively visualize a dict and indenting acrrording by the relationship of keys. """ - for k, v in d.items(): - if k in CONFIG_SECS: - logger.info("-" * 60) - + placeholder = "-" * 60 + for k, v in sorted(d.items()): if isinstance(v, dict): logger.info("{}{} : ".format(delimiter * " ", k)) print_dict(v, delimiter + 4) @@ -77,8 +73,8 @@ def print_dict(d, delimiter=0): else: logger.info("{}{} : {}".format(delimiter * " ", k, v)) - if k in CONFIG_SECS: - logger.info("-" * 60) + if k.isupper(): + logger.info(placeholder) def print_config(config): @@ -88,18 +84,22 @@ def print_config(config): Arguments: config: configs """ - - copyright = "PaddleClas is powered by PaddlePaddle" - ad = "https://github.com/PaddlePaddle/PaddleClas" - - logger.info("\n" * 2) - logger.info(copyright) - logger.info(ad) - + copyright = "PaddleClas is powered by PaddlePaddle !" + info = "For more info please go to the following website." + website = "https://github.com/PaddlePaddle/PaddleClas" + AD_LEN = 55 + + logger.info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( + "=" * (AD_LEN + 4), + "=={}==".format(copyright.center(AD_LEN)), + "=" * (AD_LEN + 4), + "=={}==".format(' ' * AD_LEN), + "=={}==".format(info.center(AD_LEN)), + "=={}==".format(' ' * AD_LEN), + "=={}==".format(website.center(AD_LEN)), + "=" * (AD_LEN + 4), )) print_dict(config) - logger.info("-" * 60) - def check_config(config): """ @@ -157,7 +157,7 @@ def override(dl, ks, v): override(dl[ks[0]], ks[1:], v) -def override_config(config, options=[]): +def override_config(config, options=None): """ Recursively override the config @@ -172,32 +172,31 @@ def override_config(config, options=[]): Returns: config(dict): replaced config """ - for opt in options: - assert isinstance(opt, str), \ - ("option({}) should be a str".format(opt)) - assert "=" in opt, ("option({}) should contain " \ - "a = to distinguish between key and value".format(opt)) - pair = opt.split('=') - assert len(pair) == 2, ("there can be only a = in the option") - key, value = pair - keys = key.split('.') - override(config, keys, value) + if options is not None: + for opt in options: + assert isinstance(opt, str), ( + "option({}) should be a str".format(opt)) + assert "=" in opt, ( + "option({}) should contain a =" + "to distinguish between key and value".format(opt)) + pair = opt.split('=') + assert len(pair) == 2, ("there can be only a = in the option") + key, value = pair + keys = key.split('.') + override(config, keys, value) return config -def get_config(fname, overrides=[], show=True): +def get_config(fname, overrides=None, show=True): """ Read config from file """ - assert os.path.exists(fname), \ - ('config file({}) is not exist'.format(fname)) + assert os.path.exists(fname), ( + 'config file({}) is not exist'.format(fname)) config = parse_config(fname) + override_config(config, overrides) if show: print_config(config) - if len(overrides) > 0: - override_config(config, overrides) - if show: - print_config(config) check_config(config) return config