未验证 提交 165f055d 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add inference time info and more models link (#5504)

* add inference time info and more models link

* rm unused images

* fix infer for image dir

* fix readme
上级 86d24e9f
......@@ -129,6 +129,10 @@ python tools/train.py -c configs/lite_hrnet_30_256x192_coco_pact.yml
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/lite_hrnet_30_256x192_coco_pact.yml
# training with PACT quantization on multi-GPU using distilled trained model
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/lite_hrnet_30_256x192_coco_dist_pact.yml
# GPU evaluation with PACT quantization
export CUDA_VISIBLE_DEVICES=0
python tools/eval.py -c configs/lite_hrnet_30_256x192_coco_pact.yml -o weights=https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_pact.pdparams
......@@ -163,23 +167,37 @@ python deploy/infer.py --model_dir=output_inference/lite_hrnet_30_256x192_coco_
```
## 3 Result
COCO Dataset
| Model | Input Size | AP(%, coco val) | Model Download | Config File | Inference model size |
| :----------: | -------- | :----------: | :------------: | :---: | :---: |
| HRNet-w32 | 256x192 | 76.9 | [hrnet_w32_256x192.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/hrnet_w32_256x192.pdparams) | [config](./configs/hrnet_w32_256x192.yml) | 118M |
| LiteHRNet-30 | 256x192 | 69.4 | [lite_hrnet_30_256x192_coco.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco.yml) | 26M |
| LiteHRNet-30-distillation | 256x192 | 69.9 |[lite_hrnet_30_256x192_coco_dist.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_dist.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco.yml) | 26M |
| LiteHRNet-30-PACT | 256x192 | 68.9 | [lite_hrnet_30_256x192_coco_pact.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_pact.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco_pact.yml) | 8.0M |
* COCO Dataset benchmark
| Model | Input Size | AP(%, coco val) | Model Download | Config File | Inference model size(M) | Inference time (ms/image) |
| :----------: | -------- | :----------: | :------------: | :---: | :---: | :---: |
| HRNet-w32 | 256x192 | 76.9 | [hrnet_w32_256x192.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/hrnet_w32_256x192.pdparams) | [config](./configs/hrnet_w32_256x192.yml) | 118 | 357.1 |
| LiteHRNet-30 | 256x192 | 69.4 | [lite_hrnet_30_256x192_coco.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco.yml) | 26 | 160.6
| LiteHRNet-30-distillation | 256x192 | 69.9 |[lite_hrnet_30_256x192_coco_dist.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_dist.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco.yml) | 26 | 160.6
| LiteHRNet-30-PACT | 256x192 | 68.9 | [lite_hrnet_30_256x192_coco_pact.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_pact.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco_pact.yml) | 8 | 156.7
| LiteHRNet-30-distillation-PACT | 256x192 | 70.2 | [lite_hrnet_30_256x192_coco_dist_pact.pdparams](https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_dist_pact.pdparams) | [config](./configs/lite_hrnet_30_256x192_coco_dist_pact.yml) | 8 | 156.7
**NOTE:**
* Inference model size is obtained by summing `pdiparams` and `pdmodel` file size.
* The inference time is tested on CPU(`Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz`) without MKLDNN using paddlepaddle-develop.
* note: Inference model size is obtained by summing `pdiparams` and `pdmodel` file size.
**Visualization**
Input:
![](./dataset/test_image/hrnet_demo.jpg)
Output:
![](./deploy/output/hrnet_demo_vis.jpg)
![](/dataset/test_image/hrnet_demo.jpg)
![](/deploy/output/hrnet_demo_vis.jpg)
## Citation
````
@inproceedings{cheng2020bottom,
title={Deep High-Resolution Representation Learning for Human Pose Estimation},
......
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/lite_hrnet_30_256x192_coco_dist_pact/model_final
epoch: 50
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
pretrain_weights: https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco_dist.pdparams
slim: QAT
QAT:
quant_config: {
'activation_preprocess_type': 'PACT',
'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: True
architecture: TopDownHRNet
TopDownHRNet:
backbone: LiteHRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 40
loss: KeyPointMSELoss
use_dark: false
LiteHRNet:
network_type: lite_30
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
loss_scale: 1.0
# optimizer
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
milestones: [40, 45]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 500
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.25
rot: 30
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
......@@ -2,7 +2,7 @@ use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/lite_hrnet_30_256x192_coco/model_final
weights: output/lite_hrnet_30_256x192_coco_pact/model_final
epoch: 50
num_joints: &num_joints 17
pixel_std: &pixel_std 200
......@@ -15,7 +15,7 @@ hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_30_256x192_coco.pdparams
pretrain_weights: https://paddle-model-ecology.bj.bcebos.com/model/hrnet_pose/lite_hrnet_30_256x192_coco.pdparams
slim: QAT
QAT:
quant_config: {
......
......@@ -335,10 +335,10 @@ def predict_image(detector, image_list, batch_size=1):
if FLAGS.run_benchmark:
# warmup
detector.predict(
image_list, FLAGS.threshold, repeats=3, add_timer=False)
[img_file], FLAGS.threshold, repeats=3, add_timer=False)
# run benchmark
detector.predict(
image_list, FLAGS.threshold, repeats=3, add_timer=True)
[img_file], FLAGS.threshold, repeats=3, add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
......@@ -346,7 +346,7 @@ def predict_image(detector, image_list, batch_size=1):
detector.gpu_util += gu
print('Test iter {}'.format(i))
else:
results = detector.predict(image_list, FLAGS.threshold)
results = detector.predict([img_file], FLAGS.threshold)
draw_pose(
img_file,
results,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册