postprocess.py 2.1 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
import cv2
import numpy as np


def generate_colormap(num_classes):
    color_map = num_classes * [0, 0, 0]
    for i in range(0, num_classes):
        j = 0
        lab = i
        while lab:
            color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
            color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
            color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
            j += 1
            lab >>= 3
    color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
    return color_map


class SegPostprocess(object):
    def __init__(self, class_num):
        self.class_num = class_num

    def __call__(self, image_with_result):
        if "filename" not in image_with_result:
            raise ("filename should be specified in postprocess")
        img_name = image_with_result["filename"]
        ori_img = cv2.imread(img_name, -1)
        ori_shape = ori_img.shape
        mask = None
        for key in image_with_result:
            if ".lod" in key or "filename" in key:
                continue
            mask = image_with_result[key][0]
        if mask is None:
            raise ("segment mask should be specified in postprocess")
        mask = np.argmax(mask, axis=0)
        mask = mask.astype("uint8")
        mask_png = mask
        score_png = mask_png[:, :, np.newaxis]
        score_png = np.concatenate([score_png] * 3, axis=2)

        color_map = generate_colormap(self.class_num)
        for i in range(score_png.shape[0]):
            for j in range(score_png.shape[1]):
                score_png[i, j] = color_map[score_png[i, j, 0]]
        ext_pos = img_name.rfind(".")
        img_name_fix = img_name[:ext_pos] + "_" + img_name[ext_pos + 1:]
        mask_save_name = img_name_fix + "_mask.png"
        cv2.imwrite(mask_save_name, mask_png, [cv2.CV_8UC1])
        vis_result_name = img_name_fix + "_result.png"
        result_png = score_png

        result_png = cv2.resize(result_png, (ori_shape[1], ori_shape[0]),
                                fx=0,
                                fy=0,
                                interpolation=cv2.INTER_CUBIC)
        cv2.imwrite(vis_result_name, result_png, [cv2.CV_8UC1])