diff --git a/README.md b/README.md
index 4fb6874af1ea4b49f5a117e6564fc6f37ca6b9c5..cb7ca47084f3574031f1b9eb9665c03a0d721b3f 100644
--- a/README.md
+++ b/README.md
@@ -70,11 +70,15 @@ PaddleClas的安装说明、模型训练、预测、评估以及模型微调(f
### 10万类图像分类预训练模型
在实际应用中,由于训练数据匮乏,往往将ImageNet1K数据集训练的分类模型作为预训练模型,进行图像分类的迁移学习。然而ImageNet1K数据集的类别只有1000种,预训练模型的特征迁移能力有限。因此百度自研了一个有语义体系的、粒度有粗有细的10w级别的Tag体系,通过人工或半监督方式,至今收集到 5500w+图片训练数据;该系统是国内甚至世界范围内最大规模的图片分类体系和训练集合。PaddleClas提供了在该数据集上训练的ResNet50_vd的模型。下表显示了一些实际应用场景中,使用ImageNet预训练模型和上述10万类图像分类预训练模型的效果比对,使用10万类图像分类预训练模型,识别准确率最高可以提升30%。
-
-
-
![]()
-
+
+| 数据集 | 数据统计 | ImageNet预训练模型 | 10万类图像分类预训练模型 |
+|:--:|:--:|:--:|:--:|
+| 花卉 | class_num:102
train/val:5789/2396 | 0.7779 | 0.9892 |
+| 手绘简笔画 | class_num:18
train/val:1007/432 | 0.8785 | 0.9107 |
+| 植物叶子 | class_num:6
train/val:5256/2278 | 0.8212 | 0.8385 |
+| 集装箱车辆 | class_num:115
train/val:4879/2094 | 0.623 | 0.9524 |
+| 椅子 | class_num:5
train/val:169/784 | 0.8557 | 0.9077 |
+| 地质 | class_num:4
train/val:671/296 | 0.5719 | 0.6781 |
10万类图像分类预训练模型下载地址如下,更多的相关内容请参考文档教程中的[**图像分类迁移学习章节**](https://paddleclas.readthedocs.io/zh_CN/latest/application/transfer_learning.html#id1)。
diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py
index 3d30cab706319c1c8f85318d028cf35ba282e12e..524d9a6e3b5b72ebbd8f3068c0de1ea4b306c689 100644
--- a/ppcls/utils/config.py
+++ b/ppcls/utils/config.py
@@ -14,13 +14,12 @@
import os
import yaml
+
from ppcls.utils import check
from ppcls.utils import logger
__all__ = ['get_config']
-CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
-
class AttrDict(dict):
def __getattr__(self, key):
@@ -47,13 +46,12 @@ def create_attr_dict(yaml_config):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
- return
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen:
- yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.FullLoader))
+ yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
create_attr_dict(yaml_config)
return yaml_config
@@ -63,10 +61,8 @@ def print_dict(d, delimiter=0):
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
- for k, v in d.items():
- if k in CONFIG_SECS:
- logger.info("-" * 60)
-
+ placeholder = "-" * 60
+ for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", k))
print_dict(v, delimiter + 4)
@@ -77,8 +73,8 @@ def print_dict(d, delimiter=0):
else:
logger.info("{}{} : {}".format(delimiter * " ", k, v))
- if k in CONFIG_SECS:
- logger.info("-" * 60)
+ if k.isupper():
+ logger.info(placeholder)
def print_config(config):
@@ -88,18 +84,22 @@ def print_config(config):
Arguments:
config: configs
"""
-
- copyright = "PaddleClas is powered by PaddlePaddle"
- ad = "https://github.com/PaddlePaddle/PaddleClas"
-
- logger.info("\n" * 2)
- logger.info(copyright)
- logger.info(ad)
-
+ copyright = "PaddleClas is powered by PaddlePaddle !"
+ info = "For more info please go to the following website."
+ website = "https://github.com/PaddlePaddle/PaddleClas"
+ AD_LEN = 55
+
+ logger.info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
+ "=" * (AD_LEN + 4),
+ "=={}==".format(copyright.center(AD_LEN)),
+ "=" * (AD_LEN + 4),
+ "=={}==".format(' ' * AD_LEN),
+ "=={}==".format(info.center(AD_LEN)),
+ "=={}==".format(' ' * AD_LEN),
+ "=={}==".format(website.center(AD_LEN)),
+ "=" * (AD_LEN + 4), ))
print_dict(config)
- logger.info("-" * 60)
-
def check_config(config):
"""
@@ -157,7 +157,7 @@ def override(dl, ks, v):
override(dl[ks[0]], ks[1:], v)
-def override_config(config, options=[]):
+def override_config(config, options=None):
"""
Recursively override the config
@@ -172,32 +172,31 @@ def override_config(config, options=[]):
Returns:
config(dict): replaced config
"""
- for opt in options:
- assert isinstance(opt, str), \
- ("option({}) should be a str".format(opt))
- assert "=" in opt, ("option({}) should contain " \
- "a = to distinguish between key and value".format(opt))
- pair = opt.split('=')
- assert len(pair) == 2, ("there can be only a = in the option")
- key, value = pair
- keys = key.split('.')
- override(config, keys, value)
+ if options is not None:
+ for opt in options:
+ assert isinstance(opt, str), (
+ "option({}) should be a str".format(opt))
+ assert "=" in opt, (
+ "option({}) should contain a ="
+ "to distinguish between key and value".format(opt))
+ pair = opt.split('=')
+ assert len(pair) == 2, ("there can be only a = in the option")
+ key, value = pair
+ keys = key.split('.')
+ override(config, keys, value)
return config
-def get_config(fname, overrides=[], show=True):
+def get_config(fname, overrides=None, show=True):
"""
Read config from file
"""
- assert os.path.exists(fname), \
- ('config file({}) is not exist'.format(fname))
+ assert os.path.exists(fname), (
+ 'config file({}) is not exist'.format(fname))
config = parse_config(fname)
+ override_config(config, overrides)
if show:
print_config(config)
- if len(overrides) > 0:
- override_config(config, overrides)
- if show:
- print_config(config)
check_config(config)
return config