提交 34c613c6 编写于 作者: D dengkaipeng

add stack_outputs

上级 6cb262e7
...@@ -1083,7 +1083,7 @@ class Model(fluid.dygraph.Layer): ...@@ -1083,7 +1083,7 @@ class Model(fluid.dygraph.Layer):
return eval_result return eval_result
def predict(self, test_data, batch_size=1, num_workers=0): def predict(self, test_data, batch_size=1, num_workers=0, stack_outputs=True):
""" """
FIXME: add more comments and usage FIXME: add more comments and usage
Args: Args:
...@@ -1096,6 +1096,12 @@ class Model(fluid.dygraph.Layer): ...@@ -1096,6 +1096,12 @@ class Model(fluid.dygraph.Layer):
num_workers (int): the number of subprocess to load data, 0 for no subprocess num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored. both the instance of Dataloader, this parameter will be ignored.
stack_output (bool): whether stack output field like a batch, as for an output
filed of a sample is in shape [X, Y], test_data contains N samples, predict
output field will be in shape [N, X, Y] if stack_output is True, and will
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
is False. stack_outputs as False is used for LoDTensor output situation,
it is recommended set as True if outputs contains no LoDTensor. Default False
""" """
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
...@@ -1127,10 +1133,11 @@ class Model(fluid.dygraph.Layer): ...@@ -1127,10 +1133,11 @@ class Model(fluid.dygraph.Layer):
data = flatten(data) data = flatten(data)
outputs.append(self.test(data[:len(self._inputs)])) outputs.append(self.test(data[:len(self._inputs)]))
# NOTE: we do not stack or concanate here for output # NOTE: for lod tensor output, we should not stack outputs
# lod tensor may loss its detail info, just pack sample # for stacking may loss its detail info
# list data to batch data
outputs = list(zip(*outputs)) outputs = list(zip(*outputs))
if stack_outputs:
outputs = [np.stack(outs, axis=0) for outs in outputs]
self._test_dataloader = None self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \ if test_loader is not None and self._adapter._nranks > 1 \
......
...@@ -78,7 +78,7 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层 ...@@ -78,7 +78,7 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层
模型目前支持COCO数据集格式的数据读入和精度评估,我们同时提供了将转换为COCO数据集的格式的Pascal VOC数据集下载,可通过如下命令下载。 模型目前支持COCO数据集格式的数据读入和精度评估,我们同时提供了将转换为COCO数据集的格式的Pascal VOC数据集下载,可通过如下命令下载。
```bash ```bash
python dataset/voc/download.py python dataset/download_voc.py
``` ```
数据目录结构如下: 数据目录结构如下:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,9 +13,14 @@ ...@@ -13,9 +13,14 @@
# limitations under the License. # limitations under the License.
import os import os
import os.path as osp
import sys
import tarfile import tarfile
from paddle.dataset.common import download from download import _download
import logging
logger = logging.getLogger(__name__)
DATASETS = { DATASETS = {
'voc': [ 'voc': [
...@@ -26,7 +31,7 @@ DATASETS = { ...@@ -26,7 +31,7 @@ DATASETS = {
def download_decompress_file(data_dir, url, md5): def download_decompress_file(data_dir, url, md5):
logger.info("Downloading from {}".format(url)) logger.info("Downloading from {}".format(url))
tar_file = download(url, data_dir, md5) tar_file = _download(url, data_dir, md5)
logger.info("Decompressing {}".format(tar_file)) logger.info("Decompressing {}".format(tar_file))
with tarfile.open(tar_file) as tf: with tarfile.open(tar_file) as tf:
tf.extractall(path=data_dir) tf.extractall(path=data_dir)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -147,7 +147,7 @@ def main(): ...@@ -147,7 +147,7 @@ def main():
if FLAGS.eval_only: if FLAGS.eval_only:
if FLAGS.weights is not None: if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True) model.load(FLAGS.weights, reset_optimizer=True)
preds = model.predict(loader) preds = model.predict(loader, stack_outputs=False)
_, _, _, img_ids, bboxes = preds _, _, _, img_ids, bboxes = preds
anno_path = os.path.join(FLAGS.data, 'annotations/instances_val2017.json') anno_path = os.path.join(FLAGS.data, 'annotations/instances_val2017.json')
......
...@@ -18,14 +18,27 @@ from __future__ import print_function ...@@ -18,14 +18,27 @@ from __future__ import print_function
import numpy as np import numpy as np
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from colormap import colormap
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['draw_bbox'] __all__ = ['draw_bbox']
def color_map(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = np.array(color_map).reshape(-1, 3)
return color_map
def draw_bbox(image, catid2name, bboxes, threshold): def draw_bbox(image, catid2name, bboxes, threshold):
""" """
Draw bbox on image Draw bbox on image
...@@ -38,7 +51,7 @@ def draw_bbox(image, catid2name, bboxes, threshold): ...@@ -38,7 +51,7 @@ def draw_bbox(image, catid2name, bboxes, threshold):
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
catid2color = {} catid2color = {}
color_list = colormap(rgb=True)[:40] color_list = color_map(len(catid2name))
for bbox in bboxes: for bbox in bboxes:
catid, score, xmin, ymin, xmax, ymax = bbox catid, score, xmin, ymin, xmax, ymax = bbox
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册