未验证 提交 f64a4ccb 编写于 作者: G greatlog 提交者: GitHub

fix(keypoints) fix bug in test (#36)

上级 d8eaf9ad
...@@ -88,7 +88,19 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH ...@@ -88,7 +88,19 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH
| :--: |:--: |:--: |:--: | | :--: |:--: |:--: |:--: |
| Deeplabv3plus | Resnet101 | 79.0 | 79.8 | | Deeplabv3plus | Resnet101 | 79.0 | 79.8 |
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
### 人体关节点检测 ### 人体关节点检测
=======
### 人体关节点
>>>>>>> update readme
=======
### 人体关节点
>>>>>>> update readme
=======
### 人体关节点检测
>>>>>>> 3fdaf98eee3169f70ace463d54cd177ee1fcf68e
我们提供了人体关节点检测的经典模型[SimpleBaseline](https://arxiv.org/pdf/1804.06208.pdf)和高精度模型[MSPN](https://arxiv.org/pdf/1901.00148.pdf),使用在COCO val2017上人体检测AP为56的检测结果,提供的模型在COCO val2017上的关节点检测结果为: 我们提供了人体关节点检测的经典模型[SimpleBaseline](https://arxiv.org/pdf/1804.06208.pdf)和高精度模型[MSPN](https://arxiv.org/pdf/1901.00148.pdf),使用在COCO val2017上人体检测AP为56的检测结果,提供的模型在COCO val2017上的关节点检测结果为:
......
...@@ -38,12 +38,12 @@ ${COCO_DATA_ROOT} ...@@ -38,12 +38,12 @@ ${COCO_DATA_ROOT}
| |-- person_keypoints_val2017.json | |-- person_keypoints_val2017.json
|-- person_detection_results |-- person_detection_results
| |-- COCO_val2017_detections_AP_H_56_person.json | |-- COCO_val2017_detections_AP_H_56_person.json
|-- |-- train2017 |-- train2017
| |-- 000000000009.jpg | | |-- 000000000009.jpg
| |-- 000000000025.jpg | | |-- 000000000025.jpg
| |-- 000000000030.jpg | | |-- 000000000030.jpg
| |-- ... | | |-- ...
|-- val2017 |-- val2017
|-- 000000000139.jpg |-- 000000000139.jpg
|-- 000000000285.jpg |-- 000000000285.jpg
|-- 000000000632.jpg |-- 000000000632.jpg
...@@ -79,19 +79,27 @@ python3 train.py --arch mspn_4stage \ ...@@ -79,19 +79,27 @@ python3 train.py --arch mspn_4stage \
## 如何测试 ## 如何测试
模型训练好之后,可以通过如下命令测试模型在COCOval2017验证集的性能: 模型训练好之后,可以通过如下命令测试指定模型在COCOval2017验证集的性能:
```bash ```bash
python3 test.py --arch name/of/network \ python3 test.py --arch name/of/network \
--model /path/to/model.pkl \ --model /path/to/model.pkl \
--dt_file /name/human/detection/results --dt_file /name/human/detection/results
``` ```
`test.py`的命令行参数如下: `test.py`的命令行参数如下:
- `--arch`, 网络的名字; - `--arch`, 网络的名字;
- `--model`, 待检测的模; - `--model`, 待检测的模;
- `--dt_path`,人体检测结果. - `--dt_path`,人体检测结果.
也可以连续验证多个模型的性能:
```bash
python3 test.py --arch name/of/network \
--model_dir path/of/saved/models \
--start_epoch num/of/start/epoch \
--end_epoch num/of/end/epoch \
--test_freq test/frequence
```
## 如何使用 ## 如何使用
模型训练好之后,可以通过如下命令测试单张图片(先使用预训练的RetainNet检测出人的框),得到人体姿态可视化结果: 模型训练好之后,可以通过如下命令测试单张图片(先使用预训练的RetainNet检测出人的框),得到人体姿态可视化结果:
...@@ -111,5 +119,5 @@ python3 inference.py --arch /name/of/tested/network \ ...@@ -111,5 +119,5 @@ python3 inference.py --arch /name/of/tested/network \
## 参考文献 ## 参考文献
- [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/pdf/1804.06208.pdf), Bin Xiao, Haiping Wu, and Yichen Wei - [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/pdf/1804.06208.pdf) Bin Xiao, Haiping Wu, and Yichen Wei
- [Rethinking on Multi-Stage Networks for Human Pose Estimation](https://arxiv.org/pdf/1901.00148.pdf) Wenbo Li1, Zhicheng Wang, Binyi Yin, Qixiang Peng, Yuming Du, Tianzi Xiao, Gang Yu, Hongtao Lu, Yichen Wei and Jian Sun - [Rethinking on Multi-Stage Networks for Human Pose Estimation](https://arxiv.org/pdf/1901.00148.pdf) Wenbo Li1, Zhicheng Wang, Binyi Yin, Qixiang Peng, Yuming Du, Tianzi Xiao, Gang Yu, Hongtao Lu, Yichen Wei and Jian Sun
\ No newline at end of file
...@@ -11,5 +11,5 @@ from .simplebaseline import ( ...@@ -11,5 +11,5 @@ from .simplebaseline import (
simplebaseline_res101, simplebaseline_res101,
simplebaseline_res152, simplebaseline_res152,
) )
from .mspn import mspn_4stage from .mspn import mspn_4stage
...@@ -243,7 +243,7 @@ class MSPN(M.Module): ...@@ -243,7 +243,7 @@ class MSPN(M.Module):
@hub.pretrained( @hub.pretrained(
"https://data.megengine.org.cn/models/weights/mspn_4stage_256x192_0_255_75_2.pkl" "https://data.megengine.org.cn/models/weights/keypoint_models/mspn_4stage_0_255_75_2.pkl"
) )
def mspn_4stage(**kwargs): def mspn_4stage(**kwargs):
model = MSPN( model = MSPN(
......
...@@ -110,7 +110,7 @@ cfg = SimpleBaseline_Config() ...@@ -110,7 +110,7 @@ cfg = SimpleBaseline_Config()
@hub.pretrained( @hub.pretrained(
"https://data.megengine.org.cn/models/weights/simplebaseline50_256x192_0_255_71_2.pkl" "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline50_256x192_0_255_71_2.pkl"
) )
def simplebaseline_res50(**kwargs): def simplebaseline_res50(**kwargs):
...@@ -119,7 +119,7 @@ def simplebaseline_res50(**kwargs): ...@@ -119,7 +119,7 @@ def simplebaseline_res50(**kwargs):
@hub.pretrained( @hub.pretrained(
"https://data.megengine.org.cn/models/weights/simplebaseline101_256x192_0_255_72_2.pkl" "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline101_256x192_0_255_72_2.pkl"
) )
def simplebaseline_res101(**kwargs): def simplebaseline_res101(**kwargs):
...@@ -128,7 +128,7 @@ def simplebaseline_res101(**kwargs): ...@@ -128,7 +128,7 @@ def simplebaseline_res101(**kwargs):
@hub.pretrained( @hub.pretrained(
"https://data.megengine.org.cn/models/weights/simplebaseline152_256x192_0_255_72_4.pkl" "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline152_256x192_0_255_72_4.pkl"
) )
def simplebaseline_res152(**kwargs): def simplebaseline_res152(**kwargs):
......
...@@ -221,6 +221,9 @@ def make_parser(): ...@@ -221,6 +221,9 @@ def make_parser():
) )
parser.add_argument("-se", "--start_epoch", default=-1, type=int) parser.add_argument("-se", "--start_epoch", default=-1, type=int)
parser.add_argument("-ee", "--end_epoch", default=-1, type=int) parser.add_argument("-ee", "--end_epoch", default=-1, type=int)
parser.add_argument("-md", "--model_dir", default="/data/models/simplebaseline_res50_256x192/", type=str)
parser.add_argument("-tf", "--test_freq", default=1, type=int)
parser.add_argument( parser.add_argument(
"-a", "-a",
"--arch", "--arch",
...@@ -266,12 +269,12 @@ def main(): ...@@ -266,12 +269,12 @@ def main():
if args.end_epoch == -1: if args.end_epoch == -1:
args.end_epoch = args.start_epoch args.end_epoch = args.start_epoch
for epoch_num in range(args.start_epoch, args.end_epoch + 1): for epoch_num in range(args.start_epoch, args.end_epoch + 1, args.test_freq):
if args.model: if args.model:
model_file = args.model model_file = args.model
else: else:
model_file = "log-of-{}/epoch_{}.pkl".format( model_file = "{}/epoch_{}.pkl".format(
os.path.basename(args.file).split(".")[0], epoch_num args.model_dir, epoch_num
) )
logger.info("Load Model : %s completed", model_file) logger.info("Load Model : %s completed", model_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册