提交 33c96900 编写于 作者: W WuHaobo
......@@ -3,49 +3,43 @@
**文档教程**:https://paddleclas.readthedocs.io (正在持续更新中)
## 简介
PaddleClas的目的是为工业界和学术界提供一个图像分类任务相关的百宝箱,特色如下:
- 模型库:ResNet_vd、MobileNetV3等23种系列的分类网络结构和训练技巧,以及对应的117个分类预训练模型和性能评估
- 高阶支持:SSLD知识蒸馏方案(准确率82.39%的ResNet50_vd和78.9%的MobileNetV3)、8种数据增广方法的复现和验证
- 应用拓展:常见视觉任务的特色方案,包括图像分类领域的迁移学习(百度自研的10万类图像分类预训练模型)和通用目标检测(mAP 47.8%的实用检测方案)等
- 实用工具:便于工业应用部署的实用工具,包括TensorRT预测、移动端预测、模型服务化部署等
- 赛事支持:助力多个视觉全球挑战赛取得领先成绩,包括2018年Kaggle Open Images V4图像目标检测挑战赛冠军、2019年Kaggle地标检索挑战赛亚军等
<div align="center">
<img src="docs/images/main_features.png" width="700">
</div>
## 模型库
## 丰富的模型库
基于ImageNet1k分类数据集,PaddleClas提供ResNet、ResNet_vd、EfficientNet、Res2Net、HRNet、MobileNetV3等23种系列的分类网络结构的简单介绍、论文指标复现配置,以及在复现过程中的训练技巧。与此同时,PaddleClas也提供了对应的117个图像分类预训练模型,并且基于TensorRT评估了所有模型的GPU预测时间,以及在骁龙855(SD855)上评估了移动端模型的CPU预测时间和存储大小。支持的***预训练模型列表、下载地址以及更多信息***请见文档教程中的[**模型库章节**](https://paddleclas.readthedocs.io/zh_CN/latest/models/models_intro.html)。
基于ImageNet1k分类数据集,PaddleClas提供ResNet、ResNet_vd、EfficientNet、Res2Net、HRNet、MobileNetV3等23种系列的分类网络结构的简单介绍、论文指标复现配置,以及在复现过程中的训练技巧。与此同时,也提供了对应的117个图像分类预训练模型,并且基于TensorRT评估了服务器端模型的GPU预测时间,以及在骁龙855(SD855)上评估了移动端模型的CPU预测时间和存储大小。支持的***预训练模型列表、下载地址以及更多信息***请见文档教程中的[**模型库章节**](https://paddleclas.readthedocs.io/zh_CN/latest/models/models_intro.html)。
<div align="center">
<img src="docs/images/models/main_fps_top1.png" width="700">
</div>
上图展示了一些适合服务器端应用的模型,使用V100,FP32和TensorRT预测一张图像的时间,图中ResNet50_vd_ssld和ResNet101_vd_ssld,是采用PaddleClas提供的SSLD蒸馏方法训练的模型。图中相同颜色和符号的点代表同一系列不同规模的模型。不同模型的简介、FLOPS、Parameters以及详细GPU预测时间请参考文档教程中的[**模型库章节**](https://paddleclas.readthedocs.io/zh_CN/latest/models/models_intro.html)
上图对比了一些最新的面向服务器端应用场景的模型,在使用V100,FP32和TensorRT预测一张图像的时间和其准确率,图中准确率82.4%的ResNet50_vd_ssld和83.7%的ResNet101_vd_ssld,是采用PaddleClas提供的SSLD知识蒸馏方案训练的模型。图中相同颜色和符号的点代表同一系列不同规模的模型。不同模型的简介、FLOPS、Parameters以及详细的GPU预测时间请参考文档教程中的[**模型库章节**](https://paddleclas.readthedocs.io/zh_CN/latest/models/models_intro.html)
<div align="center">
<img
src="docs/images/models/mobile_arm_top1.png" width="700">
</div>
上图展示了一些适合移动端应用的模型在SD855上预测一张图像的时间。图中MV3_large_x1_0_ssld(M是MobileNet的简称),MV3_small_x1_0_ssld、MV2_ssld和MV1_ssld,是采用PaddleClas提供的SSLD蒸馏方法训练的模型。MV3_large_x1_0_ssld_int8是进一步进行INT8量化的模型。不同模型的简介、FLOPS、Parameters和模型存储大小请参考文档教程中的[**模型库章节**](https://paddleclas.readthedocs.io/zh_CN/latest/models/models_intro.html)
上图对比了一些最新的面向移动端应用场景的模型,在骁龙855(SD855)上预测一张图像的时间和其准确率,包括MobileNetV1系列、MobileNetV2系列、MobileNetV3系列和ShuffleNetV2系列。图中准确率79%的MV3_large_x1_0_ssld(M是MobileNet的简称),71.3%的MV3_small_x1_0_ssld、76.74%的MV2_ssld和77.89%的MV1_ssld,是采用PaddleClas提供的SSLD蒸馏方法训练的模型。MV3_large_x1_0_ssld_int8是进一步进行INT8量化的模型。不同模型的简介、FLOPS、Parameters和模型存储大小请参考文档教程中的[**模型库章节**](https://paddleclas.readthedocs.io/zh_CN/latest/models/models_intro.html)
- TODO
- [ ] EfficientLite、GhostNet、RegNet论文指标复现和性能评估
## 高阶支持
## 高阶优化支持
除了提供丰富的分类网络结构和预训练模型,PaddleClas也支持了一系列有助于图像分类任务效果和效率提升的算法或工具。
### 知识蒸馏
### SSLD知识蒸馏
知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的效果提升,甚至获得与大模型相似的精度指标。
知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的效果提升,甚至获得与大模型相似的精度指标。PaddleClas提供了一种简单的半监督标签知识蒸馏方案(SSLD,Simple Semi-supervised Label Distillation),使用该方案,模型效果普遍提升3%以上,一些蒸馏模型提升效果如下图所示:
<div align="center">
<img
src="docs/images/distillation/distillation_perform.png" width="700">
</div>
PaddleClas提供了一种简单的半监督标签知识蒸馏方案(SSLD,Simple Semi-supervised Label Distillation),使用该方案大幅提升了ResNet101_vd,ResNet50_vd、MobileNetV1、MobileNetV2和MobileNetV3在ImageNet数据集上分类效果,如上图所示。该知识蒸馏方案的框架图如下,详细的知识蒸馏方法介绍请参考文档教程中的[**知识蒸馏章节**](https://paddleclas.readthedocs.io/zh_CN/latest/advanced_tutorials/distillation/index.html)
以在ImageNet1K蒸馏模型为例,SSLD知识蒸馏方案框架图如下,该方案的核心关键点包括教师模型的选择、loss计算方式、迭代轮数、无标签数据的使用、以及ImageNet1k蒸馏finetune,每部分的详细介绍以及实验介绍请参考文档教程中的[**知识蒸馏章节**](https://paddleclas.readthedocs.io/zh_CN/latest/advanced_tutorials/distillation/index.html)
<div align="center">
<img
......@@ -54,14 +48,14 @@ src="docs/images/distillation/ppcls_distillation.png" width="700">
### 数据增广
在图像分类任务中,图像数据的增广是一种常用的正则化方法,可以有效提升图像分类的效果,尤其对于数据量不足或者模型网络较大的场景。PaddleClas支持了最新的8种数据增广算法的复现和在统一实验环境下的效果评估,变换效果示例如下
在图像分类任务中,图像数据的增广是一种常用的正则化方法,可以有效提升图像分类的效果,尤其对于数据量不足或者模型网络较大的场景。常用的数据增广可以分为3类,图像变换类、图像裁剪类和图像混叠类,如下图所示。图像变换类是指对全图进行一些变换,例如AutoAugment,RandAugment。图像裁剪类是指对图像以一定的方式遮挡部分区域的变换,例如CutOut,RandErasing,HideAndSeek,GridMask。图像混叠类是指多张图进行混叠一张新图的变换,例如Mixup,Cutmix
<div align="center">
<img
src="docs/images/image_aug/image_aug_samples.png" width="800">
</div>
下图展示了不同数据增广方式在ResNet50上的表现。每种数据增广方法的详细介绍、对比的实验环境请参考文档教程中的[**数据增广章节**](https://paddleclas.readthedocs.io/zh_CN/latest/advanced_tutorials/image_augmentation/index.html)
PaddleClas提供了上述8种数据增广算法的复现和在统一实验环境下的效果评估。下图展示了不同数据增广方式在ResNet50上的表现, 与标准变换相比,采用数据增广,识别准确率最高可以提升1%。每种数据增广方法的详细介绍、对比的实验环境请参考文档教程中的[**数据增广章节**](https://paddleclas.readthedocs.io/zh_CN/latest/advanced_tutorials/image_augmentation/index.html)
<div align="center">
<img
......@@ -70,23 +64,27 @@ src="docs/images/image_aug/main_image_aug.png" width="600">
- TODO
- [ ] 更多的优化器支持和效果验证
- [ ] 支持模型可解释性工具
## 开始使用
PaddleClas的安装说明、模型训练、预测、评估以及模型微调(finetune)请参考文档教程中的[**初级使用章节**](https://paddleclas.readthedocs.io/zh_CN/latest/tutorials/index.html),SSLD知识蒸馏和数据增广的高阶使用正在持续更新中。
## 应用拓展
效果更优的图像分类网络结构和预训练模型往往有助于提升其他视觉任务的效果,PaddleClas提供了一系列在常见视觉任务中的特色方案。
## 特色拓展应用
### 图像分类的迁移学习
### 10万类图像分类预训练模型
在实际应用中,由于训练数据匮乏,往往将ImageNet1K数据集训练的分类模型作为预训练模型,进行图像分类的迁移学习。然而ImageNet1K数据集的类别只有1000种,预训练模型的特征迁移能力有限。因此百度自研了一个有语义体系的、粒度有粗有细的10w级别的Tag体系,通过人工或半监督方式,至今收集到 5500w+图片训练数据;该系统是国内甚至世界范围内最大规模的图片分类体系和训练集合。PaddleClas提供了在该数据集上训练的ResNet50_vd的模型。下表显示了一些实际应用场景中,使用ImageNet预训练模型和上述10万类图像分类预训练模型的效果比对,使用10万类图像分类预训练模型,识别准确率最高可以提升30%。
<div align="center">
<img
src="docs/images/10w_cls.png" width="450">
</div>
在实际应用中,由于训练数据的匮乏,往往将ImageNet1K数据集训练的分类模型作为预训练模型,进行图像分类的迁移学习。为了进一步助力实际问题的解决,PaddleClas提供百度自研的基于10万种类别、4千多万的有标签数据训练的预训练模型,预训练模型下载地址如下,更多的相关内容请参考文档教程中的[**图像分类迁移学习章节**](https://paddleclas.readthedocs.io/zh_CN/latest/application/transfer_learning.html#id1)
10万类图像分类预训练模型下载地址如下,更多的相关内容请参考文档教程中的[**图像分类迁移学习章节**](https://paddleclas.readthedocs.io/zh_CN/latest/application/transfer_learning.html#id1)
[**10万类预训练模型下载地址**](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_10w_pretrained.tar)
- [**10万类预训练模型下载地址**](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_10w_pretrained.tar)
### 通用目标检测
近年来,学术界和工业界广泛关注图像中目标检测任务。PaddleClas基于82.39%的ResNet50_vd的预训练模型,结合PaddleDetection中丰富的检测算子,提供了一种面向服务器端应用的目标检测方案,PSS-DET (Practical Server Side Detection),在COCO目标检测数据集上,当V100单卡预测速度为61FPS时,mAP是41.6%,预测速度为20FPS时,mAP是47.8%。详情请参考[**通用目标检测章节**](https://paddleclas.readthedocs.io/zh_CN/latest/application/object_detection.html)
近年来,学术界和工业界广泛关注图像中目标检测任务,而图像分类的网络结构以及预训练模型效果直接影响目标检测的效果。PaddleClas基于82.39%的ResNet50_vd的预训练模型,结合PaddleDetection中丰富的检测算子,提供了一种面向服务器端应用的目标检测方案,PSS-DET (Practical Server Side Detection)。该方案融合了多种只增加少许计算量,但是可以有效提升两阶段Faster RCNN目标检测效果的策略,包括检测模型剪裁、使用分类效果更优的预训练模型、DCNv2、Cascade RCNN、AutoAugment、Libra sampling以及多尺度训练。其中基于82.39%的R50_vd_ssld预训练模型,与79.12%的R50_vd的预训练模型相比,检测效果可以提升1.5%。在COCO目标检测数据集上测试PSS-DET,当V100单卡预测速度为61FPS时,mAP是41.6%,预测速度为20FPS时,mAP是47.8%。详情请参考[**通用目标检测章节**](https://paddleclas.readthedocs.io/zh_CN/latest/application/object_detection.html)
<div align="center">
<img
......@@ -94,10 +92,10 @@ src="docs/images/det/pssdet.png" width="500">
</div>
- TODO
- [ ] PaddleClas在OCR任务中的特色应用
- [ ] PaddleClas在人脸检测和识别中的特色应用
- [ ] PaddleClas在OCR任务中的应用
- [ ] PaddleClas在人脸检测和识别中的应用
## 实用工具
## 工业级应用部署工具
PaddlePaddle提供了一系列实用工具,便于工业应用部署PaddleClas,具体请参考文档教程中的[**实用工具章节**](https://paddleclas.readthedocs.io/zh_CN/latest/extension/index.html)
- TensorRT预测
......@@ -107,7 +105,7 @@ PaddlePaddle提供了一系列实用工具,便于工业应用部署PaddleClas
- 多机训练
- Paddle Hub
## 赛事支持
## 护航视觉挑战赛
PaddleClas的建设源于百度实际视觉业务应用的淬炼和视觉前沿能力的探索,助力多个视觉重点赛事取得领先成绩,并且持续推进更多的前沿视觉问题的解决和落地应用。更多内容请关注文档教程中的[**赛事支持章节**](https://paddleclas.readthedocs.io/zh_CN/latest/competition_support.html)
- 2018年Kaggle Open Images V4图像目标检测挑战赛冠军
......
......@@ -206,6 +206,8 @@ def mp_reader(params):
check_params(params)
full_lines = get_file_list(params)
if params["mode"] == "train":
full_lines = shuffle_lines(full_lines, seed=None)
part_num = 1 if 'num_workers' not in params else params['num_workers']
......@@ -254,11 +256,10 @@ class Reader:
self.batch_ops = create_operators(self.params['mix'])
def __call__(self):
reader = mp_reader(self.params)
batch_size = int(self.params['batch_size']) // trainers_num
def wrapper():
reader = mp_reader(self.params)
batch = []
for idx, sample in enumerate(reader()):
img, label = sample
......
......@@ -42,3 +42,6 @@ from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_1
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
# distillation model
from .distillation_models import ResNet50_vd_distill_MobileNetV3_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from .resnet_vd import ResNet50_vd
from .mobilenet_v3 import MobileNetV3_large_x1_0
from .resnext101_wsl import ResNeXt101_32x16d_wsl
__all__ = [
'ResNet50_vd_distill_MobileNetV3_x1_0',
'ResNeXt101_32x16d_wsl_distill_ResNet50_vd'
]
class ResNet50_vd_distill_MobileNetV3_x1_0():
def net(self, input, class_dim=1000):
# student
student = MobileNetV3_large_x1_0()
out_student = student.net(input, class_dim=class_dim)
# teacher
teacher = ResNet50_vd()
out_teacher = teacher.net(input, class_dim=class_dim)
out_teacher.stop_gradient = True
return out_teacher, out_student
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd():
def net(self, input, class_dim=1000):
# student
student = ResNet50_vd()
out_student = student.net(input, class_dim=class_dim)
# teacher
teacher = ResNeXt101_32x16d_wsl()
out_teacher = teacher.net(input, class_dim=class_dim)
out_teacher.stop_gradient = True
return out_teacher, out_student
......@@ -15,7 +15,7 @@
import paddle
import paddle.fluid as fluid
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss']
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss']
class Loss(object):
......@@ -34,8 +34,11 @@ class Loss(object):
self._label_smoothing = False
def _labelsmoothing(self, target):
one_hot_target = fluid.layers.one_hot(
input=target, depth=self._class_dim)
if target.shape[-1] != self._class_dim:
one_hot_target = fluid.layers.one_hot(
input=target, depth=self._class_dim)
else:
one_hot_target = target
soft_target = fluid.layers.label_smooth(
label=one_hot_target, epsilon=self._epsilon, dtype="float32")
return soft_target
......@@ -49,6 +52,19 @@ class Loss(object):
avg_cost = fluid.layers.mean(cost)
return avg_cost
def _kldiv(self, input, target):
cost = target * fluid.layers.log(target / input) * self._class_dim
cost = fluid.layers.sum(cost)
return cost
def _jsdiv(self, input, target):
input = fluid.layers.softmax(input, use_cudnn=False)
target = fluid.layers.softmax(target, use_cudnn=False)
cost = self._kldiv(input, target) + self._kldiv(target, input)
cost = cost / 2
avg_cost = fluid.layers.mean(cost)
return avg_cost
def __call__(self, input, target):
pass
......@@ -97,3 +113,16 @@ class GoogLeNetLoss(Loss):
cost = cost0 + 0.3 * cost1 + 0.3 * cost2
avg_cost = fluid.layers.mean(cost)
return avg_cost
class JSDivLoss(Loss):
"""
JSDiv loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(JSDivLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._jsdiv(input, target)
return cost
......@@ -14,6 +14,7 @@
import os
import logging
logging.basicConfig()
import random
DEBUG = logging.DEBUG #10
......
......@@ -106,22 +106,20 @@ def load_params(exe, prog, path, ignore_params=[]):
fluid.io.set_program_state(prog, state)
def init_model(config, program, exe, prefix=""):
def init_model(config, program, exe):
"""
load model from checkpoint or pretrained_model
"""
checkpoints = config.get('checkpoints')
if checkpoints:
path = os.path.join(checkpoints, prefix)
fluid.load(program, path, exe)
logger.info("Finish initing model from {}".format(path))
fluid.load(program, checkpoints, exe)
logger.info("Finish initing model from {}".format(checkpoints))
return
pretrained_model = config.get('pretrained_model')
if pretrained_model:
path = os.path.join(pretrained_model, prefix)
load_params(exe, program, path)
logger.info("Finish initing model from {}".format(path))
load_params(exe, program, pretrained_model)
logger.info("Finish initing model from {}".format(pretrained_model))
def save_model(program, model_path, epoch_id, prefix='ppcls'):
......
......@@ -24,6 +24,7 @@ def parse_args():
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("-o", "--output_path", type=str)
parser.add_argument("--class_dim", type=int)
return parser.parse_args()
......@@ -57,7 +58,7 @@ def main():
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
image = create_input()
out = create_model(args, model, image)
out = create_model(args, model, image, class_dim=args.class_dim)
infer_prog = infer_prog.clone(for_test=True)
fluid.load(
......
......@@ -31,6 +31,7 @@ from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
from ppcls.modeling.loss import JSDivLoss
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
......@@ -39,13 +40,13 @@ from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.collective import DistributedStrategy
def create_feeds(image_shape, mix=None):
def create_feeds(image_shape, use_mix=None):
"""
Create feeds as model input
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
feeds(dict): dict of model input variables
......@@ -53,7 +54,7 @@ def create_feeds(image_shape, mix=None):
feeds = OrderedDict()
feeds['image'] = fluid.data(
name="feed_image", shape=[None] + image_shape, dtype="float32")
if mix:
if use_mix:
feeds['feed_y_a'] = fluid.data(
name="feed_y_a", shape=[None, 1], dtype="int64")
feeds['feed_y_b'] = fluid.data(
......@@ -112,7 +113,8 @@ def create_loss(out,
architecture,
classes_num=1000,
epsilon=None,
mix=False):
use_mix=False,
use_distillation=False):
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
......@@ -127,7 +129,7 @@ def create_loss(out,
architecture(dict): architecture information, name(such as ResNet50) is needed
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
loss(variable): loss variable
......@@ -138,7 +140,14 @@ def create_loss(out,
target = feeds['label']
return loss(out[0], out[1], out[2], target)
if mix:
if use_distillation:
assert len(
out) == 2, "distillation output length must be 2 but got {}".format(
len(out))
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[1], out[0])
if use_mix:
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
feed_y_a = feeds['feed_y_a']
feed_y_b = feeds['feed_y_b']
......@@ -150,7 +159,8 @@ def create_loss(out,
return loss(out, target)
def create_metric(out, feeds, topk=5, classes_num=1000):
def create_metric(out, feeds, topk=5, classes_num=1000,
use_distillation=False):
"""
Create measures of model accuracy, such as top1 and top5
......@@ -163,6 +173,9 @@ def create_metric(out, feeds, topk=5, classes_num=1000):
Returns:
fetchs(dict): dict of measures
"""
# just need student label to get metrics
if use_distillation:
out = out[1]
fetchs = OrderedDict()
label = feeds['label']
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
......@@ -182,10 +195,11 @@ def create_fetchs(out,
topk=5,
classes_num=1000,
epsilon=None,
mix=False):
use_mix=False,
use_distillation=False):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if mix).
will call create_loss and create_metric(if use_mix).
Args:
out(variable): model output variable
......@@ -194,16 +208,17 @@ def create_fetchs(out,
topk(int): usually top5
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
fetchs = OrderedDict()
loss = create_loss(out, feeds, architecture, classes_num, epsilon, mix)
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
use_distillation)
fetchs['loss'] = (loss, AverageMeter('loss', ':2.4f', True))
if not mix:
metric = create_metric(out, feeds, topk, classes_num)
if not use_mix:
metric = create_metric(out, feeds, topk, classes_num, use_distillation)
fetchs.update(metric)
return fetchs
......@@ -293,7 +308,8 @@ def build(config, main_prog, startup_prog, is_train=True):
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
use_mix = config.get('use_mix') and is_train
feeds = create_feeds(config.image_shape, mix=use_mix)
use_distillation = config.get('use_distillation')
feeds = create_feeds(config.image_shape, use_mix=use_mix)
dataloader = create_dataloader(feeds.values())
out = create_model(config.ARCHITECTURE, feeds['image'],
config.classes_num)
......@@ -304,7 +320,8 @@ def build(config, main_prog, startup_prog, is_train=True):
config.topk,
config.classes_num,
epsilon=config.get('ls_epsilon'),
mix=use_mix)
use_mix=use_mix,
use_distillation=use_distillation)
if is_train:
optimizer = create_optimizer(config)
lr = optimizer._global_learning_rate()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册