Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
13d7b57a
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
13d7b57a
编写于
11月 23, 2021
作者:
S
sibo2rr
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/PaddleClas
into develop
上级
5b6f61a2
5fb8f0d3
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
173 addition
and
26 deletion
+173
-26
docs/zh_CN/data_preparation/recognition_dataset.md
docs/zh_CN/data_preparation/recognition_dataset.md
+3
-3
docs/zh_CN/image_recognition_pipeline/mainbody_detection.md
docs/zh_CN/image_recognition_pipeline/mainbody_detection.md
+3
-3
docs/zh_CN/inference_deployment/paddle_lite_deploy.md
docs/zh_CN/inference_deployment/paddle_lite_deploy.md
+3
-5
docs/zh_CN/others/paddle_mobile_inference.md
docs/zh_CN/others/paddle_mobile_inference.md
+119
-0
docs/zh_CN/quick_start/quick_start_classification_new_user.md
.../zh_CN/quick_start/quick_start_classification_new_user.md
+27
-5
docs/zh_CN/quick_start/quick_start_recognition.md
docs/zh_CN/quick_start/quick_start_recognition.md
+4
-2
ppcls/configs/quick_start/ResNet50_vd.yaml
ppcls/configs/quick_start/ResNet50_vd.yaml
+3
-3
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+11
-5
未找到文件。
docs/zh_CN/data_preparation/recognition_dataset.md
浏览文件 @
13d7b57a
# 图像
分类
任务数据集说明
# 图像
识别
任务数据集说明
本文档将介绍 PaddleClas 所使用的图像识别任务数据集格式,以及图像识别领域的常见数据集介绍。
...
...
@@ -15,8 +15,8 @@
-
[
2.2.2 商品识别
](
#商品识别
)
-
[
2.2.3 Logo识别
](
#Logo识别
)
-
[
2.2.4 车辆识别
](
#车辆识别
)
<a
name=
"数据集格式说明"
></a>
## 一、数据集格式说明
...
...
docs/zh_CN/image_recognition_pipeline/mainbody_detection.md
浏览文件 @
13d7b57a
...
...
@@ -29,11 +29,11 @@
| 模型 | 模型结构 | 预训练模型下载地址 | inference模型下载地址 | mAP | inference模型大小(MB) | 单张图片预测耗时(不包含预处理)(ms) |
| :------------: | :-------------: | :------: | :-------: | :--------: | :-------: | :--------: |
| 轻量级主体检测模型 | PicoDet |
[
地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_pretrained.pdparams
)
|
[
地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
)
| 40.1% | 30.1 | 29.8 |
| 服务端主体检测模型 | PP-YOLOv2 |
[
地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/ppyolov2_r50vd_dcn_mainbody_v1.0_pretrained.pdparams
)
|
[
地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar
)
| 42.5% | 210.5 | 466.6 |
| 轻量级主体检测模型 | PicoDet |
[
地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_pretrained.pdparams
)
|
[
tar 格式文件地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
)
[
zip 格式文件地址
]
(https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.zip) | 40.1% | 30.1 | 29.8 |
| 服务端主体检测模型 | PP-YOLOv2 |
[
地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/ppyolov2_r50vd_dcn_mainbody_v1.0_pretrained.pdparams
)
|
[
tar 格式文件地址
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar
)
[
zip 格式文件地址
]
(https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.zip) | 42.5% | 210.5 | 466.6 |
*
注意
*
由于部分解压缩软件在解压上述
`tar`
格式文件时存在问题,建议非命令行用户下载
`zip`
格式文件并解压。
`tar`
格式文件建议使用命令
`tar xf xxx.tar`
解压。
*
速度评测机器的CPU具体信息为:
`Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz`
,速度指标为开启 mkldnn ,线程数设置为 10 测试得到。
*
主体检测的预处理过程较为耗时,平均每张图在上述机器上的时间在 40~55 ms 左右,没有包含在上述的预测耗时统计中。
...
...
docs/zh_CN/inference_deployment/paddle_lite_deploy.md
浏览文件 @
13d7b57a
...
...
@@ -4,8 +4,8 @@
本教程将介绍基于
[
Paddle Lite
](
https://github.com/PaddlePaddle/Paddle-Lite
)
在移动端部署 PaddleClas 分类模型的详细步骤。识别模型的部署将在近期支持,敬请期待。
Paddle Lite是飞桨轻量化推理引擎,为手机、IOT端提供高效推理能力,并广泛整合跨平台硬件,为端侧部署及应用落地问题提供轻量化的部署方案。
<!-- TODO(gaotingquan): 下述 benchmark 文档在新文档结构中缺失 -->
<!-- 如果希望直接测试速度,可以参考[Paddle-Lite移动端benchmark测试教程](../../docs/zh_CN/extension/paddle_mobile_inference.md)。 -->
如果希望直接测试速度,可以参考
[
Paddle-Lite移动端benchmark测试教程
](
../others/paddle_mobile_inference.md
)
。
---
...
...
@@ -58,9 +58,7 @@ cd Paddle-Lite
git checkout develop
./lite/tools/build_android.sh
--arch
=
armv8
--with_cv
=
ON
--with_extra
=
ON
```
<!-- TODO(gaotingquan): 需要与lite同学确认,该编译选项是否需要更新:with_cv with_extra, -->
<!-- https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_options.html -->
<!-- **注意**:编译Paddle-Lite获得预测库时,需要打开`--with_cv=ON --with_extra=ON`两个选项,`--arch`表示`arm`版本,这里指定为armv8,更多编译命令介绍请参考[链接](https://paddle-lite.readthedocs.io/zh/latest/user_guides/Compile/Android.html#id2)。 -->
**注意**
:编译Paddle-Lite获得预测库时,需要打开
`--with_cv=ON --with_extra=ON`
两个选项,
`--arch`
表示
`arm`
版本,这里指定为armv8,更多编译命令介绍请参考
[
Linux x86 环境下编译适用于 Android 的库
](
https://paddle-lite.readthedocs.io/zh/latest/source_compile/linux_x86_compile_android.html
)
,关于其他平台的编译操作,具体请参考
[
PaddleLite
](
https://paddle-lite.readthedocs.io/zh/latest/
)
中
`源码编译`
部分。
直接下载预测库并解压后,可以得到
`inference_lite_lib.android.armv8/`
文件夹,通过编译Paddle-Lite得到的预测库位于
`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/`
文件夹下。
预测库的文件目录如下:
...
...
docs/zh_CN/others/paddle_mobile_inference.md
0 → 100644
浏览文件 @
13d7b57a
# Paddle-Lite
## 一、简介
[
Paddle-Lite
](
https://github.com/PaddlePaddle/Paddle-Lite
)
是飞桨推出的一套功能完善、易用性强且性能卓越的轻量化推理引擎。
轻量化体现在使用较少比特数用于表示神经网络的权重和激活,能够大大降低模型的体积,解决终端设备存储空间有限的问题,推理性能也整体优于其他框架。
[
PaddleClas
](
https://github.com/PaddlePaddle/PaddleClas
)
使用 Paddle-Lite 进行了
[
移动端模型的性能评估
](
../models/Mobile.md
)
,本部分以
`ImageNet1k`
数据集的
`MobileNetV1`
模型为例,介绍怎样使用
`Paddle-Lite`
,在移动端(基于骁龙855的安卓开发平台)对进行模型速度评估。
## 二、评估步骤
### 2.1 导出inference模型
*
首先需要将训练过程中保存的模型存储为用于预测部署的固化模型,可以使用
`tools/export_model.py`
导出inference模型,具体使用方法如下。
```
shell
python tools/export_model.py
\
-c
./ppcls/configs/ImageNet/MobileNetV1/MobileNetV1.yaml
\
-o
Arch.pretrained
=
./pretrained/MobileNetV1_pretrained/
\
-o
Global.save_inference_dir
=
./inference/MobileNetV1/
```
在上述命令中,通过参数
`Arch.pretrained`
指定训练过程中保存的模型参数文件,也可以指定参数
`Arch.pretrained=True`
加载 PaddleClas 提供的基于 ImageNet1k 的预训练模型参数,最终在
`inference/MobileNetV1`
文件夹下会保存得到
`inference.pdmodel`
与
`inference.pdiparmas`
文件。
### 2.2 benchmark二进制文件下载
*
使用adb(Android Debug Bridge)工具可以连接Android手机与PC端,并进行开发调试等。安装好adb,并确保PC端和手机连接成功后,使用以下命令可以查看手机的ARM版本,并基于此选择合适的预编译库。
```
shell
adb shell getprop ro.product.cpu.abi
```
*
下载benchmark_bin文件
请根据所用Android手机的ARM版本选择,ARM版本为v8,则使用以下命令下载:
```
shell
wget
-c
https://paddle-inference-dist.bj.bcebos.com/PaddleLite/benchmark_0/benchmark_bin_v8
```
如果查看的ARM版本为v7,则需要下载v7版本的benchmark_bin文件,下载命令如下:
```
shell
wget
-c
https://paddle-inference-dist.bj.bcebos.com/PaddleLite/benchmark_0/benchmark_bin_v7
```
### 2.3 模型速度benchmark
PC端和手机连接成功后,使用下面的命令开始模型评估。
```
sh deploy/lite/benchmark/benchmark.sh ./benchmark_bin_v8 ./inference result_armv8.txt true
```
其中
`./benchmark_bin_v8`
为benchmark二进制文件路径,
`./inference`
为所有需要评测的模型的路径,
`result_armv8.txt`
为保存的结果文件,最后的参数
`true`
表示在评估之后会首先进行模型优化。最终在当前文件夹下会输出
`result_armv8.txt`
的评估结果文件,具体信息如下。
```
PaddleLite Benchmark
Threads=1 Warmup=10 Repeats=30
MobileNetV1 min = 30.89100 max = 30.73600 average = 30.79750
Threads=2 Warmup=10 Repeats=30
MobileNetV1 min = 18.26600 max = 18.14000 average = 18.21637
Threads=4 Warmup=10 Repeats=30
MobileNetV1 min = 10.03200 max = 9.94300 average = 9.97627
```
这里给出了不同线程数下的模型预测速度,单位为FPS,以线程数为1为例,MobileNetV1在骁龙855上的平均速度为
`30.79750FPS`
。
### 2.4 模型优化与速度评估
*
在2.3节中提到了在模型评估之前对其进行优化,在这里也可以首先对模型进行优化,再直接加载优化后的模型进行速度评估。
*
Paddle-Lite 提供了多种策略来自动优化原始的训练模型,其中包括量化、子图融合、混合调度、Kernel优选等等方法。为了使优化过程更加方便易用,Paddle-Lite提供了opt 工具来自动完成优化步骤,输出一个轻量的、最优的可执行模型。可以在
[
Paddle-Lite模型优化工具页面
](
https://paddle-lite.readthedocs.io/zh/latest/user_guides/model_optimize_tool.html
)
下载。在这里以
`macOS`
开发环境为例,下载
[
opt_mac
](
https://paddlelite-data.bj.bcebos.com/model_optimize_tool/opt_mac
)
模型优化工具,并使用下面的命令对模型进行优化。
```
shell
model_file
=
"../MobileNetV1/inference.pdmodel"
param_file
=
"../MobileNetV1/inference.pdiparams"
opt_models_dir
=
"./opt_models"
mkdir
${
opt_models_dir
}
./opt_mac
--model_file
=
${
model_file
}
\
--param_file
=
${
param_file
}
\
--valid_targets
=
arm
\
--optimize_out_type
=
naive_buffer
\
--prefer_int8_kernel
=
false
\
--optimize_out
=
${
opt_models_dir
}
/MobileNetV1
```
其中
`model_file`
与
`param_file`
分别是导出的inference模型结构文件与参数文件地址,转换成功后,会在
`opt_models`
文件夹下生成
`MobileNetV1.nb`
文件。
使用benchmark_bin文件加载优化后的模型进行评估,具体的命令如下。
```
shell
bash benchmark.sh ./benchmark_bin_v8 ./opt_models result_armv8.txt
```
最终
`result_armv8.txt`
中结果如下:
```
PaddleLite Benchmark
Threads=1 Warmup=10 Repeats=30
MobileNetV1_lite min = 30.89500 max = 30.78500 average = 30.84173
Threads=2 Warmup=10 Repeats=30
MobileNetV1_lite min = 18.25300 max = 18.11000 average = 18.18017
Threads=4 Warmup=10 Repeats=30
MobileNetV1_lite min = 10.00600 max = 9.90000 average = 9.96177
```
以线程数为1为例,MobileNetV1在骁龙855上的平均速度为
`30.84173 ms`
。
更加具体的参数解释与Paddle-Lite使用方法可以参考
[
Paddle-Lite 文档
](
https://paddle-lite.readthedocs.io/zh/latest/
)
。
docs/zh_CN/quick_start/quick_start_classification_new_user.md
浏览文件 @
13d7b57a
...
...
@@ -3,7 +3,21 @@
此教程主要针对初级用户,即深度学习相关理论知识处于入门阶段,具有一定的 Python 基础,能够阅读简单代码的用户。此内容主要包括使用 PaddleClas 进行图像分类网络训练及模型预测。
---
## 目录
-
[
1. 基础知识
](
#1
)
-
[
2. 环境安装与配置
](
#2
)
-
[
3. 数据的准备与处理
](
#3
)
-
[
4. 模型训练
](
#4
)
-
[
4.1 使用CPU进行模型训练
](
#4.1
)
-
[
4.1.1 不使用预训练模型进行训练
](
#4.1.1
)
-
[
4.1.2 不使用预训练模型进行训练
](
#4.1.2
)
-
[
4.2 使用GPU进行模型训练
](
#4.2
)
-
[
4.2.1 不使用预训练模型进行训练
](
#4.2.1
)
-
[
4.2.2 不使用预训练模型进行训练
](
#4.2.2
)
-
[
5. 模型预测
](
#5
)
<a
name=
"1"
></a>
## 1. 基础知识
图像分类顾名思义就是一个模式分类问题,是计算机视觉中最基础的任务,它的目标是将不同的图像,划分到不同的类别。以下会对整个模型训练过程中需要了解到的一些概念做简单的解释,希望能够对初次体验 PaddleClas 的你有所帮助:
...
...
@@ -30,10 +44,12 @@
-
Top1 Acc:预测结果中概率最大的所在分类正确,则判定为正确;
-
Top5 Acc:预测结果中概率排名前 5 中有分类正确,则判定为正确;
<a
name=
"2"
></a>
## 2. 环境安装与配置
具体安装步骤可详看
[
Paddle 安装文档
](
../installation/install_paddle.md
)
,
[
PaddleClas 安装文档
](
../installation/install_paddleclas.md
)
。
<a
name=
"3"
></a>
## 3. 数据的准备与处理
进入PaddleClas目录:
...
...
@@ -72,15 +88,16 @@ cd ../../
# windoes直接打开PaddleClas根目录即可
```
<a
name=
"4"
></a>
## 4. 模型训练
<a
name=
"4.1"
></a>
### 4.1 使用CPU进行模型训练
由于使用CPU来进行模型训练,计算速度较慢,因此,此处以 ShuffleNetV2_x0_25 为例。此模型计算量较小,在 CPU 上计算速度较快。但是也因为模型较小,训练好的模型精度也不会太高。
#### 4.1.1 不使用预训练模型
<a
name=
"4.1.1"
></a>
#### 4.1.1 不使用预训练模型进行训练
```
shell
# windows在cmd中进入PaddleClas根目录,执行此命令
...
...
@@ -91,7 +108,8 @@ python tools/train.py -c ./ppcls/configs/quick_start/new_user/ShuffleNetV2_x0_25
-
`yaml`
文
`Global.device`
参数设置为
`cpu`
,即使用CPU进行训练(若不设置,此参数默认为
`True`
)
-
`yaml`
文件中
`epochs`
参数设置为20,说明对整个数据集进行20个epoch迭代,预计训练20分钟左右(不同CPU,训练时间略有不同),此时训练模型不充分。若提高训练模型精度,请将此参数设大,如
**40**
,训练时间也会相应延长
#### 4.1.2 使用预训练模型
<a
name=
"4.1.2"
></a>
#### 4.1.2 使用预训练模型进行训练
```
shell
python tools/train.py
-c
./ppcls/configs/quick_start/new_user/ShuffleNetV2_x0_25.yaml
-o
Arch.pretrained
=
True
...
...
@@ -101,6 +119,7 @@ python tools/train.py -c ./ppcls/configs/quick_start/new_user/ShuffleNetV2_x0_25
可以使用将使用与不使用预训练模型训练进行对比,观察 loss 的下降情况。
<a
name=
"4.2"
></a>
### 4.2 使用GPU进行模型训练
由于 GPU 训练速度更快,可以使用更复杂模型,因此以 ResNet50_vd 为例。与 ShuffleNetV2_x0_25 相比,此模型计算量较大,训练好的模型精度也会更高。
...
...
@@ -119,7 +138,8 @@ python tools/train.py -c ./ppcls/configs/quick_start/new_user/ShuffleNetV2_x0_25
set
CUDA_VISIBLE_DEVICES
=
0
```
#### 不使用预训练模型
<a
name=
"4.2.1"
></a>
#### 4.2.1 不使用预训练模型进行训练
```
shell
python tools/train.py
-c
./ppcls/configs/quick_start/ResNet50_vd.yaml
...
...
@@ -131,6 +151,7 @@ python tools/train.py -c ./ppcls/configs/quick_start/ResNet50_vd.yaml
<img
src=
"../../images/quick_start/r50_vd_acc.png"
width =
"800"
/>
</div>
<a
name=
"4.2.1"
></a>
#### 4.2.1 使用预训练模型进行训练
基于 ImageNet1k 分类预训练模型进行微调,训练脚本如下所示
...
...
@@ -147,6 +168,7 @@ python tools/train.py -c ./ppcls/configs/quick_start/ResNet50_vd.yaml -o Arch.pr
<img
src=
"../../images/quick_start/r50_vd_pretrained_acc.png"
width =
"800"
/>
</div>
<a
name=
"5"
></a>
## 5. 模型预测
训练完成后,可以使用训练好的模型进行预测,以训练的 ResNet50_vd 模型为例,预测代码如下:
...
...
docs/zh_CN/quick_start/quick_start_recognition.md
浏览文件 @
13d7b57a
...
...
@@ -40,8 +40,10 @@
| 模型简介 | 推荐场景 | inference模型 | 预测配置文件 | 构建索引库的配置文件 |
| ------------ | ------------- | -------- | ------- | -------- |
| 轻量级通用主体检测模型 | 通用场景 |
[
模型下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
)
| - | - |
| 轻量级通用识别模型 | 通用场景 |
[
模型下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
)
|
[
inference_general.yaml
](
../../../deploy/configs/inference_general.yaml
)
|
[
build_general.yaml
](
../../../deploy/configs/build_general.yaml
)
|
| 轻量级通用主体检测模型 | 通用场景 |
[
tar 格式文件下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
)
[
zip 格式文件下载链接
]
(https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.zip) | - | - |
| 轻量级通用识别模型 | 通用场景 |
[
tar 格式下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
)
[
zip 格式文件下载链接
]
(https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.zip) |
[
inference_general.yaml
](
../../../deploy/configs/inference_general.yaml
)
|
[
build_general.yaml
](
../../../deploy/configs/build_general.yaml
)
|
注意:由于部分解压缩软件在解压上述
`tar`
格式文件时存在问题,建议非命令行用户下载
`zip`
格式文件并解压。
`tar`
格式文件建议使用命令
`tar xf xxx.tar`
解压。
本章节 demo 数据下载地址如下:
[
瓶装饮料数据下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar
)
。
...
...
ppcls/configs/quick_start/ResNet50_vd.yaml
浏览文件 @
13d7b57a
...
...
@@ -4,7 +4,6 @@ Global:
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
102
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
...
...
@@ -17,8 +16,9 @@ Global:
# model architecture
Arch
:
name
:
ResNet50_vd
name
:
ResNet50_vd
class_num
:
102
# loss function config for traing/eval process
Loss
:
Train
:
...
...
ppcls/loss/__init__.py
浏览文件 @
13d7b57a
...
...
@@ -44,12 +44,18 @@ class CombinedLoss(nn.Layer):
def
__call__
(
self
,
input
,
batch
):
loss_dict
=
{}
for
idx
,
loss_func
in
enumerate
(
self
.
loss_func
):
loss
=
loss_func
(
input
,
batch
)
weight
=
self
.
loss_weight
[
idx
]
loss
=
{
key
:
loss
[
key
]
*
weight
for
key
in
loss
}
# just for accelerate classification traing speed
if
len
(
self
.
loss_func
)
==
1
:
loss
=
self
.
loss_func
[
0
](
input
,
batch
)
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
loss_dict
[
"loss"
]
=
list
(
loss
.
values
())[
0
]
else
:
for
idx
,
loss_func
in
enumerate
(
self
.
loss_func
):
loss
=
loss_func
(
input
,
batch
)
weight
=
self
.
loss_weight
[
idx
]
loss
=
{
key
:
loss
[
key
]
*
weight
for
key
in
loss
}
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
return
loss_dict
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录