未验证 提交 ba3d10fa 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

add paddleseg in matting (#2138)

* add paddleseg

* update codex
上级 7aa70fd5
# dim_vgg16_matting
|模型名称|dim_vgg16_matting|
| :--- | :---: |
| :--- | :---: |
|类别|图像-抠图|
|网络|dim_vgg16|
|数据集|百度自建数据集|
......@@ -17,8 +17,8 @@
- 样例结果示例(左为原图,右为效果图):
<p align="center">
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10' />
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10' />
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
</p>
- ### 模型介绍
......@@ -26,9 +26,9 @@
- Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。dim_vgg16_matting是一种需要trimap作为输入的matting模型。
- 更多详情请参考:[dim_vgg16_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting)
## 二、安装
......@@ -46,11 +46,11 @@
- ```shell
$ hub install dim_vgg16_matting
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
......@@ -58,7 +58,7 @@
- ```shell
$ hub run dim_vgg16_matting --input_path "/PATH/TO/IMAGE" --trimap_path "/PATH/TO/TRIMAP"
```
- 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
......@@ -69,16 +69,16 @@
model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"])
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result)
```
- ### 3、API
- ```python
def predict(self,
image_list,
trimap_list,
visualization,
def predict(self,
image_list,
trimap_list,
visualization,
save_path):
```
......@@ -95,7 +95,7 @@
- result (list(numpy.ndarray)):模型分割结果:
## 四、服务部署
- PaddleHub Serving可以部署人像matting在线服务。
......
# dim_vgg16_matting
|Module Name|dim_vgg16_matting|
| :--- | :---: |
| :--- | :---: |
|Category|Matting|
|Network|dim_vgg16|
|Dataset|Baidu self-built dataset|
......@@ -17,8 +17,8 @@
- Sample results:
<p align="center">
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
</p>
- ### Module Introduction
......@@ -26,9 +26,9 @@
- Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation.
- For more information, please refer to: [dim_vgg16_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting)
## II. Installation
......@@ -46,11 +46,11 @@
- ```shell
$ hub install dim_vgg16_matting
```
- In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md)
| [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III. Module API Prediction
- ### 1、Command line Prediction
......@@ -58,7 +58,7 @@
- ```shell
$ hub run dim_vgg16_matting --input_path "/PATH/TO/IMAGE" --trimap_path "/PATH/TO/TRIMAP"
```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst)
......@@ -70,16 +70,16 @@
model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"])
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result)
```
- ### 3、API
- ```python
def predict(self,
image_list,
trimap_list,
visualization,
def predict(self,
image_list,
trimap_list,
visualization,
save_path):
```
......@@ -88,7 +88,7 @@
- **Parameter**
- image_list (list(str | numpy.ndarray)): Image path or image data, ndarray.shape is in the format \[H, W, C\],BGR.
- trimap_list(list(str | numpy.ndarray)): Trimap path or trimap data, ndarray.shape is in the format \[H, W],Gray style.
- trimap_list(list(str | numpy.ndarray)): Trimap path or trimap data, ndarray.shape is in the format \[H, W],Gray style.
- visualization (bool): Whether to save the recognition results as picture files, default is False.
- save_path (str): Save path of images, "dim_vgg16_matting_output" by default.
......@@ -96,7 +96,7 @@
- result (list(numpy.ndarray)):The list of model results.
## IV. Server Deployment
- PaddleHub Serving can deploy an online service of matting.
......
# gfm_resnet34_matting
|模型名称|gfm_resnet34_matting|
| :--- | :---: |
| :--- | :---: |
|类别|图像-抠图|
|网络|gfm_resnet34|
|数据集|AM-2k|
......@@ -17,8 +17,8 @@
- 样例结果示例(左为原图,右为效果图):
<p align="center">
<img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/>
</p>
- ### 模型介绍
......@@ -26,9 +26,9 @@
- Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。gfm_resnet34_matting可生成抠图结果。
- 更多详情请参考:[gfm_resnet34_matting](https://github.com/JizhiziLi/GFM)
## 二、安装
......@@ -46,11 +46,11 @@
- ```shell
$ hub install gfm_resnet34_matting
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
......@@ -58,7 +58,7 @@
- ```shell
$ hub run gfm_resnet34_matting --input_path "/PATH/TO/IMAGE"
```
- 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
......@@ -68,15 +68,15 @@
import cv2
model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"])
result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result)
```
- ### 3、API
- ```python
def predict(self,
image_list,
visualization,
def predict(self,
image_list,
visualization,
save_path):
```
......@@ -92,7 +92,7 @@
- result (list(numpy.ndarray)):模型分割结果:
## 四、服务部署
- PaddleHub Serving可以部署动物matting在线服务。
......@@ -150,4 +150,3 @@
* 1.0.0
初始发布
# gfm_resnet34_matting
|Module Name|gfm_resnet34_matting|
| :--- | :---: |
| :--- | :---: |
|Category|Image Matting|
|Network|gfm_resnet34|
|Dataset|AM-2k|
......@@ -17,8 +17,8 @@
- Sample results:
<p align="center">
<img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/>
</p>
- ### Module Introduction
......@@ -26,9 +26,9 @@
- Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation.
- For more information, please refer to: [gfm_resnet34_matting](https://github.com/JizhiziLi/GFM)
## II. Installation
......@@ -46,11 +46,11 @@
- ```shell
$ hub install gfm_resnet34_matting
```
- In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md)
| [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III. Module API Prediction
- ### 1、Command line Prediction
......@@ -58,7 +58,7 @@
- ```shell
$ hub run gfm_resnet34_matting --input_path "/PATH/TO/IMAGE"
```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst)
......@@ -69,16 +69,16 @@
import cv2
model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"])
result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result)
```
- ### 3、API
- ```python
def predict(self,
image_list,
visualization,
def predict(self,
image_list,
visualization,
save_path):
```
......@@ -94,7 +94,7 @@
- result (list(numpy.ndarray)):The list of model results.
## IV. Server Deployment
- PaddleHub Serving can deploy an online service of matting.
......
......@@ -11,47 +11,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union, List, Tuple
from typing import Callable
from typing import List
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from gfm_resnet34_matting.resnet import resnet34
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> Callable:
"""3x3 convolution with padding"""
return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias_attr=False)
return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)
def conv_up_psp(in_channels: int, out_channels: int, up_sample: float) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2D(out_channels),
nn.ReLU(),
nn.Upsample(scale_factor=up_sample, mode='bilinear',align_corners = False))
return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1), nn.BatchNorm2D(out_channels), nn.ReLU(),
nn.Upsample(scale_factor=up_sample, mode='bilinear', align_corners=False))
def build_bb(in_channels: int, mid_channels: int, out_channels: int) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2,
padding=2), nn.BatchNorm2D(mid_channels), nn.
ReLU(), nn.Conv2D(mid_channels, out_channels, 3,
dilation=2, padding=2), nn.BatchNorm2D(out_channels), nn.ReLU(), nn.Conv2D(out_channels,
out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(
out_channels), nn.ReLU())
return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(mid_channels),
nn.ReLU(), nn.Conv2D(mid_channels, out_channels, 3, dilation=2, padding=2),
nn.BatchNorm2D(out_channels), nn.ReLU(),
nn.Conv2D(out_channels, out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(out_channels),
nn.ReLU())
def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int,
last_bnrelu: bool, upsample_flag: bool) -> Callable:
def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int, last_bnrelu: bool,
upsample_flag: bool) -> Callable:
layers = []
layers += [nn.Conv2D(in_channels, mid_channels_1, 3, padding=1), nn.
BatchNorm2D(mid_channels_1), nn.ReLU(), nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1), nn.
BatchNorm2D(mid_channels_2), nn.ReLU(), nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)]
layers += [
nn.Conv2D(in_channels, mid_channels_1, 3, padding=1),
nn.BatchNorm2D(mid_channels_1),
nn.ReLU(),
nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1),
nn.BatchNorm2D(mid_channels_2),
nn.ReLU(),
nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)
]
if last_bnrelu:
layers += [nn.BatchNorm2D(out_channels), nn.ReLU()]
if upsample_flag:
layers += [nn.Upsample(scale_factor=2, mode='bilinear')]
......@@ -61,6 +62,7 @@ def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, ou
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
......@@ -90,10 +92,8 @@ class PSPModule(nn.Layer):
def __init__(self, features: paddle.Tensor, out_features: int = 1024, sizes: List[int] = (1, 2, 3, 6)):
super().__init__()
#self.stages = []
self.stages = nn.LayerList([self._make_stage(features, size) for
size in sizes])
self.bottleneck = nn.Conv2D(features * (len(sizes) + 1),
out_features, kernel_size=1)
self.stages = nn.LayerList([self._make_stage(features, size) for size in sizes])
self.bottleneck = nn.Conv2D(features * (len(sizes) + 1), out_features, kernel_size=1)
self.relu = nn.ReLU()
def _make_stage(self, features: paddle.Tensor, size: int) -> Callable:
......@@ -103,7 +103,8 @@ class PSPModule(nn.Layer):
def forward(self, feats: paddle.Tensor) -> paddle.Tensor:
h, w = feats.shape[2], feats.shape[3]
priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear',align_corners = True) for stage in self.stages] + [feats]
priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear', align_corners=True)
for stage in self.stages] + [feats]
bottle = self.bottleneck(paddle.concat(priors, 1))
return self.relu(bottle)
......@@ -113,10 +114,8 @@ class SELayer(nn.Layer):
def __init__(self, channel: int, reduction: int = 4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction,
bias_attr=False), nn.ReLU(), nn.
Linear(channel // reduction, channel, bias_attr=False), nn.
Sigmoid())
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias_attr=False), nn.ReLU(),
nn.Linear(channel // reduction, channel, bias_attr=False), nn.Sigmoid())
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
b, c, _, _ = x.size()
......@@ -128,11 +127,11 @@ class SELayer(nn.Layer):
class GFM(nn.Layer):
"""
The GFM implementation based on PaddlePaddle.
The original article refers to:
Bridging Composite and Real: Towards End-to-end Deep Image Matting [IJCV-2021]
Main network file (GFM).
Copyright (c) 2021, Jizhizi Li (jili8515@uni.sydney.edu.au)
Licensed under the MIT License (see LICENSE for details)
Github repo: https://github.com/JizhiziLi/GFM
......@@ -150,18 +149,15 @@ class GFM(nn.Layer):
self.gd_channel = 2
if self.backbone == 'r34_2b':
self.resnet = resnet34()
self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1),
nn.BatchNorm2D(64), nn.ReLU())
self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1), nn.BatchNorm2D(64), nn.ReLU())
self.encoder1 = self.resnet.layer1
self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4
self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock(
512, 512))
self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock(
512, 512))
self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
BasicBlock(512, 512), BasicBlock(512, 512))
self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
BasicBlock(512, 512), BasicBlock(512, 512))
self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp6 = conv_up_psp(512, 512, 2)
self.psp5 = conv_up_psp(512, 512, 4)
......@@ -183,28 +179,19 @@ class GFM(nn.Layer):
self.decoder2_f = build_decoder(256, 128, 128, 64, True, True)
self.decoder1_f = build_decoder(128, 64, 64, 64, True, False)
if self.rosta == 'RIM':
self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3,
padding=1))
self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3,
padding=1))
self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3,
padding=1))
self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3, padding=1))
self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
else:
self.decoder0_g = nn.Sequential(nn.Conv2D(64, self.
gd_channel, 3, padding=1))
self.decoder0_g = nn.Sequential(nn.Conv2D(64, self.gd_channel, 3, padding=1))
self.decoder0_f = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
if self.backbone == 'r34':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.
bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.
layer1)
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4
......@@ -230,14 +217,11 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else:
self.decoder0_g = build_decoder(128, 64, 64, self.
gd_channel, False, True)
self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'r101':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.
bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.
layer1)
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4
......@@ -263,22 +247,16 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else:
self.decoder0_g = build_decoder(128, 64, 64, self.
gd_channel, False, True)
self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'd121':
self.encoder0 = nn.Sequential(self.densenet.features.conv0,
self.densenet.features.norm0, self.densenet.features.relu0)
self.encoder1 = nn.Sequential(self.densenet.features.
denseblock1, self.densenet.features.transition1)
self.encoder2 = nn.Sequential(self.densenet.features.
denseblock2, self.densenet.features.transition2)
self.encoder3 = nn.Sequential(self.densenet.features.
denseblock3, self.densenet.features.transition3)
self.encoder4 = nn.Sequential(self.densenet.features.
denseblock4, nn.Conv2D(1024, 512, 3, padding=1), nn.
BatchNorm2D(512), nn.ReLU(),
nn.MaxPool2D(2, 2, ceil_mode=True))
self.encoder0 = nn.Sequential(self.densenet.features.conv0, self.densenet.features.norm0,
self.densenet.features.relu0)
self.encoder1 = nn.Sequential(self.densenet.features.denseblock1, self.densenet.features.transition1)
self.encoder2 = nn.Sequential(self.densenet.features.denseblock2, self.densenet.features.transition2)
self.encoder3 = nn.Sequential(self.densenet.features.denseblock3, self.densenet.features.transition3)
self.encoder4 = nn.Sequential(self.densenet.features.denseblock4, nn.Conv2D(1024, 512, 3, padding=1),
nn.BatchNorm2D(512), nn.ReLU(), nn.MaxPool2D(2, 2, ceil_mode=True))
self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp4 = conv_up_psp(512, 256, 2)
self.psp3 = conv_up_psp(512, 128, 4)
......@@ -301,12 +279,10 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else:
self.decoder0_g = build_decoder(128, 64, 64, self.
gd_channel, False, True)
self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
if self.rosta == 'RIM':
self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn.
Conv2D(16, 1, 1))
self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn.Conv2D(16, 1, 1))
def forward(self, input: paddle.Tensor) -> List[paddle.Tensor]:
glance_sigmoid = paddle.zeros(input.shape)
......@@ -325,10 +301,8 @@ class GFM(nn.Layer):
e6 = self.encoder6(e5)
psp = self.psp_module(e6)
d6_g = self.decoder6_g(paddle.concat((psp, e6), 1))
d5_g = self.decoder5_g(paddle.concat((self.psp6(psp),
d6_g), 1))
d4_g = self.decoder4_g(paddle.concat((self.psp5(psp),
d5_g), 1))
d5_g = self.decoder5_g(paddle.concat((self.psp6(psp), d6_g), 1))
d4_g = self.decoder4_g(paddle.concat((self.psp5(psp), d5_g), 1))
else:
psp = self.psp_module(e4)
d4_g = self.decoder4_g(paddle.concat((psp, e4), 1))
......@@ -343,15 +317,11 @@ class GFM(nn.Layer):
else:
d0_g = self.decoder0_g(d1_g)
elif self.rosta == 'RIM':
d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp
), d1_g), 1))
d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp
), d1_g), 1))
d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp
), d1_g), 1))
d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp), d1_g), 1))
d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp), d1_g), 1))
d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp), d1_g), 1))
else:
d0_g = self.decoder0_g(paddle.concat((self.psp1(psp),
d1_g), 1))
d0_g = self.decoder0_g(paddle.concat((self.psp1(psp), d1_g), 1))
if self.rosta == 'RIM':
glance_sigmoid_tt = F.sigmoid(d0_g_tt)
glance_sigmoid_ft = F.sigmoid(d0_g_ft)
......@@ -389,30 +359,24 @@ class GFM(nn.Layer):
else:
focus_sigmoid = F.sigmoid(d0_f)
if self.rosta == 'RIM':
fusion_sigmoid_tt = collaborative_matting('TT',
glance_sigmoid_tt, focus_sigmoid_tt)
fusion_sigmoid_ft = collaborative_matting('FT',
glance_sigmoid_ft, focus_sigmoid_ft)
fusion_sigmoid_bt = collaborative_matting('BT',
glance_sigmoid_bt, focus_sigmoid_bt)
fusion_sigmoid = paddle.concat((fusion_sigmoid_tt,
fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid_tt = collaborative_matting('TT', glance_sigmoid_tt, focus_sigmoid_tt)
fusion_sigmoid_ft = collaborative_matting('FT', glance_sigmoid_ft, focus_sigmoid_ft)
fusion_sigmoid_bt = collaborative_matting('BT', glance_sigmoid_bt, focus_sigmoid_bt)
fusion_sigmoid = paddle.concat((fusion_sigmoid_tt, fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid = self.rim(fusion_sigmoid)
return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt
], [glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft],
[glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt],
fusion_sigmoid]
return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt],
[glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft],
[glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt], fusion_sigmoid]
else:
fusion_sigmoid = collaborative_matting(self.rosta,
glance_sigmoid, focus_sigmoid)
fusion_sigmoid = collaborative_matting(self.rosta, glance_sigmoid, focus_sigmoid)
return glance_sigmoid, focus_sigmoid, fusion_sigmoid
def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
if rosta == 'TT':
values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float()
index = index[:, None, :, :].cast(paddle.float32)
bg_mask = index.clone()
bg_mask[bg_mask == 2] = 1
bg_mask = 1 - bg_mask
......@@ -428,20 +392,20 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
elif rosta == 'BT':
values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float()
index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index - focus_sigmoid
fusion_sigmoid[fusion_sigmoid < 0] = 0
else:
values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float()
index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index + focus_sigmoid
fusion_sigmoid[fusion_sigmoid > 1] = 1
return fusion_sigmoid
if __name__ == "__main__":
if __name__ == "__main__":
model = GFM()
x = paddle.ones([1,3, 256,256])
x = paddle.ones([1, 3, 256, 256])
result = model(x)
print(x)
\ No newline at end of file
print(x)
......@@ -11,49 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import argparse
from typing import Callable, Union, List, Tuple
from typing import List
from typing import Union
from PIL import Image
import numpy as np
import cv2
import scipy
import gfm_resnet34_matting.processor as P
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlehub.module.module import moduleinfo
import paddlehub.vision.transforms as T
from paddlehub.module.module import moduleinfo, runnable, serving
from gfm_resnet34_matting.gfm import GFM
from PIL import Image
from skimage.transform import resize
from gfm_resnet34_matting.gfm import GFM
import gfm_resnet34_matting.processor as P
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(
name="gfm_resnet34_matting",
type="CV/matting",
author="paddlepaddle",
author_email="",
summary="gfm_resnet34_matting is an animal matting model.",
version="1.0.0")
@moduleinfo(name="gfm_resnet34_matting",
type="CV/matting",
author="paddlepaddle",
author_email="",
summary="gfm_resnet34_matting is an animal matting model.",
version="1.0.0")
class GFMResNet34(nn.Layer):
"""
The GFM implementation based on PaddlePaddle.
The original article refers to:
Bridging Composite and Real: Towards End-to-end Deep Image Matting [IJCV-2021]
Main network file (GFM).
Github repo: https://github.com/JizhiziLi/GFM
Paper link (Arxiv): https://arxiv.org/abs/2010.16188
"""
def __init__(self, pretrained: str=None):
def __init__(self, pretrained: str = None):
super(GFMResNet34, self).__init__()
self.model = GFM()
......@@ -70,52 +66,50 @@ class GFMResNet34(nn.Layer):
self.model.set_dict(model_dict)
print("load pretrained parameters success")
def preprocess(self, img: Union[str, np.ndarray], h: int, w: int) -> paddle.Tensor:
def preprocess(self, img: Union[str, np.ndarray], h: int, w: int) -> paddle.Tensor:
if min(h, w) > 1080:
img = self.resize_by_short(img)
tensor_img = self.scale_image(img, h, w)
return tensor_img
def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1/3):
def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1 / 3):
new_h = min(1600, h - (h % 32))
new_w = min(1600, w - (w % 32))
resize_h = int(h*ratio)
resize_w = int(w*ratio)
resize_h = int(h * ratio)
resize_w = int(w * ratio)
new_h = min(1600, resize_h - (resize_h % 32))
new_w = min(1600, resize_w - (resize_w % 32))
scale_img = resize(img,(new_h,new_w)) * 255
scale_img = resize(img, (new_h, new_w)) * 255
tensor_img = paddle.to_tensor(scale_img.astype(np.float32)[np.newaxis, :, :, :])
tensor_img = tensor_img.transpose([0,3,1,2])
tensor_img = tensor_img.transpose([0, 3, 1, 2])
return tensor_img
def inference_img_scale(self, input: paddle.Tensor) -> List[paddle.Tensor]:
pred_global, pred_local, pred_fusion = self.model(input)
pred_global = P.gen_trimap_from_segmap_e2e(pred_global)
pred_local = pred_local.numpy()[0,0,:,:]
pred_fusion = pred_fusion.numpy()[0,0,:,:]
pred_local = pred_local.numpy()[0, 0, :, :]
pred_fusion = pred_fusion.numpy()[0, 0, :, :]
return pred_global, pred_local, pred_fusion
def predict(self, image_list: list, visualization: bool =True, save_path: str = "gfm_resnet34_matting_output"):
def predict(self, image_list: list, visualization: bool = True, save_path: str = "gfm_resnet34_matting_output"):
self.model.eval()
result = []
with paddle.no_grad():
for i, img in enumerate(image_list):
if isinstance(img, str):
img = np.array(Image.open(img))[:,:,:3]
img = np.array(Image.open(img))[:, :, :3]
else:
img = img[:,:,::-1]
img = img[:, :, ::-1]
h, w, _ = img.shape
tensor_img = self.preprocess(img, h, w)
pred_glance_1, pred_focus_1, pred_fusion_1 = self.inference_img_scale(tensor_img)
pred_glance_1 = resize(pred_glance_1,(h,w)) * 255.0
tensor_img = self.scale_image(img, h, w, 1/2)
pred_glance_1 = resize(pred_glance_1, (h, w)) * 255.0
tensor_img = self.scale_image(img, h, w, 1 / 2)
pred_glance_2, pred_focus_2, pred_fusion_2 = self.inference_img_scale(tensor_img)
pred_focus_2 = resize(pred_focus_2,(h,w))
pred_focus_2 = resize(pred_focus_2, (h, w))
pred_fusion = P.get_masked_local_from_global_test(pred_glance_1, pred_focus_2)
pred_fusion = (pred_fusion * 255).astype(np.uint8)
pred_fusion = (pred_fusion * 255).astype(np.uint8)
if visualization:
if not os.path.exists(save_path):
os.makedirs(save_path)
......@@ -142,11 +136,10 @@ class GFMResNet34(nn.Layer):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
......@@ -154,7 +147,9 @@ class GFMResNet34(nn.Layer):
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.predict(image_list=[args.input_path], save_path=args.output_dir, visualization=args.visualization)
results = self.predict(image_list=[args.input_path],
save_path=args.output_dir,
visualization=args.visualization)
return results
......@@ -163,14 +158,17 @@ class GFMResNet34(nn.Layer):
Add the command config options.
"""
self.arg_config_group.add_argument(
'--output_dir', type=str, default="gfm_resnet34_matting_output", help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization', type=bool, default=True, help="whether to save output as images.")
self.arg_config_group.add_argument('--output_dir',
type=str,
default="gfm_resnet34_matting_output",
help="The directory to save output images.")
self.arg_config_group.add_argument('--visualization',
type=bool,
default=True,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册