未验证 提交 ba3d10fa 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

add paddleseg in matting (#2138)

* add paddleseg

* update codex
上级 7aa70fd5
......@@ -69,7 +69,7 @@
model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"])
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result)
```
- ### 3、API
......
......@@ -70,7 +70,7 @@
model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"])
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result)
```
- ### 3、API
......
......@@ -68,7 +68,7 @@
import cv2
model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"])
result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result)
```
- ### 3、API
......@@ -150,4 +150,3 @@
* 1.0.0
初始发布
......@@ -69,7 +69,7 @@
import cv2
model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"])
result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result)
```
......
......@@ -11,44 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union, List, Tuple
from typing import Callable
from typing import List
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from gfm_resnet34_matting.resnet import resnet34
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> Callable:
"""3x3 convolution with padding"""
return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias_attr=False)
return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)
def conv_up_psp(in_channels: int, out_channels: int, up_sample: float) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2D(out_channels),
nn.ReLU(),
nn.Upsample(scale_factor=up_sample, mode='bilinear',align_corners = False))
return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1), nn.BatchNorm2D(out_channels), nn.ReLU(),
nn.Upsample(scale_factor=up_sample, mode='bilinear', align_corners=False))
def build_bb(in_channels: int, mid_channels: int, out_channels: int) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2,
padding=2), nn.BatchNorm2D(mid_channels), nn.
ReLU(), nn.Conv2D(mid_channels, out_channels, 3,
dilation=2, padding=2), nn.BatchNorm2D(out_channels), nn.ReLU(), nn.Conv2D(out_channels,
out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(
out_channels), nn.ReLU())
return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(mid_channels),
nn.ReLU(), nn.Conv2D(mid_channels, out_channels, 3, dilation=2, padding=2),
nn.BatchNorm2D(out_channels), nn.ReLU(),
nn.Conv2D(out_channels, out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(out_channels),
nn.ReLU())
def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int,
last_bnrelu: bool, upsample_flag: bool) -> Callable:
def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int, last_bnrelu: bool,
upsample_flag: bool) -> Callable:
layers = []
layers += [nn.Conv2D(in_channels, mid_channels_1, 3, padding=1), nn.
BatchNorm2D(mid_channels_1), nn.ReLU(), nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1), nn.
BatchNorm2D(mid_channels_2), nn.ReLU(), nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)]
layers += [
nn.Conv2D(in_channels, mid_channels_1, 3, padding=1),
nn.BatchNorm2D(mid_channels_1),
nn.ReLU(),
nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1),
nn.BatchNorm2D(mid_channels_2),
nn.ReLU(),
nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)
]
if last_bnrelu:
layers += [nn.BatchNorm2D(out_channels), nn.ReLU()]
......@@ -61,6 +62,7 @@ def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, ou
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
......@@ -90,10 +92,8 @@ class PSPModule(nn.Layer):
def __init__(self, features: paddle.Tensor, out_features: int = 1024, sizes: List[int] = (1, 2, 3, 6)):
super().__init__()
#self.stages = []
self.stages = nn.LayerList([self._make_stage(features, size) for
size in sizes])
self.bottleneck = nn.Conv2D(features * (len(sizes) + 1),
out_features, kernel_size=1)
self.stages = nn.LayerList([self._make_stage(features, size) for size in sizes])
self.bottleneck = nn.Conv2D(features * (len(sizes) + 1), out_features, kernel_size=1)
self.relu = nn.ReLU()
def _make_stage(self, features: paddle.Tensor, size: int) -> Callable:
......@@ -103,7 +103,8 @@ class PSPModule(nn.Layer):
def forward(self, feats: paddle.Tensor) -> paddle.Tensor:
h, w = feats.shape[2], feats.shape[3]
priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear',align_corners = True) for stage in self.stages] + [feats]
priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear', align_corners=True)
for stage in self.stages] + [feats]
bottle = self.bottleneck(paddle.concat(priors, 1))
return self.relu(bottle)
......@@ -113,10 +114,8 @@ class SELayer(nn.Layer):
def __init__(self, channel: int, reduction: int = 4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction,
bias_attr=False), nn.ReLU(), nn.
Linear(channel // reduction, channel, bias_attr=False), nn.
Sigmoid())
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias_attr=False), nn.ReLU(),
nn.Linear(channel // reduction, channel, bias_attr=False), nn.Sigmoid())
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
b, c, _, _ = x.size()
......@@ -150,18 +149,15 @@ class GFM(nn.Layer):
self.gd_channel = 2
if self.backbone == 'r34_2b':
self.resnet = resnet34()
self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1),
nn.BatchNorm2D(64), nn.ReLU())
self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1), nn.BatchNorm2D(64), nn.ReLU())
self.encoder1 = self.resnet.layer1
self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4
self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock(
512, 512))
self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock(
512, 512))
self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
BasicBlock(512, 512), BasicBlock(512, 512))
self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
BasicBlock(512, 512), BasicBlock(512, 512))
self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp6 = conv_up_psp(512, 512, 2)
self.psp5 = conv_up_psp(512, 512, 4)
......@@ -183,28 +179,19 @@ class GFM(nn.Layer):
self.decoder2_f = build_decoder(256, 128, 128, 64, True, True)
self.decoder1_f = build_decoder(128, 64, 64, 64, True, False)
if self.rosta == 'RIM':
self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3,
padding=1))
self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3,
padding=1))
self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3,
padding=1))
self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3, padding=1))
self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
else:
self.decoder0_g = nn.Sequential(nn.Conv2D(64, self.
gd_channel, 3, padding=1))
self.decoder0_g = nn.Sequential(nn.Conv2D(64, self.gd_channel, 3, padding=1))
self.decoder0_f = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
if self.backbone == 'r34':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.
bn1, self.resnet.relu)
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.
layer1)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4
......@@ -230,14 +217,11 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else:
self.decoder0_g = build_decoder(128, 64, 64, self.
gd_channel, False, True)
self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'r101':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.
bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.
layer1)
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4
......@@ -263,22 +247,16 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else:
self.decoder0_g = build_decoder(128, 64, 64, self.
gd_channel, False, True)
self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'd121':
self.encoder0 = nn.Sequential(self.densenet.features.conv0,
self.densenet.features.norm0, self.densenet.features.relu0)
self.encoder1 = nn.Sequential(self.densenet.features.
denseblock1, self.densenet.features.transition1)
self.encoder2 = nn.Sequential(self.densenet.features.
denseblock2, self.densenet.features.transition2)
self.encoder3 = nn.Sequential(self.densenet.features.
denseblock3, self.densenet.features.transition3)
self.encoder4 = nn.Sequential(self.densenet.features.
denseblock4, nn.Conv2D(1024, 512, 3, padding=1), nn.
BatchNorm2D(512), nn.ReLU(),
nn.MaxPool2D(2, 2, ceil_mode=True))
self.encoder0 = nn.Sequential(self.densenet.features.conv0, self.densenet.features.norm0,
self.densenet.features.relu0)
self.encoder1 = nn.Sequential(self.densenet.features.denseblock1, self.densenet.features.transition1)
self.encoder2 = nn.Sequential(self.densenet.features.denseblock2, self.densenet.features.transition2)
self.encoder3 = nn.Sequential(self.densenet.features.denseblock3, self.densenet.features.transition3)
self.encoder4 = nn.Sequential(self.densenet.features.denseblock4, nn.Conv2D(1024, 512, 3, padding=1),
nn.BatchNorm2D(512), nn.ReLU(), nn.MaxPool2D(2, 2, ceil_mode=True))
self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp4 = conv_up_psp(512, 256, 2)
self.psp3 = conv_up_psp(512, 128, 4)
......@@ -301,12 +279,10 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else:
self.decoder0_g = build_decoder(128, 64, 64, self.
gd_channel, False, True)
self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
if self.rosta == 'RIM':
self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn.
Conv2D(16, 1, 1))
self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn.Conv2D(16, 1, 1))
def forward(self, input: paddle.Tensor) -> List[paddle.Tensor]:
glance_sigmoid = paddle.zeros(input.shape)
......@@ -325,10 +301,8 @@ class GFM(nn.Layer):
e6 = self.encoder6(e5)
psp = self.psp_module(e6)
d6_g = self.decoder6_g(paddle.concat((psp, e6), 1))
d5_g = self.decoder5_g(paddle.concat((self.psp6(psp),
d6_g), 1))
d4_g = self.decoder4_g(paddle.concat((self.psp5(psp),
d5_g), 1))
d5_g = self.decoder5_g(paddle.concat((self.psp6(psp), d6_g), 1))
d4_g = self.decoder4_g(paddle.concat((self.psp5(psp), d5_g), 1))
else:
psp = self.psp_module(e4)
d4_g = self.decoder4_g(paddle.concat((psp, e4), 1))
......@@ -343,15 +317,11 @@ class GFM(nn.Layer):
else:
d0_g = self.decoder0_g(d1_g)
elif self.rosta == 'RIM':
d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp
), d1_g), 1))
d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp
), d1_g), 1))
d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp
), d1_g), 1))
d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp), d1_g), 1))
d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp), d1_g), 1))
d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp), d1_g), 1))
else:
d0_g = self.decoder0_g(paddle.concat((self.psp1(psp),
d1_g), 1))
d0_g = self.decoder0_g(paddle.concat((self.psp1(psp), d1_g), 1))
if self.rosta == 'RIM':
glance_sigmoid_tt = F.sigmoid(d0_g_tt)
glance_sigmoid_ft = F.sigmoid(d0_g_ft)
......@@ -389,22 +359,16 @@ class GFM(nn.Layer):
else:
focus_sigmoid = F.sigmoid(d0_f)
if self.rosta == 'RIM':
fusion_sigmoid_tt = collaborative_matting('TT',
glance_sigmoid_tt, focus_sigmoid_tt)
fusion_sigmoid_ft = collaborative_matting('FT',
glance_sigmoid_ft, focus_sigmoid_ft)
fusion_sigmoid_bt = collaborative_matting('BT',
glance_sigmoid_bt, focus_sigmoid_bt)
fusion_sigmoid = paddle.concat((fusion_sigmoid_tt,
fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid_tt = collaborative_matting('TT', glance_sigmoid_tt, focus_sigmoid_tt)
fusion_sigmoid_ft = collaborative_matting('FT', glance_sigmoid_ft, focus_sigmoid_ft)
fusion_sigmoid_bt = collaborative_matting('BT', glance_sigmoid_bt, focus_sigmoid_bt)
fusion_sigmoid = paddle.concat((fusion_sigmoid_tt, fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid = self.rim(fusion_sigmoid)
return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt
], [glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft],
[glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt],
fusion_sigmoid]
return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt],
[glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft],
[glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt], fusion_sigmoid]
else:
fusion_sigmoid = collaborative_matting(self.rosta,
glance_sigmoid, focus_sigmoid)
fusion_sigmoid = collaborative_matting(self.rosta, glance_sigmoid, focus_sigmoid)
return glance_sigmoid, focus_sigmoid, fusion_sigmoid
......@@ -412,7 +376,7 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
if rosta == 'TT':
values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float()
index = index[:, None, :, :].cast(paddle.float32)
bg_mask = index.clone()
bg_mask[bg_mask == 2] = 1
bg_mask = 1 - bg_mask
......@@ -428,13 +392,13 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
elif rosta == 'BT':
values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float()
index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index - focus_sigmoid
fusion_sigmoid[fusion_sigmoid < 0] = 0
else:
values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float()
index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index + focus_sigmoid
fusion_sigmoid[fusion_sigmoid > 1] = 1
return fusion_sigmoid
......@@ -442,6 +406,6 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
if __name__ == "__main__":
model = GFM()
x = paddle.ones([1,3, 256,256])
x = paddle.ones([1, 3, 256, 256])
result = model(x)
print(x)
......@@ -11,31 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import argparse
from typing import Callable, Union, List, Tuple
from typing import List
from typing import Union
from PIL import Image
import numpy as np
import cv2
import scipy
import gfm_resnet34_matting.processor as P
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlehub.module.module import moduleinfo
import paddlehub.vision.transforms as T
from paddlehub.module.module import moduleinfo, runnable, serving
from gfm_resnet34_matting.gfm import GFM
from PIL import Image
from skimage.transform import resize
from gfm_resnet34_matting.gfm import GFM
import gfm_resnet34_matting.processor as P
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(
name="gfm_resnet34_matting",
@moduleinfo(name="gfm_resnet34_matting",
type="CV/matting",
author="paddlepaddle",
author_email="",
......@@ -53,7 +49,7 @@ class GFMResNet34(nn.Layer):
Paper link (Arxiv): https://arxiv.org/abs/2010.16188
"""
def __init__(self, pretrained: str=None):
def __init__(self, pretrained: str = None):
super(GFMResNet34, self).__init__()
self.model = GFM()
......@@ -76,44 +72,42 @@ class GFMResNet34(nn.Layer):
tensor_img = self.scale_image(img, h, w)
return tensor_img
def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1/3):
def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1 / 3):
new_h = min(1600, h - (h % 32))
new_w = min(1600, w - (w % 32))
resize_h = int(h*ratio)
resize_w = int(w*ratio)
resize_h = int(h * ratio)
resize_w = int(w * ratio)
new_h = min(1600, resize_h - (resize_h % 32))
new_w = min(1600, resize_w - (resize_w % 32))
scale_img = resize(img,(new_h,new_w)) * 255
scale_img = resize(img, (new_h, new_w)) * 255
tensor_img = paddle.to_tensor(scale_img.astype(np.float32)[np.newaxis, :, :, :])
tensor_img = tensor_img.transpose([0,3,1,2])
tensor_img = tensor_img.transpose([0, 3, 1, 2])
return tensor_img
def inference_img_scale(self, input: paddle.Tensor) -> List[paddle.Tensor]:
pred_global, pred_local, pred_fusion = self.model(input)
pred_global = P.gen_trimap_from_segmap_e2e(pred_global)
pred_local = pred_local.numpy()[0,0,:,:]
pred_fusion = pred_fusion.numpy()[0,0,:,:]
pred_local = pred_local.numpy()[0, 0, :, :]
pred_fusion = pred_fusion.numpy()[0, 0, :, :]
return pred_global, pred_local, pred_fusion
def predict(self, image_list: list, visualization: bool =True, save_path: str = "gfm_resnet34_matting_output"):
def predict(self, image_list: list, visualization: bool = True, save_path: str = "gfm_resnet34_matting_output"):
self.model.eval()
result = []
with paddle.no_grad():
for i, img in enumerate(image_list):
if isinstance(img, str):
img = np.array(Image.open(img))[:,:,:3]
img = np.array(Image.open(img))[:, :, :3]
else:
img = img[:,:,::-1]
img = img[:, :, ::-1]
h, w, _ = img.shape
tensor_img = self.preprocess(img, h, w)
pred_glance_1, pred_focus_1, pred_fusion_1 = self.inference_img_scale(tensor_img)
pred_glance_1 = resize(pred_glance_1,(h,w)) * 255.0
tensor_img = self.scale_image(img, h, w, 1/2)
pred_glance_1 = resize(pred_glance_1, (h, w)) * 255.0
tensor_img = self.scale_image(img, h, w, 1 / 2)
pred_glance_2, pred_focus_2, pred_fusion_2 = self.inference_img_scale(tensor_img)
pred_focus_2 = resize(pred_focus_2,(h,w))
pred_focus_2 = resize(pred_focus_2, (h, w))
pred_fusion = P.get_masked_local_from_global_test(pred_glance_1, pred_focus_2)
pred_fusion = (pred_fusion * 255).astype(np.uint8)
if visualization:
......@@ -142,8 +136,7 @@ class GFMResNet34(nn.Layer):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
......@@ -154,7 +147,9 @@ class GFMResNet34(nn.Layer):
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.predict(image_list=[args.input_path], save_path=args.output_dir, visualization=args.visualization)
results = self.predict(image_list=[args.input_path],
save_path=args.output_dir,
visualization=args.visualization)
return results
......@@ -163,14 +158,17 @@ class GFMResNet34(nn.Layer):
Add the command config options.
"""
self.arg_config_group.add_argument(
'--output_dir', type=str, default="gfm_resnet34_matting_output", help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization', type=bool, default=True, help="whether to save output as images.")
self.arg_config_group.add_argument('--output_dir',
type=str,
default="gfm_resnet34_matting_output",
help="The directory to save output images.")
self.arg_config_group.add_argument('--visualization',
type=bool,
default=True,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册