reader.py 4.4 KB
Newer Older
W
wuyefeilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16 17
import numpy as np
import paddle.fluid as fluid
L
LutaoChu 已提交
18
from config import cfg
W
wuzewu 已提交
19 20
import cv2

W
wuyefeilin 已提交
21

W
wuzewu 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
def get_affine_points(src_shape, dst_shape, rot_grad=0):
    # 获取图像和仿射后图像的三组对应点坐标
    # 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
    if dst_shape[0] == 0 or dst_shape[1] == 0:
        raise Exception('scale shape should not be 0')

    # 旋转角度
    rotation = rot_grad * np.pi / 180.0
    sin_v = np.sin(rotation)
    cos_v = np.cos(rotation)

    dst_ratio = float(dst_shape[0]) / dst_shape[1]
    h, w = src_shape
    src_ratio = float(h) / w if w != 0 else 0
    affine_shape = [h, h * dst_ratio] if src_ratio > dst_ratio \
                    else [w / dst_ratio, w]

    # 原始图像三组点
    points = [[0, 0]] * 3
W
wuyefeilin 已提交
41
    points[0] = (np.array([w, h]) - 1) * 0.5
W
wuzewu 已提交
42 43 44 45 46 47 48 49 50 51
    points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v])
    points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v])

    # 仿射变换后图三组点
    points_trans = [[0, 0]] * 3
    points_trans[0] = (np.array(dst_shape[::-1]) - 1) * 0.5
    points_trans[1] = [points_trans[0][0], 0]

    return points, points_trans

W
wuyefeilin 已提交
52

W
wuzewu 已提交
53 54 55 56 57 58 59 60
def preprocess(im):
    # ACE2P模型数据预处理
    im_shape = im.shape[:2]
    input_images = []
    for i, scale in enumerate(cfg.multi_scales):
        # 获取图像和仿射变换后图像的对应点坐标
        points, points_trans = get_affine_points(im_shape, scale)
        # 根据对应点集获得仿射矩阵
W
wuyefeilin 已提交
61 62
        trans = cv2.getAffineTransform(
            np.float32(points), np.float32(points_trans))
W
wuzewu 已提交
63
        # 根据仿射矩阵对图像进行仿射
W
wuyefeilin 已提交
64
        input = cv2.warpAffine(im, trans, scale[::-1], flags=cv2.INTER_LINEAR)
W
wuzewu 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

        # 减均值测,除以方差,转换数据格式为NCHW
        input = input.astype(np.float32)
        input = (input / 255. - np.array(cfg.MEAN)) / np.array(cfg.STD)
        input = input.transpose(2, 0, 1).astype(np.float32)
        input = np.expand_dims(input, 0)

        # 水平翻转
        if cfg.flip:
            flip_input = input[:, :, :, ::-1]
            input_images.append(np.vstack((input, flip_input)))
        else:
            input_images.append(input)

    return input_images


W
wuyefeilin 已提交
82 83 84
def multi_scale_test(exe, test_prog, feed_name, fetch_list, input_ims,
                     im_shape):

W
wuzewu 已提交
85 86 87
    # 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
    flipped_idx = (15, 14, 17, 16, 19, 18)
    ms_outputs = []
W
wuyefeilin 已提交
88

W
wuzewu 已提交
89 90 91
    # 多尺度预测
    for idx, scale in enumerate(cfg.multi_scales):
        input_im = input_ims[idx]
W
wuyefeilin 已提交
92 93 94 95
        parsing_output = exe.run(
            program=test_prog,
            feed={feed_name[0]: input_im},
            fetch_list=fetch_list)
W
wuzewu 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108
        output = parsing_output[0][0]
        if cfg.flip:
            # 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
            flipped_output = parsing_output[0][1]
            flipped_output[14:20, :, :] = flipped_output[flipped_idx, :, :]
            flipped_output = flipped_output[:, :, ::-1]
            output += flipped_output
            output *= 0.5

        output = np.transpose(output, [1, 2, 0])
        # 仿射变换回图像原始尺寸
        points, points_trans = get_affine_points(im_shape, scale)
        M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points))
W
wuyefeilin 已提交
109 110
        logits_result = cv2.warpAffine(
            output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
W
wuzewu 已提交
111 112 113 114 115 116 117
        ms_outputs.append(logits_result)

    # 多尺度预测结果求均值,求预测概率最大的类别
    ms_fused_parsing_output = np.stack(ms_outputs)
    ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0)
    parsing = np.argmax(ms_fused_parsing_output, axis=2)
    return parsing, ms_fused_parsing_output