提交 34c613c6 编写于 作者: D dengkaipeng

add stack_outputs

上级 6cb262e7
......@@ -1083,7 +1083,7 @@ class Model(fluid.dygraph.Layer):
return eval_result
def predict(self, test_data, batch_size=1, num_workers=0):
def predict(self, test_data, batch_size=1, num_workers=0, stack_outputs=True):
"""
FIXME: add more comments and usage
Args:
......@@ -1096,6 +1096,12 @@ class Model(fluid.dygraph.Layer):
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
stack_output (bool): whether stack output field like a batch, as for an output
filed of a sample is in shape [X, Y], test_data contains N samples, predict
output field will be in shape [N, X, Y] if stack_output is True, and will
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
is False. stack_outputs as False is used for LoDTensor output situation,
it is recommended set as True if outputs contains no LoDTensor. Default False
"""
if fluid.in_dygraph_mode():
......@@ -1127,10 +1133,11 @@ class Model(fluid.dygraph.Layer):
data = flatten(data)
outputs.append(self.test(data[:len(self._inputs)]))
# NOTE: we do not stack or concanate here for output
# lod tensor may loss its detail info, just pack sample
# list data to batch data
# NOTE: for lod tensor output, we should not stack outputs
# for stacking may loss its detail info
outputs = list(zip(*outputs))
if stack_outputs:
outputs = [np.stack(outs, axis=0) for outs in outputs]
self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \
......
......@@ -78,7 +78,7 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层
模型目前支持COCO数据集格式的数据读入和精度评估,我们同时提供了将转换为COCO数据集的格式的Pascal VOC数据集下载,可通过如下命令下载。
```bash
python dataset/voc/download.py
python dataset/download_voc.py
```
数据目录结构如下:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,9 +13,14 @@
# limitations under the License.
import os
import os.path as osp
import sys
import tarfile
from paddle.dataset.common import download
from download import _download
import logging
logger = logging.getLogger(__name__)
DATASETS = {
'voc': [
......@@ -26,7 +31,7 @@ DATASETS = {
def download_decompress_file(data_dir, url, md5):
logger.info("Downloading from {}".format(url))
tar_file = download(url, data_dir, md5)
tar_file = _download(url, data_dir, md5)
logger.info("Decompressing {}".format(tar_file))
with tarfile.open(tar_file) as tf:
tf.extractall(path=data_dir)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -147,7 +147,7 @@ def main():
if FLAGS.eval_only:
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
preds = model.predict(loader)
preds = model.predict(loader, stack_outputs=False)
_, _, _, img_ids, bboxes = preds
anno_path = os.path.join(FLAGS.data, 'annotations/instances_val2017.json')
......
......@@ -18,14 +18,27 @@ from __future__ import print_function
import numpy as np
from PIL import Image, ImageDraw
from colormap import colormap
import logging
logger = logging.getLogger(__name__)
__all__ = ['draw_bbox']
def color_map(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = np.array(color_map).reshape(-1, 3)
return color_map
def draw_bbox(image, catid2name, bboxes, threshold):
"""
Draw bbox on image
......@@ -38,7 +51,7 @@ def draw_bbox(image, catid2name, bboxes, threshold):
draw = ImageDraw.Draw(image)
catid2color = {}
color_list = colormap(rgb=True)[:40]
color_list = color_map(len(catid2name))
for bbox in bboxes:
catid, score, xmin, ymin, xmax, ymax = bbox
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册