未验证 提交 8f6106cb 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add industrial code and doc (#5493)

* add industrial code and doc
上级 1df0f648
# 产业级模型开发教程
飞桨是源于产业实践的开源深度学习平台,致力于让深度学习技术的创新与应用更简单。产业级模型的开发过程主要包含下面三个步骤。
<div align="center">
<img src="images/intrstrial_sota_model_pipeline.png" width = "800" />
</div>
具体地,
* 关于论文复现流程与方法,请参考:[论文复现指南](./article-implementation/ArticleReproduction_CV.md)
* 关于模型速度与精度优化的方法,请参考:[产业级SOTA模型优化指南](./pp-series/README.md)
* 关于训推一体全流程功能开发与测试方法,请参考:[飞桨训推一体全流程开发文档](./tipc/README.md)
# [Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019)](https://arxiv.org/abs/1902.09212)
## 1 Introduction
This is the paddle code of [Deep High-Resolution Representation Learning for Human Pose Estimation](https://arxiv.org/abs/1902.09212).
In this work, we are interested in the human pose estimation problem with a focus on learning reliable high-resolution representations. Most existing methods recover high-resolution representations from low-resolution representations produced by a high-to-low resolution network. Instead, our proposed network maintains high-resolution representations through the whole process. We start from a high-resolution subnetwork as the first stage, gradually add high-to-low resolution subnetworks one by one to form more stages, and connect the mutli-resolution subnetworks in parallel. We conduct repeated multi-scale fusions such that each of the high-to-low resolution representations receives information from other parallel representations over and over, leading to rich high-resolution representations. As a result, the predicted keypoint heatmap is potentially more accurate and spatially more precise. We empirically demonstrate the effectiveness of our network through the superior pose estimation results over two benchmark datasets: the COCO keypoint detection dataset and the MPII Human Pose dataset.
## 2 How to use
### 2.1 Environment
### Requirements:
- PaddlePaddle 2.2
- OS 64 bit
- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit
- pip/pip3(9.0.1+), 64 bit
- CUDA >= 10.1
- cuDNN >= 7.6
### Installation
#### 1. Install PaddlePaddle
```
# CUDA10.1
python -m pip install paddlepaddle-gpu==2.2.0.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
- For more CUDA version or environment to quick install, please refer to the [PaddlePaddle Quick Installation document](https://www.paddlepaddle.org.cn/install/quick)
- For more installation methods such as conda or compile with source code, please refer to the [installation document](https://www.paddlepaddle.org.cn/documentation/docs/en/install/index_en.html)
Please make sure that your PaddlePaddle is installed successfully and the version is not lower than the required version. Use the following command to verify.
```
# check
>>> import paddle
>>> paddle.utils.run_check()
# confirm the paddle's version
python -c "import paddle; print(paddle.__version__)"
```
**Note**
1. If you want to use PaddleDetection on multi-GPU, please install NCCL at first.
#### 2. Clone this repo, and we'll call the directory that you cloned as ${POSE_ROOT}.
#### 3. Install dependencies:
```
pip install -r requirements.txt
```
#### 4. Init output(training model output directory) and log(tensorboard log directory) directory:
```
mkdir output
mkdir log
```
Your directory tree should look like this:
```
${POSE_ROOT}
├── config
├── dataset
├── figures
├── lib
├── log
├── output
├── tools
├── README.md
└── requirements.txt
```
### 2.2 Data preparation
#### COCO Data Download
- The coco dataset is downloaded automatically through the code. The dataset is large and takes a long time to download
```
# automatically download coco datasets by executing code
python dataset/download_coco.py
```
after code execution, the organization structure of coco dataset file is:
```
>>cd dataset
>>tree
├── annotations
│ ├── instances_train2017.json
│ ├── instances_val2017.json
│ | ...
├── train2017
│ ├── 000000000009.jpg
│ ├── 000000580008.jpg
│ | ...
├── val2017
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
│ | ...
| ...
```
- If the coco dataset has been downloaded
The files can be organized according to the above data file organization structure.
### 2.3 Training & Evaluation & Inference
We provides scripts for training, evalution and inference with various features according to different configure.
```bash
# training on single-GPU
export CUDA_VISIBLE_DEVICES=0
python tools/train.py -c configs/hrnet_w32_256x192.yml
# training on multi-GPU
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/hrnet_w32_256x192.yml
# GPU evaluation
export CUDA_VISIBLE_DEVICES=0
python tools/eval.py -c configs/hrnet_w32_256x192.yml -o weights=https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x192.pdparams
# Inference
python tools/infer.py -c configs/hrnet_w32_256x192.yml --infer_img=dataset/test_image/hrnet_demo.jpg -o weights=https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x192.pdparams
# training with distillation
python tools/train.py -c configs/lite_hrnet_30_256x192_coco.yml --distill_config=./configs/hrnet_w32_256x192_teacher.yml
# training with PACT quantization on single-GPU
export CUDA_VISIBLE_DEVICES=0
python tools/train.py -c configs/lite_hrnet_30_256x192_coco_pact.yml
# training with PACT quantization on multi-GPU
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/lite_hrnet_30_256x192_coco_pact.yml
# GPU evaluation with PACT quantization
export CUDA_VISIBLE_DEVICES=0
python tools/eval.py -c configs/lite_hrnet_30_256x192_coco_pact.yml -o weights=https://paddledet.bj.bcebos.com/models/keypoint/lite_hrnet_30_256x192_coco_pact.pdparams
# Inference with PACT quantization
python tools/infer.py -c configs/lite_hrnet_30_256x192_coco_pact.yml
--infer_img=dataset/test_image/hrnet_demo.jpg -o weights=https://paddledet.bj.bcebos.com/models/keypoint/lite_hrnet_30_256x192_coco_pact.pdparams
```
## 3 Result
COCO Dataset
| Model | Input Size | AP(coco val) | Model Download | Config File |
| :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- |
| HRNet-w32 | 256x192 | 76.9 | [hrnet_w32_256x192.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/hrnet_w32_256x192.pdparams) | [config](./configs/hrnet_w32_256x192.yml) |
| LiteHRNet-30 | 256x192 | 69.4 | [lite_hrnet_30_256x192_coco.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco.yml) |
| LiteHRNet-30-PACT | 256x192 | 68.9 | [lite_hrnet_30_256x192_coco_pact.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_pact.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco_pact.yml) |
| LiteHRNet-30-PACT | 256x192 | 69.9 | [lite_hrnet_30_256x192_coco.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_dist.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco_pact.yml) |
![](/dataset/test_image/hrnet_demo.jpg)
![](/deploy/output/hrnet_demo_vis.jpg)
## Citation
````
@inproceedings{cheng2020bottom,
title={Deep High-Resolution Representation Learning for Human Pose Estimation},
author={Ke Sun and Bin Xiao and Dong Liu and Jingdong Wang},
booktitle={CVPR},
year={2019}
}
````
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_w32_256x192/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
#####model
architecture: TopDownHRNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams
TopDownHRNet:
backbone: HRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 32
loss: KeyPointMSELoss
use_dark: False
HRNet:
width: *width
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
#####optimizer
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: bbox.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.5
rot: 40
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false #whether to fuse nomalize layer into model while export model
pretrain_weights:
weights: "https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x192.pdparams"
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
# distillation config and loss
freeze_parameters: True
distill_loss:
name: DistMSELoss
weight: 1.0
key: output
# model
architecture: TopDownHRNet
TopDownHRNet:
backbone: HRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 32
loss: KeyPointMSELoss
use_dark: False
HRNet:
width: *width
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
\ No newline at end of file
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/lite_hrnet_30_256x192_coco/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
#####model
architecture: TopDownHRNet
TopDownHRNet:
backbone: LiteHRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 40
loss: KeyPointMSELoss
use_dark: false
LiteHRNet:
network_type: lite_30
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
loss_scale: 1.0
#####optimizer
LearningRate:
base_lr: 0.002
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 500
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.25
rot: 30
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/lite_hrnet_30_256x192_coco/model_final
epoch: 50
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_30_256x192_coco.pdparams
slim: QAT
QAT:
quant_config: {
'activation_preprocess_type': 'PACT',
'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: True
architecture: TopDownHRNet
TopDownHRNet:
backbone: LiteHRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 40
loss: KeyPointMSELoss
use_dark: false
LiteHRNet:
network_type: lite_30
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
loss_scale: 1.0
# optimizer
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
milestones: [40, 45]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 500
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.25
rot: 30
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os.path as osp
import logging
# add python path of PadleDetection to sys.path
parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from ppdet.utils.download import download_dataset
logging.basicConfig(level=logging.INFO)
download_path = osp.split(osp.realpath(sys.argv[0]))[0]
download_dataset(download_path, 'coco')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import paddle
import paddle.inference as paddle_infer
from pathlib import Path
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
LOG_PATH_ROOT = f"{CUR_DIR}/../../output"
class PaddleInferBenchmark(object):
def __init__(self,
config,
model_info: dict={},
data_info: dict={},
perf_info: dict={},
resource_info: dict={},
**kwargs):
"""
Construct PaddleInferBenchmark Class to format logs.
args:
config(paddle.inference.Config): paddle inference config
model_info(dict): basic model info
{'model_name': 'resnet50'
'precision': 'fp32'}
data_info(dict): input data info
{'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info(dict): performance result
{'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info(dict):
cpu and gpu resources
{'cpu_rss': 100
'gpu_rss': 100
'gpu_util': 60}
"""
# PaddleInferBenchmark Log Version
self.log_version = "1.0.3"
# Paddle Version
self.paddle_version = paddle.__version__
self.paddle_commit = paddle.__git_commit__
paddle_infer_info = paddle_infer.get_version()
self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
# model info
self.model_info = model_info
# data info
self.data_info = data_info
# perf info
self.perf_info = perf_info
try:
# required value
self.model_name = model_info['model_name']
self.precision = model_info['precision']
self.batch_size = data_info['batch_size']
self.shape = data_info['shape']
self.data_num = data_info['data_num']
self.inference_time_s = round(perf_info['inference_time_s'], 4)
except:
self.print_help()
raise ValueError(
"Set argument wrong, please check input argument and its type")
self.preprocess_time_s = perf_info.get('preprocess_time_s', 0)
self.postprocess_time_s = perf_info.get('postprocess_time_s', 0)
self.total_time_s = perf_info.get('total_time_s', 0)
self.inference_time_s_90 = perf_info.get("inference_time_s_90", "")
self.inference_time_s_99 = perf_info.get("inference_time_s_99", "")
self.succ_rate = perf_info.get("succ_rate", "")
self.qps = perf_info.get("qps", "")
# conf info
self.config_status = self.parse_config(config)
# mem info
if isinstance(resource_info, dict):
self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
self.cpu_vms_mb = int(resource_info.get('cpu_vms_mb', 0))
self.cpu_shared_mb = int(resource_info.get('cpu_shared_mb', 0))
self.cpu_dirty_mb = int(resource_info.get('cpu_dirty_mb', 0))
self.cpu_util = round(resource_info.get('cpu_util', 0), 2)
self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
self.gpu_mem_util = round(resource_info.get('gpu_mem_util', 0), 2)
else:
self.cpu_rss_mb = 0
self.cpu_vms_mb = 0
self.cpu_shared_mb = 0
self.cpu_dirty_mb = 0
self.cpu_util = 0
self.gpu_rss_mb = 0
self.gpu_util = 0
self.gpu_mem_util = 0
# init benchmark logger
self.benchmark_logger()
def benchmark_logger(self):
"""
benchmark logger
"""
# remove other logging handler
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
# Init logger
FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
log_output = f"{LOG_PATH_ROOT}/{self.model_name}.log"
Path(f"{LOG_PATH_ROOT}").mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
handlers=[
logging.FileHandler(
filename=log_output, mode='w'),
logging.StreamHandler(),
])
self.logger = logging.getLogger(__name__)
self.logger.info(
f"Paddle Inference benchmark log will be saved to {log_output}")
def parse_config(self, config) -> dict:
"""
parse paddle predictor config
args:
config(paddle.inference.Config): paddle inference config
return:
config_status(dict): dict style config info
"""
if isinstance(config, paddle_infer.Config):
config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu(
) else "cpu"
config_status['ir_optim'] = config.ir_optim()
config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
config_status['precision'] = self.precision
config_status['enable_mkldnn'] = config.mkldnn_enabled()
config_status[
'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
)
elif isinstance(config, dict):
config_status['runtime_device'] = config.get('runtime_device', "")
config_status['ir_optim'] = config.get('ir_optim', "")
config_status['enable_tensorrt'] = config.get('enable_tensorrt',
"")
config_status['precision'] = config.get('precision', "")
config_status['enable_mkldnn'] = config.get('enable_mkldnn', "")
config_status['cpu_math_library_num_threads'] = config.get(
'cpu_math_library_num_threads', "")
else:
self.print_help()
raise ValueError(
"Set argument config wrong, please check input argument and its type"
)
return config_status
def report(self, identifier=None):
"""
print log report
args:
identifier(string): identify log
"""
if identifier:
identifier = f"[{identifier}]"
else:
identifier = ""
self.logger.info("\n")
self.logger.info(
"---------------------- Paddle info ----------------------")
self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
self.logger.info(f"{identifier} log_api_version: {self.log_version}")
self.logger.info(
"----------------------- Conf info -----------------------")
self.logger.info(
f"{identifier} runtime_device: {self.config_status['runtime_device']}"
)
self.logger.info(
f"{identifier} ir_optim: {self.config_status['ir_optim']}")
self.logger.info(f"{identifier} enable_memory_optim: {True}")
self.logger.info(
f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
)
self.logger.info(
f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}"
)
self.logger.info(
f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
)
self.logger.info(
"----------------------- Model info ----------------------")
self.logger.info(f"{identifier} model_name: {self.model_name}")
self.logger.info(f"{identifier} precision: {self.precision}")
self.logger.info(
"----------------------- Data info -----------------------")
self.logger.info(f"{identifier} batch_size: {self.batch_size}")
self.logger.info(f"{identifier} input_shape: {self.shape}")
self.logger.info(f"{identifier} data_num: {self.data_num}")
self.logger.info(
"----------------------- Perf info -----------------------")
self.logger.info(
f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, cpu_vms: {self.cpu_vms_mb}, cpu_shared_mb: {self.cpu_shared_mb}, cpu_dirty_mb: {self.cpu_dirty_mb}, cpu_util: {self.cpu_util}%"
)
self.logger.info(
f"{identifier} gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%, gpu_mem_util: {self.gpu_mem_util}%"
)
self.logger.info(
f"{identifier} total time spent(s): {self.total_time_s}")
self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
)
if self.inference_time_s_90:
self.looger.info(
f"{identifier} 90%_cost: {self.inference_time_s_90}, 99%_cost: {self.inference_time_s_99}, succ_rate: {self.succ_rate}"
)
if self.qps:
self.logger.info(f"{identifier} QPS: {self.qps}")
def print_help(self):
"""
print function help
"""
print("""Usage:
==== Print inference benchmark logs. ====
config = paddle.inference.Config()
model_info = {'model_name': 'resnet50'
'precision': 'fp32'}
data_info = {'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info = {'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info = {'cpu_rss_mb': 100
'gpu_rss_mb': 100
'gpu_util': 60}
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
log('Test')
""")
def __call__(self, identifier=None):
"""
__call__
args:
identifier(string): identify log
"""
self.report(identifier)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import glob
from functools import reduce
import cv2
import numpy as np
import math
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, WarpAffine, TopDownEvalAffine, expand_crop
from postprocess import HRNetPostProcess
from visualize import draw_pose
from utils import argsparser, Timer, get_current_memory_mb
class Detector(object):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
use_dark=True):
self.pred_config = pred_config
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.use_dark = use_dark
def preprocess(self, image_list):
preprocess_ops = []
for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
input_im_lst = []
input_im_info_lst = []
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
input_im_lst.append(im)
input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst)
return inputs
def postprocess(self, np_boxes, inputs, threshold=0.5):
# postprocess output of predictor
results = {}
imshape = inputs['im_shape'][:, ::-1]
center = np.round(imshape / 2.)
scale = imshape / 200.
postprocess = HRNetPostProcess(use_dark=self.use_dark)
results['keypoint'] = postprocess(np_boxes, center, scale)
return results
def predict(self, image_list, threshold=0.5, repeats=1, add_timer=True):
'''
Args:
image_list (list): list of image
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape: [N, im_h, im_w]
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
np_boxes = None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess
results = self.postprocess(np_boxes, inputs, threshold=threshold)
if add_timer:
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(image_list)
return results
def get_timer(self):
return self.det_times
def create_inputs(imgs, im_info):
"""generate input for different model type
Args:
imgs (list(numpy)): list of images (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs = {}
inputs['image'] = np.stack(imgs, axis=0)
im_shape = []
for e in im_info:
im_shape.append(np.array((e['im_shape'])).astype('float32'))
inputs['im_shape'] = np.stack(im_shape, axis=0)
return inputs
class PredictConfig():
"""set config of preprocess, postprocess and visualize
Args:
model_dir (str): root path of model.yml
"""
def __init__(self, model_dir):
# parsing Yaml config for Preprocess
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.print_config()
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir,
run_mode='paddle',
batch_size=1,
device='CPU',
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.format(run_mode, device))
config = Config(
os.path.join(model_dir, 'model.pdmodel'),
os.path.join(model_dir, 'model.pdiparams'))
if device == 'GPU':
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
elif device == 'XPU':
config.enable_lite_engine()
config.enable_xpu(10 * 1024 * 1024)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
if enable_mkldnn:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
except Exception as e:
print(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
precision_map = {
'trt_int8': Config.Precision.Int8,
'trt_fp32': Config.Precision.Float32,
'trt_fp16': Config.Precision.Half
}
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=1 << 25,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if use_dynamic_shape:
min_input_shape = {
'image': [batch_size, 3, trt_min_shape, trt_min_shape]
}
max_input_shape = {
'image': [batch_size, 3, trt_max_shape, trt_max_shape]
}
opt_input_shape = {
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
print('trt set dynamic shape done!')
# disable print log when predict
config.disable_glog_info()
# enable shared memory
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor, config
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
print("Found {} inference images in total.".format(len(images)))
return images
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
def predict_image(detector, image_list, batch_size=1):
for i, img_file in enumerate(image_list):
if FLAGS.run_benchmark:
# warmup
detector.predict(
image_list, FLAGS.threshold, repeats=10, add_timer=False)
# run benchmark
detector.predict(
image_list, FLAGS.threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}'.format(i))
else:
results = detector.predict(image_list, FLAGS.threshold)
draw_pose(
img_file,
results,
visual_thread=FLAGS.threshold,
save_dir=FLAGS.output_dir)
def predict_video(detector, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_out_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = detector.predict([frame], FLAGS.threshold)
im = draw_pose(
frame, results, visual_thread=FLAGS.threshold, returnimg=True)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector = Detector(
pred_config,
FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
use_dark=FLAGS.use_dark)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
data_info = {
'batch_size': 1,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info,
data_info, perf_info, mems)
det_log('Det')
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import os
import sys
import paddle.distributed as dist
__all__ = ['setup_logger']
logger_initialized = []
def setup_logger(name="ppdet", output=None):
"""
Initialize logger and set its verbosity level to INFO.
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
logger.setLevel(logging.INFO)
logger.propagate = False
formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%m/%d %H:%M:%S")
# stdout logging: master only
local_rank = dist.get_rank()
if local_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if local_rank > 0:
filename = filename + ".rank{}".format(local_rank)
os.makedirs(os.path.dirname(filename))
fh = logging.FileHandler(filename, mode='a')
fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter())
logger.addHandler(fh)
logger_initialized.append(name)
return logger
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from scipy.optimize import linear_sum_assignment
from collections import abc, defaultdict
import cv2
import numpy as np
import math
import paddle
import paddle.nn as nn
from preprocess import get_affine_mat_kernel, get_affine_transform
class HRNetPostProcess(object):
def __init__(self, use_dark=True):
self.use_dark = use_dark
def flip_back(self, output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def get_max_preds(self, heatmaps):
"""get predictions from score maps
Args:
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
"""
assert isinstance(heatmaps,
np.ndarray), 'heatmaps should be numpy.ndarray'
assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = heatmaps.shape[0]
num_joints = heatmaps.shape[1]
width = heatmaps.shape[3]
heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def gaussian_blur(self, heatmap, kernel):
border = (kernel - 1) // 2
batch_size = heatmap.shape[0]
num_joints = heatmap.shape[1]
height = heatmap.shape[2]
width = heatmap.shape[3]
for i in range(batch_size):
for j in range(num_joints):
origin_max = np.max(heatmap[i, j])
dr = np.zeros((height + 2 * border, width + 2 * border))
dr[border:-border, border:-border] = heatmap[i, j].copy()
dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
heatmap[i, j] = dr[border:-border, border:-border].copy()
heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
return heatmap
def dark_parse(self, hm, coord):
heatmap_height = hm.shape[0]
heatmap_width = hm.shape[1]
px = int(coord[0])
py = int(coord[1])
if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \
+ hm[py-1][px-1])
dyy = 0.25 * (
hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
derivative = np.matrix([[dx], [dy]])
hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
if dxx * dyy - dxy**2 != 0:
hessianinv = hessian.I
offset = -hessianinv * derivative
offset = np.squeeze(np.array(offset.T), axis=0)
coord += offset
return coord
def dark_postprocess(self, hm, coords, kernelsize):
"""
refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
"""
hm = self.gaussian_blur(hm, kernelsize)
hm = np.maximum(hm, 1e-10)
hm = np.log(hm)
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
return coords
def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
"""the highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response.
Args:
heatmaps (numpy.ndarray): The predicted heatmaps
center (numpy.ndarray): The boxes center
scale (numpy.ndarray): The scale factor
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
"""
coords, maxvals = self.get_max_preds(heatmaps)
heatmap_height = heatmaps.shape[2]
heatmap_width = heatmaps.shape[3]
if self.use_dark:
coords = self.dark_postprocess(heatmaps, coords, kernelsize)
else:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = np.array([
hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px]
])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()
# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(coords[i], center[i], scale[i],
[heatmap_width, heatmap_height])
return preds, maxvals
def __call__(self, output, center, scale):
preds, maxvals = self.get_final_preds(output, center, scale)
return np.concatenate(
(preds, maxvals), axis=-1), np.mean(
maxvals, axis=1)
def transform_preds(coords, center, scale, output_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale * 200, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def translate_to_ori_images(keypoint_result, batch_records):
kpts, scores = keypoint_result['keypoint']
kpts[..., 0] += batch_records[:, 0:1]
kpts[..., 1] += batch_records[:, 1:2]
return kpts, scores
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import ast
import argparse
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."),
required=True)
parser.add_argument(
"--image_file", type=str, default=None, help="Path of image file.")
parser.add_argument(
"--image_dir",
type=str,
default=None,
help="Dir of image file, `image_file` has a higher priority.")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch_size for inference.")
parser.add_argument(
"--video_file",
type=str,
default=None,
help="Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory of output visualization files.")
parser.add_argument(
"--run_mode",
type=str,
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
)
parser.add_argument(
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Deprecated, please use `--device`.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--enable_mkldnn",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
parser.add_argument(
"--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.")
parser.add_argument(
"--trt_max_shape",
type=int,
default=1280,
help="max_shape for TensorRT.")
parser.add_argument(
"--trt_opt_shape",
type=int,
default=640,
help="opt_shape for TensorRT.")
parser.add_argument(
"--trt_calib_mode",
type=bool,
default=False,
help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.")
parser.add_argument(
'--save_images',
action='store_true',
help='Save visualization image results.')
parser.add_argument(
'--use_dark',
type=bool,
default=True,
help='whether to use darkpose to get better keypoint position predict ')
return parser
class Times(object):
def __init__(self):
self.time = 0.
# start time
self.st = 0.
# end time
self.et = 0.
def start(self):
self.st = time.time()
def end(self, repeats=1, accumulative=True):
self.et = time.time()
if accumulative:
self.time += (self.et - self.st) / repeats
else:
self.time = (self.et - self.st) / repeats
def reset(self):
self.time = 0.
self.st = 0.
self.et = 0.
def value(self):
return round(self.time, 4)
class Timer(Times):
def __init__(self):
super(Timer, self).__init__()
self.preprocess_time_s = Times()
self.inference_time_s = Times()
self.postprocess_time_s = Times()
self.img_num = 0
def info(self, average=False):
total_time = self.preprocess_time_s.value(
) + self.inference_time_s.value() + self.postprocess_time_s.value()
total_time = round(total_time, 4)
print("------------------ Inference Time Info ----------------------")
print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
self.img_num))
preprocess_time = round(
self.preprocess_time_s.value() / max(1, self.img_num),
4) if average else self.preprocess_time_s.value()
postprocess_time = round(
self.postprocess_time_s.value() / max(1, self.img_num),
4) if average else self.postprocess_time_s.value()
inference_time = round(self.inference_time_s.value() /
max(1, self.img_num),
4) if average else self.inference_time_s.value()
average_latency = total_time / max(1, self.img_num)
qps = 0
if total_time > 0:
qps = 1 / average_latency
print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, qps))
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000))
def report(self, average=False):
dic = {}
dic['preprocess_time_s'] = round(
self.preprocess_time_s.value() / max(1, self.img_num),
4) if average else self.preprocess_time_s.value()
dic['postprocess_time_s'] = round(
self.postprocess_time_s.value() / max(1, self.img_num),
4) if average else self.postprocess_time_s.value()
dic['inference_time_s'] = round(
self.inference_time_s.value() / max(1, self.img_num),
4) if average else self.inference_time_s.value()
dic['img_num'] = self.img_num
total_time = self.preprocess_time_s.value(
) + self.inference_time_s.value() + self.postprocess_time_s.value()
dic['total_time_s'] = round(total_time, 4)
return dic
def get_current_memory_mb():
"""
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
And this function Current program is time-consuming.
"""
import pynvml
import psutil
import GPUtil
gpu_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0))
pid = os.getpid()
p = psutil.Process(pid)
info = p.memory_full_info()
cpu_mem = info.uss / 1024. / 1024.
gpu_mem = 0
gpu_percent = 0
gpus = GPUtil.getGPUs()
if gpu_id is not None and len(gpus) > 0:
gpu_percent = gpus[gpu_id].load
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
# coding: utf-8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import math
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def draw_pose(imgfile,
results,
visual_thread=0.6,
save_name='pose.jpg',
save_dir='output',
returnimg=False,
ids=None):
try:
import matplotlib.pyplot as plt
import matplotlib
plt.switch_backend('agg')
except Exception as e:
logger.error('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.')
raise e
skeletons, scores = results['keypoint']
skeletons = np.array(skeletons)
kpt_nums = 17
if len(skeletons) > 0:
kpt_nums = skeletons.shape[1]
if kpt_nums == 17: #plot coco keypoint
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7),
(6, 8), (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
(13, 15), (14, 16), (11, 12)]
else: #plot mpii keypoint
EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7),
(7, 8), (8, 9), (10, 11), (11, 12), (13, 14), (14, 15),
(8, 12), (8, 13)]
NUM_EDGES = len(EDGES)
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
cmap = matplotlib.cm.get_cmap('hsv')
plt.figure()
img = cv2.imread(imgfile) if type(imgfile) == str else imgfile
color_set = results['colors'] if 'colors' in results else None
if 'bbox' in results and ids is None:
bboxs = results['bbox']
for j, rect in enumerate(bboxs):
xmin, ymin, xmax, ymax = rect
color = colors[0] if color_set is None else colors[color_set[j] %
len(colors)]
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
canvas = img.copy()
for i in range(kpt_nums):
for j in range(len(skeletons)):
if skeletons[j][i, 2] < visual_thread:
continue
if ids is None:
color = colors[i] if color_set is None else colors[color_set[j]
%
len(colors)]
else:
color = get_color(ids[j])
cv2.circle(
canvas,
tuple(skeletons[j][i, 0:2].astype('int32')),
2,
color,
thickness=-1)
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
fig = matplotlib.pyplot.gcf()
stickwidth = 2
for i in range(NUM_EDGES):
for j in range(len(skeletons)):
edge = EDGES[i]
if skeletons[j][edge[0], 2] < visual_thread or skeletons[j][edge[
1], 2] < visual_thread:
continue
cur_canvas = canvas.copy()
X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)),
(int(length / 2), stickwidth),
int(angle), 0, 360, 1)
if ids is None:
color = colors[i] if color_set is None else colors[color_set[j]
%
len(colors)]
else:
color = get_color(ids[j])
cv2.fillConvexPoly(cur_canvas, polygon, color)
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
if returnimg:
return canvas
save_name = os.path.join(
save_dir, os.path.splitext(os.path.basename(imgfile))[0] + '_vis.jpg')
plt.imsave(save_name, canvas[:, :, ::-1])
print("keypoint visualize image saved to: " + save_name)
plt.close()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import lib.utils
import lib.models
import lib.metrics
import lib.dataset
import lib.core
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import callbacks
from . import optimizer
from . import trainer
from .callbacks import *
from .optimizer import *
from .trainer import *
__all__ = callbacks.__all__ \
+ optimizer.__all__ + trainer.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import datetime
import six
import copy
import json
import paddle
import paddle.distributed as dist
from lib.utils.checkpoint import save_model
from lib.metrics.coco_utils import get_infer_results
from lib.utils.logger import setup_logger
logger = setup_logger('hrnet')
__all__ = [
'Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer',
'VisualDLWriter'
]
class Callback(object):
def __init__(self, model):
self.model = model
def on_step_begin(self, status):
pass
def on_step_end(self, status):
pass
def on_epoch_begin(self, status):
pass
def on_epoch_end(self, status):
pass
def on_train_begin(self, status):
pass
def on_train_end(self, status):
pass
class ComposeCallback(object):
def __init__(self, callbacks):
callbacks = [c for c in list(callbacks) if c is not None]
for c in callbacks:
assert isinstance(
c, Callback), "callback should be subclass of Callback"
self._callbacks = callbacks
def on_step_begin(self, status):
for c in self._callbacks:
c.on_step_begin(status)
def on_step_end(self, status):
for c in self._callbacks:
c.on_step_end(status)
def on_epoch_begin(self, status):
for c in self._callbacks:
c.on_epoch_begin(status)
def on_epoch_end(self, status):
for c in self._callbacks:
c.on_epoch_end(status)
def on_train_begin(self, status):
for c in self._callbacks:
c.on_train_begin(status)
def on_train_end(self, status):
for c in self._callbacks:
c.on_train_end(status)
class LogPrinter(Callback):
def __init__(self, model):
super(LogPrinter, self).__init__(model)
def on_step_end(self, status):
if dist.get_world_size() < 2 or dist.get_rank() == 0:
mode = status['mode']
if mode == 'train':
epoch_id = status['epoch_id']
step_id = status['step_id']
steps_per_epoch = status['steps_per_epoch']
training_staus = status['training_staus']
batch_time = status['batch_time']
data_time = status['data_time']
epoches = self.model.cfg.epoch
batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
))]['batch_size']
logs = training_staus.log()
space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
if step_id % self.model.cfg.log_iter == 0:
eta_steps = (epoches - epoch_id
) * steps_per_epoch - step_id
eta_sec = eta_steps * batch_time.global_avg
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
ips = float(batch_size) / batch_time.avg
fmt = ' '.join([
'Epoch: [{}]',
'[{' + space_fmt + '}/{}]',
'learning_rate: {lr:.6f}',
'{meters}',
'eta: {eta}',
'batch_cost: {btime}',
'data_cost: {dtime}',
'ips: {ips:.4f} images/s',
])
fmt = fmt.format(
epoch_id,
step_id,
steps_per_epoch,
lr=status['learning_rate'],
meters=logs,
eta=eta_str,
btime=str(batch_time),
dtime=str(data_time),
ips=ips)
logger.info(fmt)
if mode == 'eval':
step_id = status['step_id']
if step_id % 100 == 0:
logger.info("Eval iter: {}".format(step_id))
def on_epoch_end(self, status):
if dist.get_world_size() < 2 or dist.get_rank() == 0:
mode = status['mode']
if mode == 'eval':
sample_num = status['sample_num']
cost_time = status['cost_time']
logger.info('Total sample number: {}, averge FPS: {}'.format(
sample_num, sample_num / cost_time))
class Checkpointer(Callback):
def __init__(self, model):
super(Checkpointer, self).__init__(model)
cfg = self.model.cfg
self.best_ap = 0.
self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
if hasattr(self.model.model, 'student_model'):
self.weight = self.model.model.student_model
else:
self.weight = self.model.model
def on_epoch_end(self, status):
# Checkpointer only performed during training
mode = status['mode']
epoch_id = status['epoch_id']
weight = None
save_name = None
if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'train':
end_epoch = self.model.cfg.epoch
if (
epoch_id + 1
) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_name = str(
epoch_id
) if epoch_id != end_epoch - 1 else "model_final"
weight = self.weight
elif mode == 'eval':
if 'save_best_model' in status and status['save_best_model']:
for metric in self.model._metrics:
map_res = metric.get_results()
if 'bbox' in map_res:
key = 'bbox'
elif 'keypoint' in map_res:
key = 'keypoint'
else:
key = 'mask'
if key not in map_res:
logger.warning("Evaluation results empty, this may be due to " \
"training iterations being too few or not " \
"loading the correct weights.")
return
if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0]
save_name = 'best_model'
weight = self.weight
logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap))
if weight:
save_model(weight, self.model.optimizer, self.save_dir,
save_name, epoch_id + 1)
class VisualDLWriter(Callback):
"""
Use VisualDL to log data or image
"""
def __init__(self, model):
super(VisualDLWriter, self).__init__(model)
assert six.PY3, "VisualDL requires Python >= 3.5"
try:
from visualdl import LogWriter
except Exception as e:
logger.error('visualdl not found, plaese install visualdl. '
'for example: `pip install visualdl`.')
raise e
self.vdl_writer = LogWriter(
model.cfg.get('vdl_log_dir', 'vdl_log_dir/scalar'))
self.vdl_loss_step = 0
self.vdl_mAP_step = 0
self.vdl_image_step = 0
self.vdl_image_frame = 0
def on_step_end(self, status):
mode = status['mode']
if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'train':
training_staus = status['training_staus']
for loss_name, loss_value in training_staus.get().items():
self.vdl_writer.add_scalar(loss_name, loss_value,
self.vdl_loss_step)
self.vdl_loss_step += 1
elif mode == 'test':
ori_image = status['original_image']
result_image = status['result_image']
self.vdl_writer.add_image(
"original/frame_{}".format(self.vdl_image_frame),
ori_image, self.vdl_image_step)
self.vdl_writer.add_image(
"result/frame_{}".format(self.vdl_image_frame),
result_image, self.vdl_image_step)
self.vdl_image_step += 1
# each frame can display ten pictures at most.
if self.vdl_image_step % 10 == 0:
self.vdl_image_step = 0
self.vdl_image_frame += 1
def on_epoch_end(self, status):
mode = status['mode']
if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'eval':
for metric in self.model._metrics:
for key, map_value in metric.get_results().items():
self.vdl_writer.add_scalar("{}-mAP".format(key),
map_value[0],
self.vdl_mAP_step)
self.vdl_mAP_step += 1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import yaml
from collections import OrderedDict
import paddle
from lib.dataset.category import get_categories
from lib.utils.logger import setup_logger
logger = setup_logger('hrnet')
# Global dictionary
TRT_MIN_SUBGRAPH = {'HRNet': 3, }
def _prune_input_spec(input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode
paddle.enable_static()
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
global_block = program.global_block()
for name, spec in input_spec[0].items():
try:
v = global_block.var(name)
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static()
return pruned_input_spec
def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
preprocess_list = []
anno_file = dataset_cfg.get_anno()
clsid2catid, catid2name = get_categories(metric, anno_file, arch)
label_list = [str(cat) for cat in catid2name.values()]
fuse_normalize = reader_cfg.get('fuse_normalize', False)
sample_transforms = reader_cfg['sample_transforms']
for st in sample_transforms[1:]:
for key, value in st.items():
p = {'type': key}
if key == 'Resize':
if int(image_shape[1]) != -1:
value['target_size'] = image_shape[1:]
if fuse_normalize and key == 'NormalizeImage':
continue
p.update(value)
preprocess_list.append(p)
return preprocess_list, label_list
def _parse_tracker(tracker_cfg):
tracker_params = {}
for k, v in tracker_cfg.items():
tracker_params.update({k: v})
return tracker_params
def _dump_infer_config(config, path, image_shape, model):
arch_state = False
from lib.utils.config.yaml_helpers import setup_orderdict
setup_orderdict()
use_dynamic_shape = True if image_shape[2] == -1 else False
infer_cfg = OrderedDict({
'mode': 'fluid',
'draw_threshold': 0.5,
'metric': config['metric'],
'use_dynamic_shape': use_dynamic_shape
})
infer_arch = config['architecture']
for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
if arch in infer_arch:
infer_cfg['arch'] = arch
infer_cfg['min_subgraph_size'] = min_subgraph_size
arch_state = True
break
if not arch_state:
logger.error(
'Architecture: {} is not supported for exporting model now.\n'.
format(infer_arch) +
'Please set TRT_MIN_SUBGRAPH in ppdet/engine/export_utils.py')
os._exit(0)
label_arch = 'keypoint_arch'
reader_cfg = config['TestReader']
dataset_cfg = config['TestDataset']
infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader(
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
yaml.dump(infer_cfg, open(path, 'w'))
logger.info("Export inference config file to {}".format(
os.path.join(path)))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.nn as nn
import paddle.optimizer as optimizer
import paddle.regularizer as regularizer
from lib.utils.workspace import register, serializable
__all__ = ['LearningRate', 'OptimizerBuilder']
from ..utils.logger import setup_logger
logger = setup_logger(__name__)
@serializable
class PiecewiseDecay(object):
"""
Multi step learning rate decay
Args:
gamma (float | list): decay factor
milestones (list): steps at which to decay learning rate
"""
def __init__(self,
gamma=[0.1, 0.01],
milestones=[8, 11],
values=None,
use_warmup=True):
super(PiecewiseDecay, self).__init__()
if type(gamma) is not list:
self.gamma = []
for i in range(len(milestones)):
self.gamma.append(gamma / 10**i)
else:
self.gamma = gamma
self.milestones = milestones
self.values = values
self.use_warmup = use_warmup
def __call__(self,
base_lr=None,
boundary=None,
value=None,
step_per_epoch=None):
if boundary is not None and self.use_warmup:
boundary.extend([int(step_per_epoch) * i for i in self.milestones])
else:
# do not use LinearWarmup
boundary = [int(step_per_epoch) * i for i in self.milestones]
value = [base_lr] # during step[0, boundary[0]] is base_lr
# self.values is setted directly in config
if self.values is not None:
assert len(self.milestones) + 1 == len(self.values)
return optimizer.lr.PiecewiseDecay(boundary, self.values)
# value is computed by self.gamma
value = value if value is not None else [base_lr]
for i in self.gamma:
value.append(base_lr * i)
return optimizer.lr.PiecewiseDecay(boundary, value)
@serializable
class LinearWarmup(object):
"""
Warm up learning rate linearly
Args:
steps (int): warm up steps
start_factor (float): initial learning rate factor
"""
def __init__(self, steps=500, start_factor=1. / 3):
super(LinearWarmup, self).__init__()
self.steps = steps
self.start_factor = start_factor
def __call__(self, base_lr, step_per_epoch):
boundary = []
value = []
for i in range(self.steps + 1):
if self.steps > 0:
alpha = i / self.steps
factor = self.start_factor * (1 - alpha) + alpha
lr = base_lr * factor
value.append(lr)
if i > 0:
boundary.append(i)
return boundary, value
@register
class LearningRate(object):
"""
Learning Rate configuration
Args:
base_lr (float): base learning rate
schedulers (list): learning rate schedulers
"""
__category__ = 'optim'
def __init__(self,
base_lr=0.01,
schedulers=[PiecewiseDecay(), LinearWarmup()]):
super(LearningRate, self).__init__()
self.base_lr = base_lr
self.schedulers = schedulers
def __call__(self, step_per_epoch):
assert len(self.schedulers) >= 1
if not self.schedulers[0].use_warmup:
return self.schedulers[0](base_lr=self.base_lr,
step_per_epoch=step_per_epoch)
# TODO: split warmup & decay
# warmup
boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
# decay
decay_lr = self.schedulers[0](self.base_lr, boundary, value,
step_per_epoch)
return decay_lr
@register
class OptimizerBuilder():
"""
Build optimizer handles
Args:
regularizer (object): an `Regularizer` instance
optimizer (object): an `Optimizer` instance
"""
__category__ = 'optim'
def __init__(self,
clip_grad_by_norm=None,
regularizer={'type': 'L2',
'factor': .0001},
optimizer={'type': 'Momentum',
'momentum': .9}):
self.clip_grad_by_norm = clip_grad_by_norm
self.regularizer = regularizer
self.optimizer = optimizer
def __call__(self, learning_rate, model=None):
if not isinstance(model, (list, tuple)):
model = [model]
if self.clip_grad_by_norm is not None:
grad_clip = nn.ClipGradByGlobalNorm(
clip_norm=self.clip_grad_by_norm)
else:
grad_clip = None
if self.regularizer and self.regularizer != 'None':
reg_type = self.regularizer['type'] + 'Decay'
reg_factor = self.regularizer['factor']
regularization = getattr(regularizer, reg_type)(reg_factor)
else:
regularization = None
optim_args = self.optimizer.copy()
optim_type = optim_args['type']
del optim_args['type']
optim_args['weight_decay'] = regularization
op = getattr(optimizer, optim_type)
params = []
for m in model:
if m is not None:
params.extend(m.parameters())
return op(learning_rate=learning_rate,
parameters=params,
grad_clip=grad_clip,
**optim_args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import copy
import time
import numpy as np
from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle import amp
from paddle.static import InputSpec
from lib.utils.workspace import create
from lib.utils.checkpoint import load_weight, load_pretrain_weight
from lib.utils.visualizer import visualize_results, save_result
from lib.metrics.coco_utils import get_infer_results
from lib.metrics import KeyPointTopDownCOCOEval
from lib.dataset.category import get_categories
import lib.utils.stats as stats
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, VisualDLWriter
from .export_utils import _dump_infer_config, _prune_input_spec
from lib.utils.logger import setup_logger
logger = setup_logger('hrnet.pose')
__all__ = ['Trainer']
class Trainer(object):
def __init__(self, cfg, mode='train'):
self.cfg = cfg
assert mode.lower() in ['train', 'eval', 'test'], \
"mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower()
self.optimizer = None
# init distillation config
self.distill_model = None
self.distill_loss = None
# build data loader
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
if self.mode == 'train':
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num)
self.model = create(cfg.architecture)
#normalize params for deploy
self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
# EvalDataset build with BatchSampler to evaluate in single device
if self.mode == 'eval':
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num, self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here
self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank()
self.status = {}
self.start_epoch = 0
self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
# initial default callbacks
self._init_callbacks()
# initial default metrics
self._init_metrics()
self._reset_metrics()
def _init_callbacks(self):
if self.mode == 'train':
self._callbacks = [LogPrinter(self), Checkpointer(self)]
if self.cfg.get('use_vdl', False):
self._callbacks.append(VisualDLWriter(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval':
self._callbacks = [LogPrinter(self)]
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'test' and self.cfg.get('use_vdl', False):
self._callbacks = [VisualDLWriter(self)]
self._compose_callback = ComposeCallback(self._callbacks)
else:
self._callbacks = []
self._compose_callback = None
def _init_metrics(self, validate=False):
if self.mode == 'test' or (self.mode == 'train' and not validate):
self._metrics = []
return
if self.cfg.metric == 'KeyPointTopDownCOCOEval':
eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno()
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [
KeyPointTopDownCOCOEval(
anno_file,
len(eval_dataset),
self.cfg.num_joints,
self.cfg.save_dir,
save_prediction_only=save_prediction_only)
]
else:
logger.warning("Metric not support for metric type {}".format(
self.cfg.metric))
self._metrics = []
def init_optimizer(self, ):
# build optimizer in train mode
if self.mode == 'train':
steps_per_epoch = len(self.loader)
self.lr = create('LearningRate')(steps_per_epoch)
self.optimizer = create('OptimizerBuilder')(
self.lr, [self.model, self.distill_model])
def _reset_metrics(self):
for metric in self._metrics:
metric.reset()
def register_callbacks(self, callbacks):
callbacks = [c for c in list(callbacks) if c is not None]
for c in callbacks:
assert isinstance(c, Callback), \
"metrics shoule be instances of subclass of Metric"
self._callbacks.extend(callbacks)
self._compose_callback = ComposeCallback(self._callbacks)
def register_metrics(self, metrics):
metrics = [m for m in list(metrics) if m is not None]
self._metrics.extend(metrics)
def load_weights(self, weights, model=None):
self.start_epoch = 0
if model is None:
model = self.model
load_pretrain_weight(self.model, weights)
logger.debug("Load weights {} to start training".format(weights))
def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False
model = self.model
if self._nranks > 1:
model = paddle.DataParallel(
self.model,
find_unused_parameters=self.cfg.get("find_unused_parameters",
False))
self.status.update({
'epoch_id': self.start_epoch,
'step_id': 0,
'steps_per_epoch': len(self.loader)
})
self.status['batch_time'] = stats.SmoothedValue(
self.cfg.log_iter, fmt='{avg:.4f}')
self.status['data_time'] = stats.SmoothedValue(
self.cfg.log_iter, fmt='{avg:.4f}')
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
self._compose_callback.on_train_begin(self.status)
for epoch_id in range(self.start_epoch, self.cfg.epoch):
self.status['mode'] = 'train'
self.status['epoch_id'] = epoch_id
self._compose_callback.on_epoch_begin(self.status)
self.loader.dataset.set_epoch(epoch_id)
model.train()
iter_tic = time.time()
for step_id, data in enumerate(self.loader):
self.status['data_time'].update(time.time() - iter_tic)
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id
# model forward
outputs = model(data)
if self.distill_model is not None:
teacher_outputs = self.distill_model(data)
distill_loss = self.distill_loss(outputs, teacher_outputs,
data)
loss = outputs['loss'] + teacher_outputs[
"loss"] + distill_loss
else:
loss = outputs['loss']
# model backward
loss.backward()
self.optimizer.step()
curr_lr = self.optimizer.get_lr()
self.lr.step()
self.optimizer.clear_grad()
self.status['learning_rate'] = curr_lr
if self._nranks < 2 or self._local_rank == 0:
loss_dict = {"loss": outputs['loss']}
if self.distill_model is not None:
loss_dict.update({
"loss_student": outputs['loss'],
"loss_teacher": teacher_outputs["loss"],
"loss_distill": distill_loss,
"loss": loss
})
self.status['training_staus'].update(loss_dict)
self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status)
iter_tic = time.time()
self._compose_callback.on_epoch_end(self.status)
if validate and self._local_rank == 0 \
and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
or epoch_id == self.end_epoch - 1):
print("begin to eval...")
if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset
self._eval_batch_sampler = \
paddle.io.BatchSampler(
self._eval_dataset,
batch_size=self.cfg.EvalReader['batch_size'])
self._eval_loader = create('EvalReader')(
self._eval_dataset,
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
# if validation in training is enabled, metrics should be re-init
# Init_mark makes sure this code will only execute once
if validate and Init_mark == False:
Init_mark = True
self._init_metrics(validate=validate)
self._reset_metrics()
with paddle.no_grad():
self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader)
self._compose_callback.on_train_end(self.status)
def _eval_with_loader(self, loader):
sample_num = 0
tic = time.time()
self._compose_callback.on_epoch_begin(self.status)
self.status['mode'] = 'eval'
self.model.eval()
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
# forward
outs = self.model(data)
# update metrics
for metric in self._metrics:
metric.update(data, outs)
sample_num += data['im_id'].numpy().shape[0]
self._compose_callback.on_step_end(self.status)
self.status['sample_num'] = sample_num
self.status['cost_time'] = time.time() - tic
# accumulate metric to log out
for metric in self._metrics:
metric.accumulate()
metric.log()
self._compose_callback.on_epoch_end(self.status)
# reset metric states for metric may performed multiple times
self._reset_metrics()
def evaluate(self):
with paddle.no_grad():
self._eval_with_loader(self.loader)
def predict(self,
images,
draw_threshold=0.5,
output_dir='output',
save_txt=False):
self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0)
imid2path = self.dataset.get_imid2path()
anno_file = self.dataset.get_anno()
clsid2catid, catid2name = get_categories(
self.cfg.metric, anno_file=anno_file)
# Run Infer
self.status['mode'] = 'test'
self.model.eval()
results = []
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
# forward
outs = self.model(data)
for key in ['im_shape', 'scale_factor', 'im_id']:
outs[key] = data[key]
for key, value in outs.items():
if hasattr(value, 'numpy'):
outs[key] = value.numpy()
results.append(outs)
for outs in results:
batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
image = ImageOps.exif_transpose(image)
self.status['original_image'] = np.array(image.copy())
end = start + bbox_num[i]
bbox_res = batch_res['bbox'][start:end] \
if 'bbox' in batch_res else None
keypoint_res = batch_res['keypoint'][start:end] \
if 'keypoint' in batch_res else None
image = visualize_results(image, bbox_res, keypoint_res,
int(im_id), catid2name,
draw_threshold)
self.status['result_image'] = np.array(image.copy())
if self._compose_callback:
self._compose_callback.on_step_end(self.status)
# save image with detection
save_name = self._get_save_image_name(output_dir, image_path)
logger.info("Detection bbox results save in {}".format(
save_name))
image.save(save_name, quality=95)
if save_txt:
save_path = os.path.splitext(save_name)[0] + '.txt'
results = {}
results["im_id"] = im_id
if bbox_res:
results["bbox_res"] = bbox_res
if keypoint_res:
results["keypoint_res"] = keypoint_res
save_result(save_path, results, catid2name, draw_threshold)
start = end
def _get_save_image_name(self, output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
image_shape = [3, -1, -1]
im_shape = [None, 2]
scale_factor = [None, 2]
test_reader_name = 'TestReader'
if 'inputs_def' in self.cfg[test_reader_name]:
inputs_def = self.cfg[test_reader_name]['inputs_def']
image_shape = inputs_def.get('image_shape', None)
# set image_shape=[None, 3, -1, -1] as default
image_shape = [None] + image_shape
if hasattr(self.model, 'deploy'):
self.model.deploy = True
# Save infer cfg
_dump_infer_config(self.cfg,
os.path.join(save_dir, 'infer_cfg.yml'),
image_shape, self.model)
input_spec = [{
"image": InputSpec(
shape=image_shape, name='image'),
"im_shape": InputSpec(
shape=im_shape, name='im_shape'),
"scale_factor": InputSpec(
shape=scale_factor, name='scale_factor')
}]
if prune_input:
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = _prune_input_spec(
input_spec, static_model.forward.main_program,
static_model.forward.outputs)
else:
static_model = None
pruned_input_spec = input_spec
return static_model, pruned_input_spec
def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
save_dir)
# save model
if 'slim' not in self.cfg:
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
else:
self.cfg.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import category
from . import dataset
from . import keypoint_coco
from . import reader
from . import transform
from .category import *
from .dataset import *
from .keypoint_coco import *
from .reader import *
from .transform import *
__all__ = category.__all__ + dataset.__all__ + keypoint_coco.__all__ \
+ reader.__all__ + transform.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from lib.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['get_categories']
def get_categories(metric_type, anno_file=None, arch=None):
"""
Get class id to category id map and category id
to category name map from annotation file.
Args:
metric_type (str): metric type, currently support 'coco'.
anno_file (str): annotation file path
"""
if arch == 'keypoint_arch':
return (None, {'id': 'keypoint'})
if metric_type.lower() == 'keypointtopdowncocoeval' or metric_type.lower(
) == 'keypointtopdownmpiieval':
return (None, {'id': 'keypoint'})
else:
raise ValueError("unknown metric type {}".format(metric_type))
def _mot_category(category='pedestrian'):
"""
Get class id to category id map and category id
to category name map of mot dataset
"""
label_map = {category: 0}
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def _coco17_category():
"""
Get class id to category id map and category id
to category name map of COCO2017 dataset
"""
clsid2catid = {
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
12: 13,
13: 14,
14: 15,
15: 16,
16: 17,
17: 18,
18: 19,
19: 20,
20: 21,
21: 22,
22: 23,
23: 24,
24: 25,
25: 27,
26: 28,
27: 31,
28: 32,
29: 33,
30: 34,
31: 35,
32: 36,
33: 37,
34: 38,
35: 39,
36: 40,
37: 41,
38: 42,
39: 43,
40: 44,
41: 46,
42: 47,
43: 48,
44: 49,
45: 50,
46: 51,
47: 52,
48: 53,
49: 54,
50: 55,
51: 56,
52: 57,
53: 58,
54: 59,
55: 60,
56: 61,
57: 62,
58: 63,
59: 64,
60: 65,
61: 67,
62: 70,
63: 72,
64: 73,
65: 74,
66: 75,
67: 76,
68: 77,
69: 78,
70: 79,
71: 80,
72: 81,
73: 82,
74: 84,
75: 85,
76: 86,
77: 87,
78: 88,
79: 89,
80: 90
}
catid2name = {
0: 'background',
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
13: 'stop sign',
14: 'parking meter',
15: 'bench',
16: 'bird',
17: 'cat',
18: 'dog',
19: 'horse',
20: 'sheep',
21: 'cow',
22: 'elephant',
23: 'bear',
24: 'zebra',
25: 'giraffe',
27: 'backpack',
28: 'umbrella',
31: 'handbag',
32: 'tie',
33: 'suitcase',
34: 'frisbee',
35: 'skis',
36: 'snowboard',
37: 'sports ball',
38: 'kite',
39: 'baseball bat',
40: 'baseball glove',
41: 'skateboard',
42: 'surfboard',
43: 'tennis racket',
44: 'bottle',
46: 'wine glass',
47: 'cup',
48: 'fork',
49: 'knife',
50: 'spoon',
51: 'bowl',
52: 'banana',
53: 'apple',
54: 'sandwich',
55: 'orange',
56: 'broccoli',
57: 'carrot',
58: 'hot dog',
59: 'pizza',
60: 'donut',
61: 'cake',
62: 'chair',
63: 'couch',
64: 'potted plant',
65: 'bed',
67: 'dining table',
70: 'toilet',
72: 'tv',
73: 'laptop',
74: 'mouse',
75: 'remote',
76: 'keyboard',
77: 'cell phone',
78: 'microwave',
79: 'oven',
80: 'toaster',
81: 'sink',
82: 'refrigerator',
84: 'book',
85: 'clock',
86: 'vase',
87: 'scissors',
88: 'teddy bear',
89: 'hair drier',
90: 'toothbrush'
}
clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
catid2name.pop(0)
return clsid2catid, catid2name
def _dota_category():
"""
Get class id to category id map and category id
to category name map of dota dataset
"""
catid2name = {
0: 'background',
1: 'plane',
2: 'baseball-diamond',
3: 'bridge',
4: 'ground-track-field',
5: 'small-vehicle',
6: 'large-vehicle',
7: 'ship',
8: 'tennis-court',
9: 'basketball-court',
10: 'storage-tank',
11: 'soccer-ball-field',
12: 'roundabout',
13: 'harbor',
14: 'swimming-pool',
15: 'helicopter'
}
catid2name.pop(0)
clsid2catid = {i: i + 1 for i in range(len(catid2name))}
return clsid2catid, catid2name
def _oid19_category():
clsid2catid = {k: k + 1 for k in range(500)}
catid2name = {
0: "background",
1: "Infant bed",
2: "Rose",
3: "Flag",
4: "Flashlight",
5: "Sea turtle",
6: "Camera",
7: "Animal",
8: "Glove",
9: "Crocodile",
10: "Cattle",
11: "House",
12: "Guacamole",
13: "Penguin",
14: "Vehicle registration plate",
15: "Bench",
16: "Ladybug",
17: "Human nose",
18: "Watermelon",
19: "Flute",
20: "Butterfly",
21: "Washing machine",
22: "Raccoon",
23: "Segway",
24: "Taco",
25: "Jellyfish",
26: "Cake",
27: "Pen",
28: "Cannon",
29: "Bread",
30: "Tree",
31: "Shellfish",
32: "Bed",
33: "Hamster",
34: "Hat",
35: "Toaster",
36: "Sombrero",
37: "Tiara",
38: "Bowl",
39: "Dragonfly",
40: "Moths and butterflies",
41: "Antelope",
42: "Vegetable",
43: "Torch",
44: "Building",
45: "Power plugs and sockets",
46: "Blender",
47: "Billiard table",
48: "Cutting board",
49: "Bronze sculpture",
50: "Turtle",
51: "Broccoli",
52: "Tiger",
53: "Mirror",
54: "Bear",
55: "Zucchini",
56: "Dress",
57: "Volleyball",
58: "Guitar",
59: "Reptile",
60: "Golf cart",
61: "Tart",
62: "Fedora",
63: "Carnivore",
64: "Car",
65: "Lighthouse",
66: "Coffeemaker",
67: "Food processor",
68: "Truck",
69: "Bookcase",
70: "Surfboard",
71: "Footwear",
72: "Bench",
73: "Necklace",
74: "Flower",
75: "Radish",
76: "Marine mammal",
77: "Frying pan",
78: "Tap",
79: "Peach",
80: "Knife",
81: "Handbag",
82: "Laptop",
83: "Tent",
84: "Ambulance",
85: "Christmas tree",
86: "Eagle",
87: "Limousine",
88: "Kitchen & dining room table",
89: "Polar bear",
90: "Tower",
91: "Football",
92: "Willow",
93: "Human head",
94: "Stop sign",
95: "Banana",
96: "Mixer",
97: "Binoculars",
98: "Dessert",
99: "Bee",
100: "Chair",
101: "Wood-burning stove",
102: "Flowerpot",
103: "Beaker",
104: "Oyster",
105: "Woodpecker",
106: "Harp",
107: "Bathtub",
108: "Wall clock",
109: "Sports uniform",
110: "Rhinoceros",
111: "Beehive",
112: "Cupboard",
113: "Chicken",
114: "Man",
115: "Blue jay",
116: "Cucumber",
117: "Balloon",
118: "Kite",
119: "Fireplace",
120: "Lantern",
121: "Missile",
122: "Book",
123: "Spoon",
124: "Grapefruit",
125: "Squirrel",
126: "Orange",
127: "Coat",
128: "Punching bag",
129: "Zebra",
130: "Billboard",
131: "Bicycle",
132: "Door handle",
133: "Mechanical fan",
134: "Ring binder",
135: "Table",
136: "Parrot",
137: "Sock",
138: "Vase",
139: "Weapon",
140: "Shotgun",
141: "Glasses",
142: "Seahorse",
143: "Belt",
144: "Watercraft",
145: "Window",
146: "Giraffe",
147: "Lion",
148: "Tire",
149: "Vehicle",
150: "Canoe",
151: "Tie",
152: "Shelf",
153: "Picture frame",
154: "Printer",
155: "Human leg",
156: "Boat",
157: "Slow cooker",
158: "Croissant",
159: "Candle",
160: "Pancake",
161: "Pillow",
162: "Coin",
163: "Stretcher",
164: "Sandal",
165: "Woman",
166: "Stairs",
167: "Harpsichord",
168: "Stool",
169: "Bus",
170: "Suitcase",
171: "Human mouth",
172: "Juice",
173: "Skull",
174: "Door",
175: "Violin",
176: "Chopsticks",
177: "Digital clock",
178: "Sunflower",
179: "Leopard",
180: "Bell pepper",
181: "Harbor seal",
182: "Snake",
183: "Sewing machine",
184: "Goose",
185: "Helicopter",
186: "Seat belt",
187: "Coffee cup",
188: "Microwave oven",
189: "Hot dog",
190: "Countertop",
191: "Serving tray",
192: "Dog bed",
193: "Beer",
194: "Sunglasses",
195: "Golf ball",
196: "Waffle",
197: "Palm tree",
198: "Trumpet",
199: "Ruler",
200: "Helmet",
201: "Ladder",
202: "Office building",
203: "Tablet computer",
204: "Toilet paper",
205: "Pomegranate",
206: "Skirt",
207: "Gas stove",
208: "Cookie",
209: "Cart",
210: "Raven",
211: "Egg",
212: "Burrito",
213: "Goat",
214: "Kitchen knife",
215: "Skateboard",
216: "Salt and pepper shakers",
217: "Lynx",
218: "Boot",
219: "Platter",
220: "Ski",
221: "Swimwear",
222: "Swimming pool",
223: "Drinking straw",
224: "Wrench",
225: "Drum",
226: "Ant",
227: "Human ear",
228: "Headphones",
229: "Fountain",
230: "Bird",
231: "Jeans",
232: "Television",
233: "Crab",
234: "Microphone",
235: "Home appliance",
236: "Snowplow",
237: "Beetle",
238: "Artichoke",
239: "Jet ski",
240: "Stationary bicycle",
241: "Human hair",
242: "Brown bear",
243: "Starfish",
244: "Fork",
245: "Lobster",
246: "Corded phone",
247: "Drink",
248: "Saucer",
249: "Carrot",
250: "Insect",
251: "Clock",
252: "Castle",
253: "Tennis racket",
254: "Ceiling fan",
255: "Asparagus",
256: "Jaguar",
257: "Musical instrument",
258: "Train",
259: "Cat",
260: "Rifle",
261: "Dumbbell",
262: "Mobile phone",
263: "Taxi",
264: "Shower",
265: "Pitcher",
266: "Lemon",
267: "Invertebrate",
268: "Turkey",
269: "High heels",
270: "Bust",
271: "Elephant",
272: "Scarf",
273: "Barrel",
274: "Trombone",
275: "Pumpkin",
276: "Box",
277: "Tomato",
278: "Frog",
279: "Bidet",
280: "Human face",
281: "Houseplant",
282: "Van",
283: "Shark",
284: "Ice cream",
285: "Swim cap",
286: "Falcon",
287: "Ostrich",
288: "Handgun",
289: "Whiteboard",
290: "Lizard",
291: "Pasta",
292: "Snowmobile",
293: "Light bulb",
294: "Window blind",
295: "Muffin",
296: "Pretzel",
297: "Computer monitor",
298: "Horn",
299: "Furniture",
300: "Sandwich",
301: "Fox",
302: "Convenience store",
303: "Fish",
304: "Fruit",
305: "Earrings",
306: "Curtain",
307: "Grape",
308: "Sofa bed",
309: "Horse",
310: "Luggage and bags",
311: "Desk",
312: "Crutch",
313: "Bicycle helmet",
314: "Tick",
315: "Airplane",
316: "Canary",
317: "Spatula",
318: "Watch",
319: "Lily",
320: "Kitchen appliance",
321: "Filing cabinet",
322: "Aircraft",
323: "Cake stand",
324: "Candy",
325: "Sink",
326: "Mouse",
327: "Wine",
328: "Wheelchair",
329: "Goldfish",
330: "Refrigerator",
331: "French fries",
332: "Drawer",
333: "Treadmill",
334: "Picnic basket",
335: "Dice",
336: "Cabbage",
337: "Football helmet",
338: "Pig",
339: "Person",
340: "Shorts",
341: "Gondola",
342: "Honeycomb",
343: "Doughnut",
344: "Chest of drawers",
345: "Land vehicle",
346: "Bat",
347: "Monkey",
348: "Dagger",
349: "Tableware",
350: "Human foot",
351: "Mug",
352: "Alarm clock",
353: "Pressure cooker",
354: "Human hand",
355: "Tortoise",
356: "Baseball glove",
357: "Sword",
358: "Pear",
359: "Miniskirt",
360: "Traffic sign",
361: "Girl",
362: "Roller skates",
363: "Dinosaur",
364: "Porch",
365: "Human beard",
366: "Submarine sandwich",
367: "Screwdriver",
368: "Strawberry",
369: "Wine glass",
370: "Seafood",
371: "Racket",
372: "Wheel",
373: "Sea lion",
374: "Toy",
375: "Tea",
376: "Tennis ball",
377: "Waste container",
378: "Mule",
379: "Cricket ball",
380: "Pineapple",
381: "Coconut",
382: "Doll",
383: "Coffee table",
384: "Snowman",
385: "Lavender",
386: "Shrimp",
387: "Maple",
388: "Cowboy hat",
389: "Goggles",
390: "Rugby ball",
391: "Caterpillar",
392: "Poster",
393: "Rocket",
394: "Organ",
395: "Saxophone",
396: "Traffic light",
397: "Cocktail",
398: "Plastic bag",
399: "Squash",
400: "Mushroom",
401: "Hamburger",
402: "Light switch",
403: "Parachute",
404: "Teddy bear",
405: "Winter melon",
406: "Deer",
407: "Musical keyboard",
408: "Plumbing fixture",
409: "Scoreboard",
410: "Baseball bat",
411: "Envelope",
412: "Adhesive tape",
413: "Briefcase",
414: "Paddle",
415: "Bow and arrow",
416: "Telephone",
417: "Sheep",
418: "Jacket",
419: "Boy",
420: "Pizza",
421: "Otter",
422: "Office supplies",
423: "Couch",
424: "Cello",
425: "Bull",
426: "Camel",
427: "Ball",
428: "Duck",
429: "Whale",
430: "Shirt",
431: "Tank",
432: "Motorcycle",
433: "Accordion",
434: "Owl",
435: "Porcupine",
436: "Sun hat",
437: "Nail",
438: "Scissors",
439: "Swan",
440: "Lamp",
441: "Crown",
442: "Piano",
443: "Sculpture",
444: "Cheetah",
445: "Oboe",
446: "Tin can",
447: "Mango",
448: "Tripod",
449: "Oven",
450: "Mouse",
451: "Barge",
452: "Coffee",
453: "Snowboard",
454: "Common fig",
455: "Salad",
456: "Marine invertebrates",
457: "Umbrella",
458: "Kangaroo",
459: "Human arm",
460: "Measuring cup",
461: "Snail",
462: "Loveseat",
463: "Suit",
464: "Teapot",
465: "Bottle",
466: "Alpaca",
467: "Kettle",
468: "Trousers",
469: "Popcorn",
470: "Centipede",
471: "Spider",
472: "Sparrow",
473: "Plate",
474: "Bagel",
475: "Personal care",
476: "Apple",
477: "Brassiere",
478: "Bathroom cabinet",
479: "studio couch",
480: "Computer keyboard",
481: "Table tennis racket",
482: "Sushi",
483: "Cabinetry",
484: "Street light",
485: "Towel",
486: "Nightstand",
487: "Rabbit",
488: "Dolphin",
489: "Dog",
490: "Jug",
491: "Wok",
492: "Fire hydrant",
493: "Human eye",
494: "Skyscraper",
495: "Backpack",
496: "Potato",
497: "Paper towel",
498: "Lifejacket",
499: "Bicycle wheel",
500: "Toilet",
}
return clsid2catid, catid2name
def _visdrone_category():
clsid2catid = {i: i for i in range(10)}
catid2name = {
0: 'pedestrian',
1: 'people',
2: 'bicycle',
3: 'car',
4: 'van',
5: 'truck',
6: 'tricycle',
7: 'awning-tricycle',
8: 'bus',
9: 'motor'
}
return clsid2catid, catid2name
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from paddle.io import Dataset
import copy
from lib.utils.workspace import register, serializable
from lib.utils.download import get_dataset_path
__all__ = ['DetDataset', 'ImageFolder']
@serializable
class DetDataset(Dataset):
"""
Load detection dataset.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
use_default_label (bool): whether to load default label list.
"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
use_default_label=None,
**kwargs):
super(DetDataset, self).__init__()
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
self.anno_path = anno_path
self.image_dir = image_dir if image_dir is not None else ''
self.data_fields = data_fields
self.sample_num = sample_num
self.use_default_label = use_default_label
self._epoch = 0
self._curr_iter = 0
def __len__(self, ):
return len(self.roidbs)
def __getitem__(self, idx):
# data batch
roidb = copy.deepcopy(self.roidbs[idx])
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
n = len(self.roidbs)
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
n = len(self.roidbs)
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
n = len(self.roidbs)
roidb = [roidb, ] + [
copy.deepcopy(self.roidbs[np.random.randint(n)])
for _ in range(3)
]
if isinstance(roidb, Sequence):
for r in roidb:
r['curr_iter'] = self._curr_iter
else:
roidb['curr_iter'] = self._curr_iter
self._curr_iter += 1
return self.transform(roidb)
def check_or_download_dataset(self):
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
self.image_dir)
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def parse_dataset(self, ):
raise NotImplementedError(
"Need to implement parse_dataset method of Dataset")
def get_anno(self):
if self.anno_path is None:
return
return os.path.join(self.dataset_dir, self.anno_path)
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
return f.lower().endswith(extensions)
def _make_dataset(dir):
dir = os.path.expanduser(dir)
if not os.path.isdir(dir):
raise ('{} should be a dir'.format(dir))
images = []
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if _is_valid_file(path):
images.append(path)
return images
@register
@serializable
class ImageFolder(DetDataset):
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
sample_num=-1,
use_default_label=None,
**kwargs):
super(ImageFolder, self).__init__(
dataset_dir,
image_dir,
anno_path,
sample_num=sample_num,
use_default_label=use_default_label)
self._imid2path = {}
self.roidbs = None
self.sample_num = sample_num
def check_or_download_dataset(self):
if self.dataset_dir:
# NOTE: ImageFolder is only used for prediction, in
# infer mode, image_dir is set by set_images
# so we only check anno_path here
self.dataset_dir = get_dataset_path(self.dataset_dir,
self.anno_path, None)
def parse_dataset(self, ):
if not self.roidbs:
self.roidbs = self._load_images()
def _parse(self):
image_dir = self.image_dir
if not isinstance(image_dir, Sequence):
image_dir = [image_dir]
images = []
for im_dir in image_dir:
if os.path.isdir(im_dir):
im_dir = os.path.join(self.dataset_dir, im_dir)
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
return images
def _load_images(self):
images = self._parse()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import operators
from . import keypoint_operators
from .operators import *
from .keypoint_operators import *
__all__ = []
__all__ += registered_ops
__all__ += keypoint_operators.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import keypoint_metrics
from . import coco_utils
from . import json_results
from . import map_utils
from .keypoint_metrics import *
from .coco_utils import *
from .json_results import *
from .map_utils import *
__all__ = keypoint_metrics.__all__ + coco_utils.__all__ + json_results.__all__ + map_utils.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import hrnet
from . import lite_hrnet
from . import keypoint_hrnet
from . import loss
from .hrnet import *
from .keypoint_hrnet import *
from .loss import *
from .lite_hrnet import *
__all__ = hrnet.__all__ + keypoint_hrnet.__all__ \
+ loss.__all__
此差异已折叠。
此差异已折叠。
from . import quant
from .quant import *
import yaml
from lib.utils.workspace import load_config, create
from lib.utils.checkpoint import load_pretrain_weight
def build_slim_model(cfg, mode='train'):
assert cfg.slim == 'QAT', 'Only QAT is supported now'
model = create(cfg.architecture)
if mode == 'train':
load_pretrain_weight(model, cfg.pretrain_weights)
slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
# TODO: fix quant export model in framework.
if mode == 'test' and cfg.slim == 'QAT':
slim.quant_config['activation_preprocess_type'] = None
cfg['model'] = slim(model)
cfg['slim'] = slim
if mode != 'train':
load_pretrain_weight(cfg['model'], cfg.weights)
return cfg
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle.utils import try_import
from lib.utils.workspace import register, serializable
from lib.utils.logger import setup_logger
logger = setup_logger(__name__)
@register
@serializable
class QAT(object):
def __init__(self, quant_config, print_model):
super(QAT, self).__init__()
self.quant_config = quant_config
self.print_model = print_model
def __call__(self, model):
paddleslim = try_import('paddleslim')
self.quanter = paddleslim.dygraph.quant.QAT(config=self.quant_config)
if self.print_model:
logger.info("Model before quant:")
logger.info(model)
self.quanter.quantize(model)
if self.print_model:
logger.info("Quantized model:")
logger.info(model)
return model
def save_quantized_model(self, layer, path, input_spec=None, **config):
self.quanter.save_quantized_model(
model=layer, path=path, input_spec=input_spec, **config)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册