Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • PaddleDetection
  • Issue
  • #98

P
PaddleDetection
  • 项目概览

PaddlePaddle / PaddleDetection
大约 2 年 前同步成功

通知 708
Star 11112
Fork 2696
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 184
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 40
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
PaddleDetection
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 184
    • Issue 184
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 40
    • 合并请求 40
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 12月 10, 2019 by saxon_zh@saxon_zhGuest

【供大家参考】YOLOv3导出模型后使用Python预测及可视化示例

Created by: qingqing01

cd PaddleDetection
export PYTHONPATH=`pwd`:$PYTHONPATH

导出模型:

python tools/export_model.py -c configs/dcn/yolov3_r50vd_dcn.yml \
        -o weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn_imagenet.tar \
        --output_dir=inference_model \

导出的模型存储在 inference_model/yolov3_r50vd_dcn 下面.

文件infer.py 如下, 运行 python infer.py 即可得到 test_demo.jpg

import paddle.fluid as fluid
import numpy as np
import cv2
from PIL import Image, ImageDraw
from ppdet.utils.coco_eval import get_category_info

def Permute(im, channel_first=True, to_bgr=False):
    if channel_first:
        im = np.swapaxes(im, 1, 2)
        im = np.swapaxes(im, 1, 0)
    if to_bgr:
        im = im[[2, 1, 0], :, :]
    return im


def DecodeImage(im_path):
    with open(im_path, 'rb') as f:
        im = f.read()
    data = np.frombuffer(im, dtype='uint8')
    im = cv2.imdecode(data, 1)  # BGR mode, but need RGB mode
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    return im


def ResizeImage(im, target_size=608, max_size=0):
    if len(im.shape) != 3:
        raise ImageError('image is not 3-dimensional.')
    im_shape = im.shape
    print(im_shape)
    im_size_min = np.min(im_shape[0:2])
    im_size_max = np.max(im_shape[0:2])
    if float(im_size_min) == 0:
        raise ZeroDivisionError('min size of image is 0')
    if max_size != 0:
        im_scale = float(target_size) / float(im_size_min)
        # Prevent the biggest axis from being more than max_size
        if np.round(im_scale * im_size_max) > max_size:
            im_scale = float(max_size) / float(im_size_max)
        im_scale_x = im_scale
        im_scale_y = im_scale
    else:
        im_scale_x = float(target_size) / float(im_shape[1])
        im_scale_y = float(target_size) / float(im_shape[0])
    
    im = cv2.resize(
             im,
             None,
             None,
             fx=im_scale_x,
             fy=im_scale_y,
             interpolation=2)
    return im


def NormalizeImage(im,mean = [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], is_scale=True):
    """Normalize the image.
    Operators:
        1.(optional) Scale the image to [0,1]
        2. Each pixel minus mean and is divided by std
    """
    im = im.astype(np.float32, copy=False)
    mean = np.array(mean)[np.newaxis, np.newaxis, :]
    std = np.array(std)[np.newaxis, np.newaxis, :]
    if is_scale:
        im = im / 255.0
    im -= mean
    im /= std
    return im


def Prepocess(img_path):
    test_img = DecodeImage(img_path)
    img_shape = test_img.shape[:2]
    test_img = ResizeImage(test_img)
    test_img = NormalizeImage(test_img)
    test_img = Permute(test_img)
    test_img = test_img[np.newaxis,:]#.reshape(1, 3, 608, 608)
    return test_img, img_shape

def colormap(rgb=False):
    """
    Get colormap
    """
    color_list = np.array([
        0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494,
        0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078,
        0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000,
        1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000,
        0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667,
        0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000,
        0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000,
        1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000,
        0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500,
        0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667,
        0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333,
        0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000,
        0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333,
        0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000,
        1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000,
        1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167,
        0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000,
        0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000,
        0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000,
        0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000,
        0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833,
        0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286,
        0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714,
        0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000
    ]).astype(np.float32)
    color_list = color_list.reshape((-1, 3)) * 255
    if not rgb:
        color_list = color_list[:, ::-1]
    return color_list

def draw_bbox(image, catid2name, bboxes, threshold):
    """
    Draw bbox on image
    """
    draw = ImageDraw.Draw(image)

    catid2color = {}
    color_list = colormap(rgb=True)[:40]
    for dt in np.array(bboxes):
        catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
        if score < threshold:
            continue

        xmin, ymin, w, h = bbox
        xmax = xmin + w
        ymax = ymin + h

        if catid not in catid2color:
            idx = np.random.randint(len(color_list))
            catid2color[catid] = color_list[idx]
        color = tuple(catid2color[catid])

        # draw bbox
        draw.line(
            [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
             (xmin, ymin)],
            width=2,
            fill=color)

        # draw label
        text = "{} {:.2f}".format(catid2name[catid], score)
        tw, th = draw.textsize(text)
        draw.rectangle(
            [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
        draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))

    return image

def clip_bbox(bbox):
    xmin = max(min(bbox[0], 1.), 0.)
    ymin = max(min(bbox[1], 1.), 0.)
    xmax = max(min(bbox[2], 1.), 0.)
    ymax = max(min(bbox[3], 1.), 0.)
    return xmin, ymin, xmax, ymax

def bbox2out(results, clsid2catid, is_bbox_normalized=False):
    """
    Args:
        results: request a dict, should include: `bbox`, `im_id`,
                 if is_bbox_normalized=True, also need `im_shape`.
        clsid2catid: class id to category id map of COCO2017 dataset.
        is_bbox_normalized: whether or not bbox is normalized.
    """
    xywh_res = []
    for t in results:
        bboxes = t['bbox'][0]
        lengths = t['bbox'][1][0]
        if bboxes.shape == (1, 1) or bboxes is None:
            continue

        k = 0
        for i in range(len(lengths)):
            num = lengths[i]
            for j in range(num):
                dt = bboxes[k]
                clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
                catid = (clsid2catid[int(clsid)])

                if is_bbox_normalized:
                    xmin, ymin, xmax, ymax = \
                            clip_bbox([xmin, ymin, xmax, ymax])
                    w = xmax - xmin
                    h = ymax - ymin
                    im_height, im_width = t['im_shape'][i].tolist()
                    xmin *= im_width
                    ymin *= im_height
                    w *= im_width
                    h *= im_height
                else:
                    w = xmax - xmin + 1
                    h = ymax - ymin + 1

                bbox = [xmin, ymin, w, h]
                coco_res = {
                    'category_id': catid,
                    'bbox': bbox,
                    'score': score
                }
                xywh_res.append(coco_res)
                k += 1
    return xywh_res

def test():
    infer_prog = fluid.Program()
    startup_prog = fluid.Program()
    
    place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(startup_prog)
    
    path = "inference_model/yolov3_r50vd_dcn"
    img_path = "demo/000000014439.jpg"
    
    test_img, img_shape = Prepocess(img_path)
    print("shape of test_img:", test_img.shape)
    img_shape = np.array(img_shape).reshape(1, 2)
    img_shape = img_shape.astype('int32')
    print(img_shape.dtype)
    #exit()
    [inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model(
        dirname=path, executor=exe, model_filename='__model__', params_filename='__params__'))
    print(feed_target_names, test_img.shape, img_shape.shape)
    outs = exe.run(inference_program,
              feed={feed_target_names[0]: test_img, feed_target_names[1]: img_shape},
              fetch_list=fetch_targets,
              return_numpy=False)
    print(img_shape)
    res = {
             'bbox': (np.array(outs[0]), outs[0].recursive_sequence_lengths()),
              'im_shape': img_shape
          }

    clsid2catid, catid2name = get_category_info(None, False, True)
    bbox_results = bbox2out([res], clsid2catid, False)
    print(bbox_results)

    image = Image.open(img_path).convert('RGB')
    image = draw_bbox(image, catid2name, bbox_results, 0.5)
    image.save('test_demo.jpg', quality=95)

if __name__ == '__main__':
    test()
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/PaddleDetection#98
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7