未验证 提交 ba3d10fa 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

add paddleseg in matting (#2138)

* add paddleseg

* update codex
上级 7aa70fd5
...@@ -69,7 +69,7 @@ ...@@ -69,7 +69,7 @@
model = hub.Module(name="dim_vgg16_matting") model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"]) result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
model = hub.Module(name="dim_vgg16_matting") model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"]) result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
......
...@@ -68,7 +68,7 @@ ...@@ -68,7 +68,7 @@
import cv2 import cv2
model = hub.Module(name="gfm_resnet34_matting") model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"]) result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
...@@ -150,4 +150,3 @@ ...@@ -150,4 +150,3 @@
* 1.0.0 * 1.0.0
初始发布 初始发布
...@@ -69,7 +69,7 @@ ...@@ -69,7 +69,7 @@
import cv2 import cv2
model = hub.Module(name="gfm_resnet34_matting") model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"]) result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result) print(result)
``` ```
......
...@@ -11,44 +11,45 @@ ...@@ -11,44 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable
from typing import Callable, Union, List, Tuple from typing import List
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from gfm_resnet34_matting.resnet import resnet34 from gfm_resnet34_matting.resnet import resnet34
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> Callable: def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> Callable:
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)
padding=1, bias_attr=False)
def conv_up_psp(in_channels: int, out_channels: int, up_sample: float) -> Callable: def conv_up_psp(in_channels: int, out_channels: int, up_sample: float) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1), return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1), nn.BatchNorm2D(out_channels), nn.ReLU(),
nn.BatchNorm2D(out_channels), nn.Upsample(scale_factor=up_sample, mode='bilinear', align_corners=False))
nn.ReLU(),
nn.Upsample(scale_factor=up_sample, mode='bilinear',align_corners = False))
def build_bb(in_channels: int, mid_channels: int, out_channels: int) -> Callable: def build_bb(in_channels: int, mid_channels: int, out_channels: int) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2, return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(mid_channels),
padding=2), nn.BatchNorm2D(mid_channels), nn. nn.ReLU(), nn.Conv2D(mid_channels, out_channels, 3, dilation=2, padding=2),
ReLU(), nn.Conv2D(mid_channels, out_channels, 3, nn.BatchNorm2D(out_channels), nn.ReLU(),
dilation=2, padding=2), nn.BatchNorm2D(out_channels), nn.ReLU(), nn.Conv2D(out_channels, nn.Conv2D(out_channels, out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(out_channels),
out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D( nn.ReLU())
out_channels), nn.ReLU())
def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int, def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int, last_bnrelu: bool,
last_bnrelu: bool, upsample_flag: bool) -> Callable: upsample_flag: bool) -> Callable:
layers = [] layers = []
layers += [nn.Conv2D(in_channels, mid_channels_1, 3, padding=1), nn. layers += [
BatchNorm2D(mid_channels_1), nn.ReLU(), nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1), nn. nn.Conv2D(in_channels, mid_channels_1, 3, padding=1),
BatchNorm2D(mid_channels_2), nn.ReLU(), nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)] nn.BatchNorm2D(mid_channels_1),
nn.ReLU(),
nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1),
nn.BatchNorm2D(mid_channels_2),
nn.ReLU(),
nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)
]
if last_bnrelu: if last_bnrelu:
layers += [nn.BatchNorm2D(out_channels), nn.ReLU()] layers += [nn.BatchNorm2D(out_channels), nn.ReLU()]
...@@ -61,6 +62,7 @@ def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, ou ...@@ -61,6 +62,7 @@ def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, ou
class BasicBlock(nn.Layer): class BasicBlock(nn.Layer):
expansion = 1 expansion = 1
def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None): def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride) self.conv1 = conv3x3(inplanes, planes, stride)
...@@ -90,10 +92,8 @@ class PSPModule(nn.Layer): ...@@ -90,10 +92,8 @@ class PSPModule(nn.Layer):
def __init__(self, features: paddle.Tensor, out_features: int = 1024, sizes: List[int] = (1, 2, 3, 6)): def __init__(self, features: paddle.Tensor, out_features: int = 1024, sizes: List[int] = (1, 2, 3, 6)):
super().__init__() super().__init__()
#self.stages = [] #self.stages = []
self.stages = nn.LayerList([self._make_stage(features, size) for self.stages = nn.LayerList([self._make_stage(features, size) for size in sizes])
size in sizes]) self.bottleneck = nn.Conv2D(features * (len(sizes) + 1), out_features, kernel_size=1)
self.bottleneck = nn.Conv2D(features * (len(sizes) + 1),
out_features, kernel_size=1)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def _make_stage(self, features: paddle.Tensor, size: int) -> Callable: def _make_stage(self, features: paddle.Tensor, size: int) -> Callable:
...@@ -103,7 +103,8 @@ class PSPModule(nn.Layer): ...@@ -103,7 +103,8 @@ class PSPModule(nn.Layer):
def forward(self, feats: paddle.Tensor) -> paddle.Tensor: def forward(self, feats: paddle.Tensor) -> paddle.Tensor:
h, w = feats.shape[2], feats.shape[3] h, w = feats.shape[2], feats.shape[3]
priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear',align_corners = True) for stage in self.stages] + [feats] priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear', align_corners=True)
for stage in self.stages] + [feats]
bottle = self.bottleneck(paddle.concat(priors, 1)) bottle = self.bottleneck(paddle.concat(priors, 1))
return self.relu(bottle) return self.relu(bottle)
...@@ -113,10 +114,8 @@ class SELayer(nn.Layer): ...@@ -113,10 +114,8 @@ class SELayer(nn.Layer):
def __init__(self, channel: int, reduction: int = 4): def __init__(self, channel: int, reduction: int = 4):
super(SELayer, self).__init__() super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1) self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias_attr=False), nn.ReLU(),
bias_attr=False), nn.ReLU(), nn. nn.Linear(channel // reduction, channel, bias_attr=False), nn.Sigmoid())
Linear(channel // reduction, channel, bias_attr=False), nn.
Sigmoid())
def forward(self, x: paddle.Tensor) -> paddle.Tensor: def forward(self, x: paddle.Tensor) -> paddle.Tensor:
b, c, _, _ = x.size() b, c, _, _ = x.size()
...@@ -150,18 +149,15 @@ class GFM(nn.Layer): ...@@ -150,18 +149,15 @@ class GFM(nn.Layer):
self.gd_channel = 2 self.gd_channel = 2
if self.backbone == 'r34_2b': if self.backbone == 'r34_2b':
self.resnet = resnet34() self.resnet = resnet34()
self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1), self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1), nn.BatchNorm2D(64), nn.ReLU())
nn.BatchNorm2D(64), nn.ReLU())
self.encoder1 = self.resnet.layer1 self.encoder1 = self.resnet.layer1
self.encoder2 = self.resnet.layer2 self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3 self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4 self.encoder4 = self.resnet.layer4
self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock( BasicBlock(512, 512), BasicBlock(512, 512))
512, 512)) self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True BasicBlock(512, 512), BasicBlock(512, 512))
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock(
512, 512))
self.psp_module = PSPModule(512, 512, (1, 3, 5)) self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp6 = conv_up_psp(512, 512, 2) self.psp6 = conv_up_psp(512, 512, 2)
self.psp5 = conv_up_psp(512, 512, 4) self.psp5 = conv_up_psp(512, 512, 4)
...@@ -183,28 +179,19 @@ class GFM(nn.Layer): ...@@ -183,28 +179,19 @@ class GFM(nn.Layer):
self.decoder2_f = build_decoder(256, 128, 128, 64, True, True) self.decoder2_f = build_decoder(256, 128, 128, 64, True, True)
self.decoder1_f = build_decoder(128, 64, 64, 64, True, False) self.decoder1_f = build_decoder(128, 64, 64, 64, True, False)
if self.rosta == 'RIM': if self.rosta == 'RIM':
self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3, self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3, padding=1))
padding=1)) self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3, self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
padding=1)) self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3, self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
padding=1)) self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
else: else:
self.decoder0_g = nn.Sequential(nn.Conv2D(64, self. self.decoder0_g = nn.Sequential(nn.Conv2D(64, self.gd_channel, 3, padding=1))
gd_channel, 3, padding=1))
self.decoder0_f = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1)) self.decoder0_f = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
if self.backbone == 'r34': if self.backbone == 'r34':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet. self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet. self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
layer1)
self.encoder2 = self.resnet.layer2 self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3 self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4 self.encoder4 = self.resnet.layer4
...@@ -230,14 +217,11 @@ class GFM(nn.Layer): ...@@ -230,14 +217,11 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else: else:
self.decoder0_g = build_decoder(128, 64, 64, self. self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'r101': elif self.backbone == 'r101':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet. self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
bn1, self.resnet.relu) self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.
layer1)
self.encoder2 = self.resnet.layer2 self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3 self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4 self.encoder4 = self.resnet.layer4
...@@ -263,22 +247,16 @@ class GFM(nn.Layer): ...@@ -263,22 +247,16 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else: else:
self.decoder0_g = build_decoder(128, 64, 64, self. self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'd121': elif self.backbone == 'd121':
self.encoder0 = nn.Sequential(self.densenet.features.conv0, self.encoder0 = nn.Sequential(self.densenet.features.conv0, self.densenet.features.norm0,
self.densenet.features.norm0, self.densenet.features.relu0) self.densenet.features.relu0)
self.encoder1 = nn.Sequential(self.densenet.features. self.encoder1 = nn.Sequential(self.densenet.features.denseblock1, self.densenet.features.transition1)
denseblock1, self.densenet.features.transition1) self.encoder2 = nn.Sequential(self.densenet.features.denseblock2, self.densenet.features.transition2)
self.encoder2 = nn.Sequential(self.densenet.features. self.encoder3 = nn.Sequential(self.densenet.features.denseblock3, self.densenet.features.transition3)
denseblock2, self.densenet.features.transition2) self.encoder4 = nn.Sequential(self.densenet.features.denseblock4, nn.Conv2D(1024, 512, 3, padding=1),
self.encoder3 = nn.Sequential(self.densenet.features. nn.BatchNorm2D(512), nn.ReLU(), nn.MaxPool2D(2, 2, ceil_mode=True))
denseblock3, self.densenet.features.transition3)
self.encoder4 = nn.Sequential(self.densenet.features.
denseblock4, nn.Conv2D(1024, 512, 3, padding=1), nn.
BatchNorm2D(512), nn.ReLU(),
nn.MaxPool2D(2, 2, ceil_mode=True))
self.psp_module = PSPModule(512, 512, (1, 3, 5)) self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp4 = conv_up_psp(512, 256, 2) self.psp4 = conv_up_psp(512, 256, 2)
self.psp3 = conv_up_psp(512, 128, 4) self.psp3 = conv_up_psp(512, 128, 4)
...@@ -301,12 +279,10 @@ class GFM(nn.Layer): ...@@ -301,12 +279,10 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else: else:
self.decoder0_g = build_decoder(128, 64, 64, self. self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
if self.rosta == 'RIM': if self.rosta == 'RIM':
self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn. self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn.Conv2D(16, 1, 1))
Conv2D(16, 1, 1))
def forward(self, input: paddle.Tensor) -> List[paddle.Tensor]: def forward(self, input: paddle.Tensor) -> List[paddle.Tensor]:
glance_sigmoid = paddle.zeros(input.shape) glance_sigmoid = paddle.zeros(input.shape)
...@@ -325,10 +301,8 @@ class GFM(nn.Layer): ...@@ -325,10 +301,8 @@ class GFM(nn.Layer):
e6 = self.encoder6(e5) e6 = self.encoder6(e5)
psp = self.psp_module(e6) psp = self.psp_module(e6)
d6_g = self.decoder6_g(paddle.concat((psp, e6), 1)) d6_g = self.decoder6_g(paddle.concat((psp, e6), 1))
d5_g = self.decoder5_g(paddle.concat((self.psp6(psp), d5_g = self.decoder5_g(paddle.concat((self.psp6(psp), d6_g), 1))
d6_g), 1)) d4_g = self.decoder4_g(paddle.concat((self.psp5(psp), d5_g), 1))
d4_g = self.decoder4_g(paddle.concat((self.psp5(psp),
d5_g), 1))
else: else:
psp = self.psp_module(e4) psp = self.psp_module(e4)
d4_g = self.decoder4_g(paddle.concat((psp, e4), 1)) d4_g = self.decoder4_g(paddle.concat((psp, e4), 1))
...@@ -343,15 +317,11 @@ class GFM(nn.Layer): ...@@ -343,15 +317,11 @@ class GFM(nn.Layer):
else: else:
d0_g = self.decoder0_g(d1_g) d0_g = self.decoder0_g(d1_g)
elif self.rosta == 'RIM': elif self.rosta == 'RIM':
d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp), d1_g), 1))
), d1_g), 1)) d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp), d1_g), 1))
d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp), d1_g), 1))
), d1_g), 1))
d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp
), d1_g), 1))
else: else:
d0_g = self.decoder0_g(paddle.concat((self.psp1(psp), d0_g = self.decoder0_g(paddle.concat((self.psp1(psp), d1_g), 1))
d1_g), 1))
if self.rosta == 'RIM': if self.rosta == 'RIM':
glance_sigmoid_tt = F.sigmoid(d0_g_tt) glance_sigmoid_tt = F.sigmoid(d0_g_tt)
glance_sigmoid_ft = F.sigmoid(d0_g_ft) glance_sigmoid_ft = F.sigmoid(d0_g_ft)
...@@ -389,22 +359,16 @@ class GFM(nn.Layer): ...@@ -389,22 +359,16 @@ class GFM(nn.Layer):
else: else:
focus_sigmoid = F.sigmoid(d0_f) focus_sigmoid = F.sigmoid(d0_f)
if self.rosta == 'RIM': if self.rosta == 'RIM':
fusion_sigmoid_tt = collaborative_matting('TT', fusion_sigmoid_tt = collaborative_matting('TT', glance_sigmoid_tt, focus_sigmoid_tt)
glance_sigmoid_tt, focus_sigmoid_tt) fusion_sigmoid_ft = collaborative_matting('FT', glance_sigmoid_ft, focus_sigmoid_ft)
fusion_sigmoid_ft = collaborative_matting('FT', fusion_sigmoid_bt = collaborative_matting('BT', glance_sigmoid_bt, focus_sigmoid_bt)
glance_sigmoid_ft, focus_sigmoid_ft) fusion_sigmoid = paddle.concat((fusion_sigmoid_tt, fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid_bt = collaborative_matting('BT',
glance_sigmoid_bt, focus_sigmoid_bt)
fusion_sigmoid = paddle.concat((fusion_sigmoid_tt,
fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid = self.rim(fusion_sigmoid) fusion_sigmoid = self.rim(fusion_sigmoid)
return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt],
], [glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft], [glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft],
[glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt], [glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt], fusion_sigmoid]
fusion_sigmoid]
else: else:
fusion_sigmoid = collaborative_matting(self.rosta, fusion_sigmoid = collaborative_matting(self.rosta, glance_sigmoid, focus_sigmoid)
glance_sigmoid, focus_sigmoid)
return glance_sigmoid, focus_sigmoid, fusion_sigmoid return glance_sigmoid, focus_sigmoid, fusion_sigmoid
...@@ -412,7 +376,7 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid): ...@@ -412,7 +376,7 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
if rosta == 'TT': if rosta == 'TT':
values = paddle.max(glance_sigmoid, axis=1) values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1) index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float() index = index[:, None, :, :].cast(paddle.float32)
bg_mask = index.clone() bg_mask = index.clone()
bg_mask[bg_mask == 2] = 1 bg_mask[bg_mask == 2] = 1
bg_mask = 1 - bg_mask bg_mask = 1 - bg_mask
...@@ -428,13 +392,13 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid): ...@@ -428,13 +392,13 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
elif rosta == 'BT': elif rosta == 'BT':
values = paddle.max(glance_sigmoid, axis=1) values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1) index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float() index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index - focus_sigmoid fusion_sigmoid = index - focus_sigmoid
fusion_sigmoid[fusion_sigmoid < 0] = 0 fusion_sigmoid[fusion_sigmoid < 0] = 0
else: else:
values = paddle.max(glance_sigmoid, axis=1) values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1) index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float() index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index + focus_sigmoid fusion_sigmoid = index + focus_sigmoid
fusion_sigmoid[fusion_sigmoid > 1] = 1 fusion_sigmoid[fusion_sigmoid > 1] = 1
return fusion_sigmoid return fusion_sigmoid
...@@ -442,6 +406,6 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid): ...@@ -442,6 +406,6 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
if __name__ == "__main__": if __name__ == "__main__":
model = GFM() model = GFM()
x = paddle.ones([1,3, 256,256]) x = paddle.ones([1, 3, 256, 256])
result = model(x) result = model(x)
print(x) print(x)
...@@ -11,31 +11,27 @@ ...@@ -11,31 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import os import os
import time import time
import argparse from typing import List
from typing import Callable, Union, List, Tuple from typing import Union
from PIL import Image
import numpy as np
import cv2 import cv2
import scipy import gfm_resnet34_matting.processor as P
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F from gfm_resnet34_matting.gfm import GFM
from paddlehub.module.module import moduleinfo from PIL import Image
import paddlehub.vision.transforms as T
from paddlehub.module.module import moduleinfo, runnable, serving
from skimage.transform import resize from skimage.transform import resize
from gfm_resnet34_matting.gfm import GFM from paddlehub.module.module import moduleinfo
import gfm_resnet34_matting.processor as P from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo( @moduleinfo(name="gfm_resnet34_matting",
name="gfm_resnet34_matting",
type="CV/matting", type="CV/matting",
author="paddlepaddle", author="paddlepaddle",
author_email="", author_email="",
...@@ -53,7 +49,7 @@ class GFMResNet34(nn.Layer): ...@@ -53,7 +49,7 @@ class GFMResNet34(nn.Layer):
Paper link (Arxiv): https://arxiv.org/abs/2010.16188 Paper link (Arxiv): https://arxiv.org/abs/2010.16188
""" """
def __init__(self, pretrained: str=None): def __init__(self, pretrained: str = None):
super(GFMResNet34, self).__init__() super(GFMResNet34, self).__init__()
self.model = GFM() self.model = GFM()
...@@ -76,44 +72,42 @@ class GFMResNet34(nn.Layer): ...@@ -76,44 +72,42 @@ class GFMResNet34(nn.Layer):
tensor_img = self.scale_image(img, h, w) tensor_img = self.scale_image(img, h, w)
return tensor_img return tensor_img
def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1/3): def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1 / 3):
new_h = min(1600, h - (h % 32)) new_h = min(1600, h - (h % 32))
new_w = min(1600, w - (w % 32)) new_w = min(1600, w - (w % 32))
resize_h = int(h*ratio) resize_h = int(h * ratio)
resize_w = int(w*ratio) resize_w = int(w * ratio)
new_h = min(1600, resize_h - (resize_h % 32)) new_h = min(1600, resize_h - (resize_h % 32))
new_w = min(1600, resize_w - (resize_w % 32)) new_w = min(1600, resize_w - (resize_w % 32))
scale_img = resize(img,(new_h,new_w)) * 255 scale_img = resize(img, (new_h, new_w)) * 255
tensor_img = paddle.to_tensor(scale_img.astype(np.float32)[np.newaxis, :, :, :]) tensor_img = paddle.to_tensor(scale_img.astype(np.float32)[np.newaxis, :, :, :])
tensor_img = tensor_img.transpose([0,3,1,2]) tensor_img = tensor_img.transpose([0, 3, 1, 2])
return tensor_img return tensor_img
def inference_img_scale(self, input: paddle.Tensor) -> List[paddle.Tensor]: def inference_img_scale(self, input: paddle.Tensor) -> List[paddle.Tensor]:
pred_global, pred_local, pred_fusion = self.model(input) pred_global, pred_local, pred_fusion = self.model(input)
pred_global = P.gen_trimap_from_segmap_e2e(pred_global) pred_global = P.gen_trimap_from_segmap_e2e(pred_global)
pred_local = pred_local.numpy()[0,0,:,:] pred_local = pred_local.numpy()[0, 0, :, :]
pred_fusion = pred_fusion.numpy()[0,0,:,:] pred_fusion = pred_fusion.numpy()[0, 0, :, :]
return pred_global, pred_local, pred_fusion return pred_global, pred_local, pred_fusion
def predict(self, image_list: list, visualization: bool = True, save_path: str = "gfm_resnet34_matting_output"):
def predict(self, image_list: list, visualization: bool =True, save_path: str = "gfm_resnet34_matting_output"):
self.model.eval() self.model.eval()
result = [] result = []
with paddle.no_grad(): with paddle.no_grad():
for i, img in enumerate(image_list): for i, img in enumerate(image_list):
if isinstance(img, str): if isinstance(img, str):
img = np.array(Image.open(img))[:,:,:3] img = np.array(Image.open(img))[:, :, :3]
else: else:
img = img[:,:,::-1] img = img[:, :, ::-1]
h, w, _ = img.shape h, w, _ = img.shape
tensor_img = self.preprocess(img, h, w) tensor_img = self.preprocess(img, h, w)
pred_glance_1, pred_focus_1, pred_fusion_1 = self.inference_img_scale(tensor_img) pred_glance_1, pred_focus_1, pred_fusion_1 = self.inference_img_scale(tensor_img)
pred_glance_1 = resize(pred_glance_1,(h,w)) * 255.0 pred_glance_1 = resize(pred_glance_1, (h, w)) * 255.0
tensor_img = self.scale_image(img, h, w, 1/2) tensor_img = self.scale_image(img, h, w, 1 / 2)
pred_glance_2, pred_focus_2, pred_fusion_2 = self.inference_img_scale(tensor_img) pred_glance_2, pred_focus_2, pred_fusion_2 = self.inference_img_scale(tensor_img)
pred_focus_2 = resize(pred_focus_2,(h,w)) pred_focus_2 = resize(pred_focus_2, (h, w))
pred_fusion = P.get_masked_local_from_global_test(pred_glance_1, pred_focus_2) pred_fusion = P.get_masked_local_from_global_test(pred_glance_1, pred_focus_2)
pred_fusion = (pred_fusion * 255).astype(np.uint8) pred_fusion = (pred_fusion * 255).astype(np.uint8)
if visualization: if visualization:
...@@ -142,8 +136,7 @@ class GFMResNet34(nn.Layer): ...@@ -142,8 +136,7 @@ class GFMResNet34(nn.Layer):
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name), prog='hub run {}'.format(self.name),
usage='%(prog)s', usage='%(prog)s',
add_help=True) add_help=True)
...@@ -154,7 +147,9 @@ class GFMResNet34(nn.Layer): ...@@ -154,7 +147,9 @@ class GFMResNet34(nn.Layer):
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.predict(image_list=[args.input_path], save_path=args.output_dir, visualization=args.visualization) results = self.predict(image_list=[args.input_path],
save_path=args.output_dir,
visualization=args.visualization)
return results return results
...@@ -163,14 +158,17 @@ class GFMResNet34(nn.Layer): ...@@ -163,14 +158,17 @@ class GFMResNet34(nn.Layer):
Add the command config options. Add the command config options.
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--output_dir',
'--output_dir', type=str, default="gfm_resnet34_matting_output", help="The directory to save output images.") type=str,
self.arg_config_group.add_argument( default="gfm_resnet34_matting_output",
'--visualization', type=bool, default=True, help="whether to save output as images.") help="The directory to save output images.")
self.arg_config_group.add_argument('--visualization',
type=bool,
default=True,
help="whether to save output as images.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
Add the command input options. Add the command input options.
""" """
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.") self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册