未验证 提交 ba3d10fa 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

add paddleseg in matting (#2138)

* add paddleseg

* update codex
上级 7aa70fd5
# dim_vgg16_matting # dim_vgg16_matting
|模型名称|dim_vgg16_matting| |模型名称|dim_vgg16_matting|
| :--- | :---: | | :--- | :---: |
|类别|图像-抠图| |类别|图像-抠图|
|网络|dim_vgg16| |网络|dim_vgg16|
|数据集|百度自建数据集| |数据集|百度自建数据集|
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
- 样例结果示例(左为原图,右为效果图): - 样例结果示例(左为原图,右为效果图):
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10' /> <img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10' />
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
</p> </p>
- ### 模型介绍 - ### 模型介绍
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
- Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。dim_vgg16_matting是一种需要trimap作为输入的matting模型。 - Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。dim_vgg16_matting是一种需要trimap作为输入的matting模型。
- 更多详情请参考:[dim_vgg16_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting) - 更多详情请参考:[dim_vgg16_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting)
## 二、安装 ## 二、安装
...@@ -46,11 +46,11 @@ ...@@ -46,11 +46,11 @@
- ```shell - ```shell
$ hub install dim_vgg16_matting $ hub install dim_vgg16_matting
``` ```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测 ## 三、模型API预测
- ### 1、命令行预测 - ### 1、命令行预测
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
- ```shell - ```shell
$ hub run dim_vgg16_matting --input_path "/PATH/TO/IMAGE" --trimap_path "/PATH/TO/TRIMAP" $ hub run dim_vgg16_matting --input_path "/PATH/TO/IMAGE" --trimap_path "/PATH/TO/TRIMAP"
``` ```
- 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) - 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例 - ### 2、预测代码示例
...@@ -69,16 +69,16 @@ ...@@ -69,16 +69,16 @@
model = hub.Module(name="dim_vgg16_matting") model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"]) result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
def predict(self, def predict(self,
image_list, image_list,
trimap_list, trimap_list,
visualization, visualization,
save_path): save_path):
``` ```
...@@ -95,7 +95,7 @@ ...@@ -95,7 +95,7 @@
- result (list(numpy.ndarray)):模型分割结果: - result (list(numpy.ndarray)):模型分割结果:
## 四、服务部署 ## 四、服务部署
- PaddleHub Serving可以部署人像matting在线服务。 - PaddleHub Serving可以部署人像matting在线服务。
......
# dim_vgg16_matting # dim_vgg16_matting
|Module Name|dim_vgg16_matting| |Module Name|dim_vgg16_matting|
| :--- | :---: | | :--- | :---: |
|Category|Matting| |Category|Matting|
|Network|dim_vgg16| |Network|dim_vgg16|
|Dataset|Baidu self-built dataset| |Dataset|Baidu self-built dataset|
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
- Sample results: - Sample results:
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
</p> </p>
- ### Module Introduction - ### Module Introduction
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
- Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation. - Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation.
- For more information, please refer to: [dim_vgg16_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting) - For more information, please refer to: [dim_vgg16_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting)
## II. Installation ## II. Installation
...@@ -46,11 +46,11 @@ ...@@ -46,11 +46,11 @@
- ```shell - ```shell
$ hub install dim_vgg16_matting $ hub install dim_vgg16_matting
``` ```
- In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) - In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md)
| [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III. Module API Prediction ## III. Module API Prediction
- ### 1、Command line Prediction - ### 1、Command line Prediction
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
- ```shell - ```shell
$ hub run dim_vgg16_matting --input_path "/PATH/TO/IMAGE" --trimap_path "/PATH/TO/TRIMAP" $ hub run dim_vgg16_matting --input_path "/PATH/TO/IMAGE" --trimap_path "/PATH/TO/TRIMAP"
``` ```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst) - If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst)
...@@ -70,16 +70,16 @@ ...@@ -70,16 +70,16 @@
model = hub.Module(name="dim_vgg16_matting") model = hub.Module(name="dim_vgg16_matting")
result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["PATH/TO/TRIMAP"]) result = model.predict(image_list=["/PATH/TO/IMAGE"], trimap_list=["/PATH/TO/IMAGE"])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
def predict(self, def predict(self,
image_list, image_list,
trimap_list, trimap_list,
visualization, visualization,
save_path): save_path):
``` ```
...@@ -88,7 +88,7 @@ ...@@ -88,7 +88,7 @@
- **Parameter** - **Parameter**
- image_list (list(str | numpy.ndarray)): Image path or image data, ndarray.shape is in the format \[H, W, C\],BGR. - image_list (list(str | numpy.ndarray)): Image path or image data, ndarray.shape is in the format \[H, W, C\],BGR.
- trimap_list(list(str | numpy.ndarray)): Trimap path or trimap data, ndarray.shape is in the format \[H, W],Gray style. - trimap_list(list(str | numpy.ndarray)): Trimap path or trimap data, ndarray.shape is in the format \[H, W],Gray style.
- visualization (bool): Whether to save the recognition results as picture files, default is False. - visualization (bool): Whether to save the recognition results as picture files, default is False.
- save_path (str): Save path of images, "dim_vgg16_matting_output" by default. - save_path (str): Save path of images, "dim_vgg16_matting_output" by default.
...@@ -96,7 +96,7 @@ ...@@ -96,7 +96,7 @@
- result (list(numpy.ndarray)):The list of model results. - result (list(numpy.ndarray)):The list of model results.
## IV. Server Deployment ## IV. Server Deployment
- PaddleHub Serving can deploy an online service of matting. - PaddleHub Serving can deploy an online service of matting.
......
# gfm_resnet34_matting # gfm_resnet34_matting
|模型名称|gfm_resnet34_matting| |模型名称|gfm_resnet34_matting|
| :--- | :---: | | :--- | :---: |
|类别|图像-抠图| |类别|图像-抠图|
|网络|gfm_resnet34| |网络|gfm_resnet34|
|数据集|AM-2k| |数据集|AM-2k|
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
- 样例结果示例(左为原图,右为效果图): - 样例结果示例(左为原图,右为效果图):
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/>
</p> </p>
- ### 模型介绍 - ### 模型介绍
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
- Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。gfm_resnet34_matting可生成抠图结果。 - Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。gfm_resnet34_matting可生成抠图结果。
- 更多详情请参考:[gfm_resnet34_matting](https://github.com/JizhiziLi/GFM) - 更多详情请参考:[gfm_resnet34_matting](https://github.com/JizhiziLi/GFM)
## 二、安装 ## 二、安装
...@@ -46,11 +46,11 @@ ...@@ -46,11 +46,11 @@
- ```shell - ```shell
$ hub install gfm_resnet34_matting $ hub install gfm_resnet34_matting
``` ```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测 ## 三、模型API预测
- ### 1、命令行预测 - ### 1、命令行预测
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
- ```shell - ```shell
$ hub run gfm_resnet34_matting --input_path "/PATH/TO/IMAGE" $ hub run gfm_resnet34_matting --input_path "/PATH/TO/IMAGE"
``` ```
- 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) - 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例 - ### 2、预测代码示例
...@@ -68,15 +68,15 @@ ...@@ -68,15 +68,15 @@
import cv2 import cv2
model = hub.Module(name="gfm_resnet34_matting") model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"]) result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
def predict(self, def predict(self,
image_list, image_list,
visualization, visualization,
save_path): save_path):
``` ```
...@@ -92,7 +92,7 @@ ...@@ -92,7 +92,7 @@
- result (list(numpy.ndarray)):模型分割结果: - result (list(numpy.ndarray)):模型分割结果:
## 四、服务部署 ## 四、服务部署
- PaddleHub Serving可以部署动物matting在线服务。 - PaddleHub Serving可以部署动物matting在线服务。
...@@ -150,4 +150,3 @@ ...@@ -150,4 +150,3 @@
* 1.0.0 * 1.0.0
初始发布 初始发布
# gfm_resnet34_matting # gfm_resnet34_matting
|Module Name|gfm_resnet34_matting| |Module Name|gfm_resnet34_matting|
| :--- | :---: | | :--- | :---: |
|Category|Image Matting| |Category|Image Matting|
|Network|gfm_resnet34| |Network|gfm_resnet34|
|Dataset|AM-2k| |Dataset|AM-2k|
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
- Sample results: - Sample results:
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/145993777-9b69a85d-d31c-4743-8620-82b2a56ca1e7.jpg" width = "480" height = "350" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/145993809-b0fb4bae-2c64-4868-99fc-500f19343442.png" width = "480" height = "350" hspace='10'/>
</p> </p>
- ### Module Introduction - ### Module Introduction
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
- Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation. - Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation.
- For more information, please refer to: [gfm_resnet34_matting](https://github.com/JizhiziLi/GFM) - For more information, please refer to: [gfm_resnet34_matting](https://github.com/JizhiziLi/GFM)
## II. Installation ## II. Installation
...@@ -46,11 +46,11 @@ ...@@ -46,11 +46,11 @@
- ```shell - ```shell
$ hub install gfm_resnet34_matting $ hub install gfm_resnet34_matting
``` ```
- In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) - In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md)
| [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III. Module API Prediction ## III. Module API Prediction
- ### 1、Command line Prediction - ### 1、Command line Prediction
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
- ```shell - ```shell
$ hub run gfm_resnet34_matting --input_path "/PATH/TO/IMAGE" $ hub run gfm_resnet34_matting --input_path "/PATH/TO/IMAGE"
``` ```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst) - If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst)
...@@ -69,16 +69,16 @@ ...@@ -69,16 +69,16 @@
import cv2 import cv2
model = hub.Module(name="gfm_resnet34_matting") model = hub.Module(name="gfm_resnet34_matting")
result = model.predict(["/PATH/TO/IMAGE"]) result = model.predict([cv2.imread("/PATH/TO/IMAGE")])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
def predict(self, def predict(self,
image_list, image_list,
visualization, visualization,
save_path): save_path):
``` ```
...@@ -94,7 +94,7 @@ ...@@ -94,7 +94,7 @@
- result (list(numpy.ndarray)):The list of model results. - result (list(numpy.ndarray)):The list of model results.
## IV. Server Deployment ## IV. Server Deployment
- PaddleHub Serving can deploy an online service of matting. - PaddleHub Serving can deploy an online service of matting.
......
...@@ -11,47 +11,48 @@ ...@@ -11,47 +11,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable
from typing import Callable, Union, List, Tuple from typing import List
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from gfm_resnet34_matting.resnet import resnet34 from gfm_resnet34_matting.resnet import resnet34
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> Callable: def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> Callable:
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)
padding=1, bias_attr=False)
def conv_up_psp(in_channels: int, out_channels: int, up_sample: float) -> Callable: def conv_up_psp(in_channels: int, out_channels: int, up_sample: float) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1), return nn.Sequential(nn.Conv2D(in_channels, out_channels, 3, padding=1), nn.BatchNorm2D(out_channels), nn.ReLU(),
nn.BatchNorm2D(out_channels), nn.Upsample(scale_factor=up_sample, mode='bilinear', align_corners=False))
nn.ReLU(),
nn.Upsample(scale_factor=up_sample, mode='bilinear',align_corners = False))
def build_bb(in_channels: int, mid_channels: int, out_channels: int) -> Callable: def build_bb(in_channels: int, mid_channels: int, out_channels: int) -> Callable:
return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2, return nn.Sequential(nn.Conv2D(in_channels, mid_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(mid_channels),
padding=2), nn.BatchNorm2D(mid_channels), nn. nn.ReLU(), nn.Conv2D(mid_channels, out_channels, 3, dilation=2, padding=2),
ReLU(), nn.Conv2D(mid_channels, out_channels, 3, nn.BatchNorm2D(out_channels), nn.ReLU(),
dilation=2, padding=2), nn.BatchNorm2D(out_channels), nn.ReLU(), nn.Conv2D(out_channels, nn.Conv2D(out_channels, out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D(out_channels),
out_channels, 3, dilation=2, padding=2), nn.BatchNorm2D( nn.ReLU())
out_channels), nn.ReLU())
def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int, def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, out_channels: int, last_bnrelu: bool,
last_bnrelu: bool, upsample_flag: bool) -> Callable: upsample_flag: bool) -> Callable:
layers = [] layers = []
layers += [nn.Conv2D(in_channels, mid_channels_1, 3, padding=1), nn. layers += [
BatchNorm2D(mid_channels_1), nn.ReLU(), nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1), nn. nn.Conv2D(in_channels, mid_channels_1, 3, padding=1),
BatchNorm2D(mid_channels_2), nn.ReLU(), nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)] nn.BatchNorm2D(mid_channels_1),
nn.ReLU(),
nn.Conv2D(mid_channels_1, mid_channels_2, 3, padding=1),
nn.BatchNorm2D(mid_channels_2),
nn.ReLU(),
nn.Conv2D(mid_channels_2, out_channels, 3, padding=1)
]
if last_bnrelu: if last_bnrelu:
layers += [nn.BatchNorm2D(out_channels), nn.ReLU()] layers += [nn.BatchNorm2D(out_channels), nn.ReLU()]
if upsample_flag: if upsample_flag:
layers += [nn.Upsample(scale_factor=2, mode='bilinear')] layers += [nn.Upsample(scale_factor=2, mode='bilinear')]
...@@ -61,6 +62,7 @@ def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, ou ...@@ -61,6 +62,7 @@ def build_decoder(in_channels: int, mid_channels_1: int, mid_channels_2: int, ou
class BasicBlock(nn.Layer): class BasicBlock(nn.Layer):
expansion = 1 expansion = 1
def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None): def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride) self.conv1 = conv3x3(inplanes, planes, stride)
...@@ -90,10 +92,8 @@ class PSPModule(nn.Layer): ...@@ -90,10 +92,8 @@ class PSPModule(nn.Layer):
def __init__(self, features: paddle.Tensor, out_features: int = 1024, sizes: List[int] = (1, 2, 3, 6)): def __init__(self, features: paddle.Tensor, out_features: int = 1024, sizes: List[int] = (1, 2, 3, 6)):
super().__init__() super().__init__()
#self.stages = [] #self.stages = []
self.stages = nn.LayerList([self._make_stage(features, size) for self.stages = nn.LayerList([self._make_stage(features, size) for size in sizes])
size in sizes]) self.bottleneck = nn.Conv2D(features * (len(sizes) + 1), out_features, kernel_size=1)
self.bottleneck = nn.Conv2D(features * (len(sizes) + 1),
out_features, kernel_size=1)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def _make_stage(self, features: paddle.Tensor, size: int) -> Callable: def _make_stage(self, features: paddle.Tensor, size: int) -> Callable:
...@@ -103,7 +103,8 @@ class PSPModule(nn.Layer): ...@@ -103,7 +103,8 @@ class PSPModule(nn.Layer):
def forward(self, feats: paddle.Tensor) -> paddle.Tensor: def forward(self, feats: paddle.Tensor) -> paddle.Tensor:
h, w = feats.shape[2], feats.shape[3] h, w = feats.shape[2], feats.shape[3]
priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear',align_corners = True) for stage in self.stages] + [feats] priors = [F.upsample(stage(feats), size=(h, w), mode='bilinear', align_corners=True)
for stage in self.stages] + [feats]
bottle = self.bottleneck(paddle.concat(priors, 1)) bottle = self.bottleneck(paddle.concat(priors, 1))
return self.relu(bottle) return self.relu(bottle)
...@@ -113,10 +114,8 @@ class SELayer(nn.Layer): ...@@ -113,10 +114,8 @@ class SELayer(nn.Layer):
def __init__(self, channel: int, reduction: int = 4): def __init__(self, channel: int, reduction: int = 4):
super(SELayer, self).__init__() super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1) self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias_attr=False), nn.ReLU(),
bias_attr=False), nn.ReLU(), nn. nn.Linear(channel // reduction, channel, bias_attr=False), nn.Sigmoid())
Linear(channel // reduction, channel, bias_attr=False), nn.
Sigmoid())
def forward(self, x: paddle.Tensor) -> paddle.Tensor: def forward(self, x: paddle.Tensor) -> paddle.Tensor:
b, c, _, _ = x.size() b, c, _, _ = x.size()
...@@ -128,11 +127,11 @@ class SELayer(nn.Layer): ...@@ -128,11 +127,11 @@ class SELayer(nn.Layer):
class GFM(nn.Layer): class GFM(nn.Layer):
""" """
The GFM implementation based on PaddlePaddle. The GFM implementation based on PaddlePaddle.
The original article refers to: The original article refers to:
Bridging Composite and Real: Towards End-to-end Deep Image Matting [IJCV-2021] Bridging Composite and Real: Towards End-to-end Deep Image Matting [IJCV-2021]
Main network file (GFM). Main network file (GFM).
Copyright (c) 2021, Jizhizi Li (jili8515@uni.sydney.edu.au) Copyright (c) 2021, Jizhizi Li (jili8515@uni.sydney.edu.au)
Licensed under the MIT License (see LICENSE for details) Licensed under the MIT License (see LICENSE for details)
Github repo: https://github.com/JizhiziLi/GFM Github repo: https://github.com/JizhiziLi/GFM
...@@ -150,18 +149,15 @@ class GFM(nn.Layer): ...@@ -150,18 +149,15 @@ class GFM(nn.Layer):
self.gd_channel = 2 self.gd_channel = 2
if self.backbone == 'r34_2b': if self.backbone == 'r34_2b':
self.resnet = resnet34() self.resnet = resnet34()
self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1), self.encoder0 = nn.Sequential(nn.Conv2D(3, 64, 3, padding=1), nn.BatchNorm2D(64), nn.ReLU())
nn.BatchNorm2D(64), nn.ReLU())
self.encoder1 = self.resnet.layer1 self.encoder1 = self.resnet.layer1
self.encoder2 = self.resnet.layer2 self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3 self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4 self.encoder4 = self.resnet.layer4
self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True self.encoder5 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock( BasicBlock(512, 512), BasicBlock(512, 512))
512, 512)) self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True), BasicBlock(512, 512),
self.encoder6 = nn.Sequential(nn.MaxPool2D(2, 2, ceil_mode=True BasicBlock(512, 512), BasicBlock(512, 512))
), BasicBlock(512, 512), BasicBlock(512, 512), BasicBlock(
512, 512))
self.psp_module = PSPModule(512, 512, (1, 3, 5)) self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp6 = conv_up_psp(512, 512, 2) self.psp6 = conv_up_psp(512, 512, 2)
self.psp5 = conv_up_psp(512, 512, 4) self.psp5 = conv_up_psp(512, 512, 4)
...@@ -183,28 +179,19 @@ class GFM(nn.Layer): ...@@ -183,28 +179,19 @@ class GFM(nn.Layer):
self.decoder2_f = build_decoder(256, 128, 128, 64, True, True) self.decoder2_f = build_decoder(256, 128, 128, 64, True, True)
self.decoder1_f = build_decoder(128, 64, 64, 64, True, False) self.decoder1_f = build_decoder(128, 64, 64, 64, True, False)
if self.rosta == 'RIM': if self.rosta == 'RIM':
self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3, self.decoder0_g_tt = nn.Sequential(nn.Conv2D(64, 3, 3, padding=1))
padding=1)) self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
self.decoder0_g_ft = nn.Sequential(nn.Conv2D(64, 2, 3, self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3, padding=1))
padding=1)) self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_g_bt = nn.Sequential(nn.Conv2D(64, 2, 3, self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
padding=1)) self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
self.decoder0_f_tt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_ft = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
self.decoder0_f_bt = nn.Sequential(nn.Conv2D(64, 1, 3,
padding=1))
else: else:
self.decoder0_g = nn.Sequential(nn.Conv2D(64, self. self.decoder0_g = nn.Sequential(nn.Conv2D(64, self.gd_channel, 3, padding=1))
gd_channel, 3, padding=1))
self.decoder0_f = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1)) self.decoder0_f = nn.Sequential(nn.Conv2D(64, 1, 3, padding=1))
if self.backbone == 'r34': if self.backbone == 'r34':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet. self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
bn1, self.resnet.relu)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.
layer1)
self.encoder2 = self.resnet.layer2 self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3 self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4 self.encoder4 = self.resnet.layer4
...@@ -230,14 +217,11 @@ class GFM(nn.Layer): ...@@ -230,14 +217,11 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else: else:
self.decoder0_g = build_decoder(128, 64, 64, self. self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'r101': elif self.backbone == 'r101':
self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet. self.encoder0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
bn1, self.resnet.relu) self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.layer1)
self.encoder1 = nn.Sequential(self.resnet.maxpool, self.resnet.
layer1)
self.encoder2 = self.resnet.layer2 self.encoder2 = self.resnet.layer2
self.encoder3 = self.resnet.layer3 self.encoder3 = self.resnet.layer3
self.encoder4 = self.resnet.layer4 self.encoder4 = self.resnet.layer4
...@@ -263,22 +247,16 @@ class GFM(nn.Layer): ...@@ -263,22 +247,16 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else: else:
self.decoder0_g = build_decoder(128, 64, 64, self. self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
elif self.backbone == 'd121': elif self.backbone == 'd121':
self.encoder0 = nn.Sequential(self.densenet.features.conv0, self.encoder0 = nn.Sequential(self.densenet.features.conv0, self.densenet.features.norm0,
self.densenet.features.norm0, self.densenet.features.relu0) self.densenet.features.relu0)
self.encoder1 = nn.Sequential(self.densenet.features. self.encoder1 = nn.Sequential(self.densenet.features.denseblock1, self.densenet.features.transition1)
denseblock1, self.densenet.features.transition1) self.encoder2 = nn.Sequential(self.densenet.features.denseblock2, self.densenet.features.transition2)
self.encoder2 = nn.Sequential(self.densenet.features. self.encoder3 = nn.Sequential(self.densenet.features.denseblock3, self.densenet.features.transition3)
denseblock2, self.densenet.features.transition2) self.encoder4 = nn.Sequential(self.densenet.features.denseblock4, nn.Conv2D(1024, 512, 3, padding=1),
self.encoder3 = nn.Sequential(self.densenet.features. nn.BatchNorm2D(512), nn.ReLU(), nn.MaxPool2D(2, 2, ceil_mode=True))
denseblock3, self.densenet.features.transition3)
self.encoder4 = nn.Sequential(self.densenet.features.
denseblock4, nn.Conv2D(1024, 512, 3, padding=1), nn.
BatchNorm2D(512), nn.ReLU(),
nn.MaxPool2D(2, 2, ceil_mode=True))
self.psp_module = PSPModule(512, 512, (1, 3, 5)) self.psp_module = PSPModule(512, 512, (1, 3, 5))
self.psp4 = conv_up_psp(512, 256, 2) self.psp4 = conv_up_psp(512, 256, 2)
self.psp3 = conv_up_psp(512, 128, 4) self.psp3 = conv_up_psp(512, 128, 4)
...@@ -301,12 +279,10 @@ class GFM(nn.Layer): ...@@ -301,12 +279,10 @@ class GFM(nn.Layer):
self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_ft = build_decoder(128, 64, 64, 1, False, True)
self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f_bt = build_decoder(128, 64, 64, 1, False, True)
else: else:
self.decoder0_g = build_decoder(128, 64, 64, self. self.decoder0_g = build_decoder(128, 64, 64, self.gd_channel, False, True)
gd_channel, False, True)
self.decoder0_f = build_decoder(128, 64, 64, 1, False, True) self.decoder0_f = build_decoder(128, 64, 64, 1, False, True)
if self.rosta == 'RIM': if self.rosta == 'RIM':
self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn. self.rim = nn.Sequential(nn.Conv2D(3, 16, 1), SELayer(16), nn.Conv2D(16, 1, 1))
Conv2D(16, 1, 1))
def forward(self, input: paddle.Tensor) -> List[paddle.Tensor]: def forward(self, input: paddle.Tensor) -> List[paddle.Tensor]:
glance_sigmoid = paddle.zeros(input.shape) glance_sigmoid = paddle.zeros(input.shape)
...@@ -325,10 +301,8 @@ class GFM(nn.Layer): ...@@ -325,10 +301,8 @@ class GFM(nn.Layer):
e6 = self.encoder6(e5) e6 = self.encoder6(e5)
psp = self.psp_module(e6) psp = self.psp_module(e6)
d6_g = self.decoder6_g(paddle.concat((psp, e6), 1)) d6_g = self.decoder6_g(paddle.concat((psp, e6), 1))
d5_g = self.decoder5_g(paddle.concat((self.psp6(psp), d5_g = self.decoder5_g(paddle.concat((self.psp6(psp), d6_g), 1))
d6_g), 1)) d4_g = self.decoder4_g(paddle.concat((self.psp5(psp), d5_g), 1))
d4_g = self.decoder4_g(paddle.concat((self.psp5(psp),
d5_g), 1))
else: else:
psp = self.psp_module(e4) psp = self.psp_module(e4)
d4_g = self.decoder4_g(paddle.concat((psp, e4), 1)) d4_g = self.decoder4_g(paddle.concat((psp, e4), 1))
...@@ -343,15 +317,11 @@ class GFM(nn.Layer): ...@@ -343,15 +317,11 @@ class GFM(nn.Layer):
else: else:
d0_g = self.decoder0_g(d1_g) d0_g = self.decoder0_g(d1_g)
elif self.rosta == 'RIM': elif self.rosta == 'RIM':
d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp d0_g_tt = self.decoder0_g_tt(paddle.concat((self.psp1(psp), d1_g), 1))
), d1_g), 1)) d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp), d1_g), 1))
d0_g_ft = self.decoder0_g_ft(paddle.concat((self.psp1(psp d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp), d1_g), 1))
), d1_g), 1))
d0_g_bt = self.decoder0_g_bt(paddle.concat((self.psp1(psp
), d1_g), 1))
else: else:
d0_g = self.decoder0_g(paddle.concat((self.psp1(psp), d0_g = self.decoder0_g(paddle.concat((self.psp1(psp), d1_g), 1))
d1_g), 1))
if self.rosta == 'RIM': if self.rosta == 'RIM':
glance_sigmoid_tt = F.sigmoid(d0_g_tt) glance_sigmoid_tt = F.sigmoid(d0_g_tt)
glance_sigmoid_ft = F.sigmoid(d0_g_ft) glance_sigmoid_ft = F.sigmoid(d0_g_ft)
...@@ -389,30 +359,24 @@ class GFM(nn.Layer): ...@@ -389,30 +359,24 @@ class GFM(nn.Layer):
else: else:
focus_sigmoid = F.sigmoid(d0_f) focus_sigmoid = F.sigmoid(d0_f)
if self.rosta == 'RIM': if self.rosta == 'RIM':
fusion_sigmoid_tt = collaborative_matting('TT', fusion_sigmoid_tt = collaborative_matting('TT', glance_sigmoid_tt, focus_sigmoid_tt)
glance_sigmoid_tt, focus_sigmoid_tt) fusion_sigmoid_ft = collaborative_matting('FT', glance_sigmoid_ft, focus_sigmoid_ft)
fusion_sigmoid_ft = collaborative_matting('FT', fusion_sigmoid_bt = collaborative_matting('BT', glance_sigmoid_bt, focus_sigmoid_bt)
glance_sigmoid_ft, focus_sigmoid_ft) fusion_sigmoid = paddle.concat((fusion_sigmoid_tt, fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid_bt = collaborative_matting('BT',
glance_sigmoid_bt, focus_sigmoid_bt)
fusion_sigmoid = paddle.concat((fusion_sigmoid_tt,
fusion_sigmoid_ft, fusion_sigmoid_bt), 1)
fusion_sigmoid = self.rim(fusion_sigmoid) fusion_sigmoid = self.rim(fusion_sigmoid)
return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt return [[glance_sigmoid_tt, focus_sigmoid_tt, fusion_sigmoid_tt],
], [glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft], [glance_sigmoid_ft, focus_sigmoid_ft, fusion_sigmoid_ft],
[glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt], [glance_sigmoid_bt, focus_sigmoid_bt, fusion_sigmoid_bt], fusion_sigmoid]
fusion_sigmoid]
else: else:
fusion_sigmoid = collaborative_matting(self.rosta, fusion_sigmoid = collaborative_matting(self.rosta, glance_sigmoid, focus_sigmoid)
glance_sigmoid, focus_sigmoid)
return glance_sigmoid, focus_sigmoid, fusion_sigmoid return glance_sigmoid, focus_sigmoid, fusion_sigmoid
def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid): def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
if rosta == 'TT': if rosta == 'TT':
values = paddle.max(glance_sigmoid, axis=1) values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1) index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float() index = index[:, None, :, :].cast(paddle.float32)
bg_mask = index.clone() bg_mask = index.clone()
bg_mask[bg_mask == 2] = 1 bg_mask[bg_mask == 2] = 1
bg_mask = 1 - bg_mask bg_mask = 1 - bg_mask
...@@ -428,20 +392,20 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid): ...@@ -428,20 +392,20 @@ def collaborative_matting(rosta, glance_sigmoid, focus_sigmoid):
elif rosta == 'BT': elif rosta == 'BT':
values = paddle.max(glance_sigmoid, axis=1) values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1) index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float() index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index - focus_sigmoid fusion_sigmoid = index - focus_sigmoid
fusion_sigmoid[fusion_sigmoid < 0] = 0 fusion_sigmoid[fusion_sigmoid < 0] = 0
else: else:
values = paddle.max(glance_sigmoid, axis=1) values = paddle.max(glance_sigmoid, axis=1)
index = paddle.argmax(glance_sigmoid, axis=1) index = paddle.argmax(glance_sigmoid, axis=1)
index = index[:, None, :, :].float() index = index[:, None, :, :].cast(paddle.float32)
fusion_sigmoid = index + focus_sigmoid fusion_sigmoid = index + focus_sigmoid
fusion_sigmoid[fusion_sigmoid > 1] = 1 fusion_sigmoid[fusion_sigmoid > 1] = 1
return fusion_sigmoid return fusion_sigmoid
if __name__ == "__main__": if __name__ == "__main__":
model = GFM() model = GFM()
x = paddle.ones([1,3, 256,256]) x = paddle.ones([1, 3, 256, 256])
result = model(x) result = model(x)
print(x) print(x)
\ No newline at end of file
...@@ -11,49 +11,45 @@ ...@@ -11,49 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import os import os
import time import time
import argparse from typing import List
from typing import Callable, Union, List, Tuple from typing import Union
from PIL import Image
import numpy as np
import cv2 import cv2
import scipy import gfm_resnet34_matting.processor as P
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F from gfm_resnet34_matting.gfm import GFM
from paddlehub.module.module import moduleinfo from PIL import Image
import paddlehub.vision.transforms as T
from paddlehub.module.module import moduleinfo, runnable, serving
from skimage.transform import resize from skimage.transform import resize
from gfm_resnet34_matting.gfm import GFM from paddlehub.module.module import moduleinfo
import gfm_resnet34_matting.processor as P from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo( @moduleinfo(name="gfm_resnet34_matting",
name="gfm_resnet34_matting", type="CV/matting",
type="CV/matting", author="paddlepaddle",
author="paddlepaddle", author_email="",
author_email="", summary="gfm_resnet34_matting is an animal matting model.",
summary="gfm_resnet34_matting is an animal matting model.", version="1.0.0")
version="1.0.0")
class GFMResNet34(nn.Layer): class GFMResNet34(nn.Layer):
""" """
The GFM implementation based on PaddlePaddle. The GFM implementation based on PaddlePaddle.
The original article refers to: The original article refers to:
Bridging Composite and Real: Towards End-to-end Deep Image Matting [IJCV-2021] Bridging Composite and Real: Towards End-to-end Deep Image Matting [IJCV-2021]
Main network file (GFM). Main network file (GFM).
Github repo: https://github.com/JizhiziLi/GFM Github repo: https://github.com/JizhiziLi/GFM
Paper link (Arxiv): https://arxiv.org/abs/2010.16188 Paper link (Arxiv): https://arxiv.org/abs/2010.16188
""" """
def __init__(self, pretrained: str=None): def __init__(self, pretrained: str = None):
super(GFMResNet34, self).__init__() super(GFMResNet34, self).__init__()
self.model = GFM() self.model = GFM()
...@@ -70,52 +66,50 @@ class GFMResNet34(nn.Layer): ...@@ -70,52 +66,50 @@ class GFMResNet34(nn.Layer):
self.model.set_dict(model_dict) self.model.set_dict(model_dict)
print("load pretrained parameters success") print("load pretrained parameters success")
def preprocess(self, img: Union[str, np.ndarray], h: int, w: int) -> paddle.Tensor: def preprocess(self, img: Union[str, np.ndarray], h: int, w: int) -> paddle.Tensor:
if min(h, w) > 1080: if min(h, w) > 1080:
img = self.resize_by_short(img) img = self.resize_by_short(img)
tensor_img = self.scale_image(img, h, w) tensor_img = self.scale_image(img, h, w)
return tensor_img return tensor_img
def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1/3): def scale_image(self, img: np.ndarray, h: int, w: int, ratio: float = 1 / 3):
new_h = min(1600, h - (h % 32)) new_h = min(1600, h - (h % 32))
new_w = min(1600, w - (w % 32)) new_w = min(1600, w - (w % 32))
resize_h = int(h*ratio) resize_h = int(h * ratio)
resize_w = int(w*ratio) resize_w = int(w * ratio)
new_h = min(1600, resize_h - (resize_h % 32)) new_h = min(1600, resize_h - (resize_h % 32))
new_w = min(1600, resize_w - (resize_w % 32)) new_w = min(1600, resize_w - (resize_w % 32))
scale_img = resize(img,(new_h,new_w)) * 255 scale_img = resize(img, (new_h, new_w)) * 255
tensor_img = paddle.to_tensor(scale_img.astype(np.float32)[np.newaxis, :, :, :]) tensor_img = paddle.to_tensor(scale_img.astype(np.float32)[np.newaxis, :, :, :])
tensor_img = tensor_img.transpose([0,3,1,2]) tensor_img = tensor_img.transpose([0, 3, 1, 2])
return tensor_img return tensor_img
def inference_img_scale(self, input: paddle.Tensor) -> List[paddle.Tensor]: def inference_img_scale(self, input: paddle.Tensor) -> List[paddle.Tensor]:
pred_global, pred_local, pred_fusion = self.model(input) pred_global, pred_local, pred_fusion = self.model(input)
pred_global = P.gen_trimap_from_segmap_e2e(pred_global) pred_global = P.gen_trimap_from_segmap_e2e(pred_global)
pred_local = pred_local.numpy()[0,0,:,:] pred_local = pred_local.numpy()[0, 0, :, :]
pred_fusion = pred_fusion.numpy()[0,0,:,:] pred_fusion = pred_fusion.numpy()[0, 0, :, :]
return pred_global, pred_local, pred_fusion return pred_global, pred_local, pred_fusion
def predict(self, image_list: list, visualization: bool = True, save_path: str = "gfm_resnet34_matting_output"):
def predict(self, image_list: list, visualization: bool =True, save_path: str = "gfm_resnet34_matting_output"):
self.model.eval() self.model.eval()
result = [] result = []
with paddle.no_grad(): with paddle.no_grad():
for i, img in enumerate(image_list): for i, img in enumerate(image_list):
if isinstance(img, str): if isinstance(img, str):
img = np.array(Image.open(img))[:,:,:3] img = np.array(Image.open(img))[:, :, :3]
else: else:
img = img[:,:,::-1] img = img[:, :, ::-1]
h, w, _ = img.shape h, w, _ = img.shape
tensor_img = self.preprocess(img, h, w) tensor_img = self.preprocess(img, h, w)
pred_glance_1, pred_focus_1, pred_fusion_1 = self.inference_img_scale(tensor_img) pred_glance_1, pred_focus_1, pred_fusion_1 = self.inference_img_scale(tensor_img)
pred_glance_1 = resize(pred_glance_1,(h,w)) * 255.0 pred_glance_1 = resize(pred_glance_1, (h, w)) * 255.0
tensor_img = self.scale_image(img, h, w, 1/2) tensor_img = self.scale_image(img, h, w, 1 / 2)
pred_glance_2, pred_focus_2, pred_fusion_2 = self.inference_img_scale(tensor_img) pred_glance_2, pred_focus_2, pred_fusion_2 = self.inference_img_scale(tensor_img)
pred_focus_2 = resize(pred_focus_2,(h,w)) pred_focus_2 = resize(pred_focus_2, (h, w))
pred_fusion = P.get_masked_local_from_global_test(pred_glance_1, pred_focus_2) pred_fusion = P.get_masked_local_from_global_test(pred_glance_1, pred_focus_2)
pred_fusion = (pred_fusion * 255).astype(np.uint8) pred_fusion = (pred_fusion * 255).astype(np.uint8)
if visualization: if visualization:
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
...@@ -142,11 +136,10 @@ class GFMResNet34(nn.Layer): ...@@ -142,11 +136,10 @@ class GFMResNet34(nn.Layer):
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
description="Run the {} module.".format(self.name), prog='hub run {}'.format(self.name),
prog='hub run {}'.format(self.name), usage='%(prog)s',
usage='%(prog)s', add_help=True)
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.") title="Config options", description="Run configuration for controlling module behavior, not required.")
...@@ -154,7 +147,9 @@ class GFMResNet34(nn.Layer): ...@@ -154,7 +147,9 @@ class GFMResNet34(nn.Layer):
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.predict(image_list=[args.input_path], save_path=args.output_dir, visualization=args.visualization) results = self.predict(image_list=[args.input_path],
save_path=args.output_dir,
visualization=args.visualization)
return results return results
...@@ -163,14 +158,17 @@ class GFMResNet34(nn.Layer): ...@@ -163,14 +158,17 @@ class GFMResNet34(nn.Layer):
Add the command config options. Add the command config options.
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--output_dir',
'--output_dir', type=str, default="gfm_resnet34_matting_output", help="The directory to save output images.") type=str,
self.arg_config_group.add_argument( default="gfm_resnet34_matting_output",
'--visualization', type=bool, default=True, help="whether to save output as images.") help="The directory to save output images.")
self.arg_config_group.add_argument('--visualization',
type=bool,
default=True,
help="whether to save output as images.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
Add the command input options. Add the command input options.
""" """
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.") self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册