提交 2eb9948b 编写于 作者: D dongdaxiang

add image segmentation example with unet

上级 e2759c78
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, File2Image, Resize, Transpose, BGR2RGB, SegPostprocess
import sys
import cv2
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9494"])
preprocess = Sequential(
[File2Image(), Resize(
(512, 512), interpolation=cv2.INTER_LINEAR)])
postprocess = SegPostprocess(2)
for i in range(100):
filename = sys.argv[2]
im = preprocess(filename)
fetch_map = client.predict(feed={"image": im}, fetch=["output"])
fetch_map["filename"] = filename
postprocess(fetch_map)
...@@ -11,4 +11,4 @@ ...@@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, RCNNPostprocess from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, RCNNPostprocess, SegPostprocess
...@@ -23,6 +23,63 @@ import json ...@@ -23,6 +23,63 @@ import json
_cv2_interpolation_to_str = {cv2.INTER_LINEAR: "cv2.INTER_LINEAR", None: "None"} _cv2_interpolation_to_str = {cv2.INTER_LINEAR: "cv2.INTER_LINEAR", None: "None"}
def generate_colormap(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
class SegPostprocess(object):
def __init__(self, class_num):
self.class_num = class_num
def __call__(self, image_with_result):
if "filename" not in image_with_result:
raise ("filename should be specified in postprocess")
img_name = image_with_result["filename"]
ori_img = cv2.imread(img_name, -1)
ori_shape = ori_img.shape
mask = None
for key in image_with_result:
if ".lod" in key or "filename" in key:
continue
mask = image_with_result[key]
if mask is None:
raise ("segment mask should be specified in postprocess")
mask = mask.astype("uint8")
mask_png = mask.reshape((512, 512, 1))
#score_png = mask_png[:, :, np.newaxis]
score_png = mask_png
score_png = np.concatenate([score_png] * 3, axis=2)
color_map = generate_colormap(self.class_num)
for i in range(score_png.shape[0]):
for j in range(score_png.shape[1]):
score_png[i, j] = color_map[score_png[i, j, 0]]
ext_pos = img_name.rfind(".")
img_name_fix = img_name[:ext_pos] + "_" + img_name[ext_pos + 1:]
mask_save_name = img_name_fix + "_mask.png"
cv2.imwrite(mask_save_name, mask_png, [cv2.CV_8UC1])
vis_result_name = img_name_fix + "_result.png"
result_png = score_png
result_png = cv2.resize(
result_png,
ori_shape[:2],
fx=0,
fy=0,
interpolation=cv2.INTER_CUBIC)
cv2.imwrite(vis_result_name, result_png, [cv2.CV_8UC1])
class RCNNPostprocess(object): class RCNNPostprocess(object):
def __init__(self, label_file, output_dir): def __init__(self, label_file, output_dir):
self.output_dir = output_dir self.output_dir = output_dir
...@@ -238,6 +295,12 @@ class File2Image(object): ...@@ -238,6 +295,12 @@ class File2Image(object):
sample = fin.read() sample = fin.read()
data = np.fromstring(sample, np.uint8) data = np.fromstring(sample, np.uint8)
img = cv2.imdecode(data, cv2.IMREAD_COLOR) img = cv2.imdecode(data, cv2.IMREAD_COLOR)
'''
img = cv2.imread(img_path, -1)
channels = img.shape[2]
ori_h = img.shape[0]
ori_w = img.shape[1]
'''
return img return img
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册