未验证 提交 46158530 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #56 from heavengate/add_metric_test

add test_metrics
......@@ -39,8 +39,8 @@ TSM模型是将Temporal Shift Module插入到ResNet网络中构建的视频分
```bash
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=$PYTHONPATH:`pwd`
cd tsm
export PYTHONPATH=`pwd`:$PYTHONPATH
cd examples/tsm
```
### 数据准备
......@@ -141,6 +141,8 @@ python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --inf
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
```
**注意:** 推断时`--infer_file`需要指定到pickle文件路径。
## 参考论文
- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han
......
......@@ -26,6 +26,7 @@ from check import check_gpu, check_version
from modeling import tsm_resnet50
from kinetics_dataset import KineticsDataset
from transforms import *
from utils import print_arguments
import logging
logger = logging.getLogger(__name__)
......@@ -56,7 +57,7 @@ def main():
model.load(FLAGS.weights, reset_optimizer=True)
imgs, label = dataset[0]
pred = model.test([imgs[np.newaxis, :]])
pred = model.test_batch([imgs[np.newaxis, :]])
pred = labels[np.argmax(pred)]
logger.info("Sample {} predict label: {}, ground truth label: {}" \
.format(FLAGS.infer_file, pred, labels[int(label)]))
......@@ -86,6 +87,7 @@ if __name__ == '__main__':
type=str,
help="weights path for evaluation")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
......
......@@ -113,7 +113,7 @@ class KineticsDataset(Dataset):
if self.transform:
imgs, label = self.transform(imgs, label)
return imgs, np.array([label])
return imgs, np.array([label]).astype('int64')
@property
def num_classes(self):
......
......@@ -31,6 +31,7 @@ from modeling import tsm_resnet50
from check import check_gpu, check_version
from kinetics_dataset import KineticsDataset
from transforms import *
from utils import print_arguments
def make_optimizer(step_per_epoch, parameter_list=None):
......@@ -106,7 +107,7 @@ def main():
eval_data=val_dataset,
epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size,
save_dir='tsm_checkpoint',
save_dir=FLAGS.save_dir or 'tsm_checkpoint',
num_workers=FLAGS.num_workers,
drop_last=True,
shuffle=True)
......@@ -150,7 +151,14 @@ if __name__ == '__main__':
default=None,
type=str,
help="weights path for evaluation")
parser.add_argument(
"-s",
"--save_dir",
default=None,
type=str,
help="directory path for checkpoint saving, default ./yolo_checkpoint")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import logging
logger = logging.getLogger(__name__)
__all__ = ['print_ar']
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
logger.info("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
......@@ -53,8 +53,8 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层
```bash
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=$PYTHONPATH:`pwd`
cd tsm
export PYTHONPATH=`pwd`:$PYTHONPATH
cd examples/yolov3
```
#### 安装COCO-API
......@@ -126,13 +126,13 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data=
使用如下方式进行多卡训练:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m paddle.distributed.launch --data=<path/to/dataset> --batch_size=16 -d
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=16 -d
```
### 模型评估
YOLOv3模型输出为LoDTensor,只支持使用batch_size为1进行评估,可通过如下两种方式进行模型评估。
YOLOv3模型输出为LoDTensor,只支持使用单卡且batch_size为1进行评估,可通过如下两种方式进行模型评估。
1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估
......@@ -180,7 +180,7 @@ python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.
2. 加载checkpoint进行精度评估
```bash
python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg --weights=yolo_checkpoint/mo_mixup/final
python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg --weights=yolo_checkpoint/no_mixup/final
```
推断结果可视化图像会保存于`--output`指定的文件夹下,默认保存于`./output`目录。
......
......@@ -28,7 +28,7 @@ from hapi.model import Model, Input, set_device
from modeling import yolov3_darknet53, YoloLoss
from transforms import *
from utils import print_arguments
from visualizer import draw_bbox
import logging
......@@ -91,7 +91,7 @@ def main():
img_id = np.array([0]).astype('int64')[np.newaxis, :]
img_shape = np.array([h, w]).astype('int32')[np.newaxis, :]
_, bboxes = model.test([img_id, img_shape, img])
_, bboxes = model.test_batch([img_id, img_shape, img])
vis_img = draw_bbox(orig_img, cat2name, bboxes, FLAGS.draw_threshold)
save_name = get_save_image_name(FLAGS.output_dir, FLAGS.infer_image)
......@@ -121,6 +121,7 @@ if __name__ == '__main__':
"-w", "--weights", default=None, type=str,
help="path to weights for inference")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
assert os.path.isfile(FLAGS.infer_image), \
"infer_image {} not a file".format(FLAGS.infer_image)
assert os.path.isfile(FLAGS.label_list), \
......
......@@ -33,6 +33,7 @@ from modeling import yolov3_darknet53, YoloLoss
from coco import COCODataset
from coco_metric import COCOMetric
from transforms import *
from utils import print_arguments
NUM_MAX_BOXES = 50
......@@ -171,16 +172,18 @@ def main():
if FLAGS.resume is not None:
model.load(FLAGS.resume)
save_dir = FLAGS.save_dir or 'yolo_checkpoint'
model.fit(train_data=loader,
epochs=FLAGS.epoch - FLAGS.no_mixup_epoch,
save_dir="yolo_checkpoint/mixup",
save_dir=os.path.join(save_dir, "mixup"),
save_freq=10)
# do not use image mixup transfrom in the last FLAGS.no_mixup_epoch epoches
dataset.mixup = False
model.fit(train_data=loader,
epochs=FLAGS.no_mixup_epoch,
save_dir="yolo_checkpoint/no_mixup",
save_dir=os.path.join(save_dir, "no_mixup"),
save_freq=5)
......@@ -233,6 +236,13 @@ if __name__ == '__main__':
default=None,
type=str,
help="path to weights for evaluation")
parser.add_argument(
"-s",
"--save_dir",
default=None,
type=str,
help="directory path for checkpoint saving, default ./yolo_checkpoint")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
assert FLAGS.data, "error: must provide data path"
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import logging
logger = logging.getLogger(__name__)
__all__ = ['print_ar']
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
logger.info("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
......@@ -123,7 +123,7 @@ class Flowers(Dataset):
if self.transform is not None:
image = self.transform(image)
return image, label
return image, label.astype('int64')
def __len__(self):
return len(self.indexes)
......@@ -144,7 +144,7 @@ class MNIST(Dataset):
for i in range(buffer_size):
self.images.append(images[i, :])
self.labels.append(np.array([labels[i]]))
self.labels.append(np.array([labels[i]]).astype('int64'))
def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx]
......
......@@ -116,7 +116,7 @@ class Accuracy(Metric):
def add_metric_op(self, pred, label, *args):
pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
correct = pred == label
return correct
return fluid.layers.cast(correct, dtype='float32')
def update(self, correct, *args):
accs = []
......@@ -143,7 +143,7 @@ class Accuracy(Metric):
if self.maxk != 1:
self._name = ['{}_top{}'.format(name, k) for k in self.topk]
else:
self._name = ['acc']
self._name = [name]
def name(self):
return self._name
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from hapi.metrics import *
from hapi.utils import to_list
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0]
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(correct_k / batch_size)
return res
def convert_to_one_hot(y, C):
oh = np.random.random((y.shape[0], C)).astype('float32') * .5
for i in range(y.shape[0]):
oh[i, int(y[i])] = 1.
return oh
class TestAccuracyDynamic(unittest.TestCase):
def setUp(self):
self.topk = (1, )
self.class_num = 5
self.sample_num = 1000
self.name = None
def random_pred_label(self):
label = np.random.randint(0, self.class_num, (self.sample_num, 1)).astype('int64')
pred = np.random.randint(0, self.class_num, (self.sample_num, 1)).astype('int32')
pred_one_hot = convert_to_one_hot(pred, self.class_num)
pred_one_hot = pred_one_hot.astype('float32')
return label, pred_one_hot
def test_main(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
acc = Accuracy(topk=self.topk, name=self.name)
for i in range(10):
label, pred = self.random_pred_label()
label_var = to_variable(label)
pred_var = to_variable(pred)
state = to_list(acc.add_metric_op(pred_var, label_var))
acc.update(*[s.numpy() for s in state])
res_m = acc.accumulate()
res_f = accuracy(pred, label, self.topk)
assert np.all(np.isclose(np.array(res_m), np.array(res_f), rtol=1e-3)), \
"Accuracy precision error: {} != {}".format(res_m, res_f)
acc.reset()
assert np.sum(acc.total) == 0
assert np.sum(acc.count) == 0
class TestAccuracyDynamicMultiTopk(TestAccuracyDynamic):
def setUp(self):
self.topk = (1, 5)
self.class_num = 10
self.sample_num = 1000
self.name = "accuracy"
class TestAccuracyStatic(TestAccuracyDynamic):
def test_main(self):
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
pred = fluid.data(name='pred', shape=[None, self.class_num], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
acc = Accuracy(topk=self.topk, name=self.name)
state = acc.add_metric_op(pred, label)
exe = fluid.Executor(fluid.CPUPlace())
compiled_main_prog = fluid.CompiledProgram(main_prog)
for i in range(10):
label, pred = self.random_pred_label()
state_ret = exe.run(compiled_main_prog,
feed={'pred': pred, 'label': label},
fetch_list=[s.name for s in to_list(state)],
return_numpy=True)
acc.update(*state_ret)
res_m = acc.accumulate()
res_f = accuracy(pred, label, self.topk)
assert np.all(np.isclose(np.array(res_m), np.array(res_f), rtol=1e-3)), \
"Accuracy precision error: {} != {}".format(res_m, res_f)
acc.reset()
assert np.sum(acc.total) == 0
assert np.sum(acc.count) == 0
class TestAccuracyStaticMultiTopk(TestAccuracyStatic):
def setUp(self):
self.topk = (1, 5)
self.class_num = 10
self.sample_num = 1000
self.name = "accuracy"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册