未验证 提交 77b8bac3 编写于 作者: Q qingqing01 提交者: GitHub

Add MidDaSv2 in ppgan.apps (#118)

* Add MidDaSv2 in ppgan.apps
* remove ppgan/apps/midas/run.py
上级 2cc72be9
......@@ -387,3 +387,47 @@ ppgan.apps.AnimeGANPredictor(output_path='output_dir',weight_path=None,use_adjus
> ```
> **返回值:**
> > - anime_image(numpy.ndarray): 返回风格化后的景色图像
## ppgan.apps.MiDaSPredictor
```pyhton
ppgan.apps.MiDaSPredictor(output=None, weight_path=None)
```
> 单目深度估计模型MiDaSv2, 参考 https://github.com/intel-isl/MiDaS, 论文是 Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer , 论文链接: https://arxiv.org/abs/1907.01341v3
> **示例**
>
> ```python
> from ppgan.apps import MiDaSPredictor
> # if set output, will write depth pfm and png file in output/MiDaS
> model = MiDaSPredictor()
> prediction = model.run()
> ```
>
> 深度图彩色显示:
>
> ```python
> import numpy as np
> import PIL.Image as Image
> import matplotlib as mpl
> import matplotlib.cm as cm
>
> vmax = np.percentile(prediction, 95)
> normalizer = mpl.colors.Normalize(vmin=prediction.min(), vmax=vmax)
> mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
> colormapped_im = (mapper.to_rgba(prediction)[:, :, :3] * 255).astype(np.uint8)
> im = Image.fromarray(colormapped_im)
> im.save('test_disp.jpeg')
> ```
>
> **参数:**
>
> > - output (str): 输出路径,如果是None,则不保存pfm和png的深度图文件。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> **返回值:**
> > - prediction (numpy.ndarray): 返回预测结果。
> > - pfm_f (str): 如果设置output路径,返回pfm文件保存路径。
> > - png_f (str): 如果设置output路径,返回png文件保存路径。
......@@ -20,3 +20,4 @@ from .edvr_predictor import EDVRPredictor
from .first_order_predictor import FirstOrderPredictor
from .face_parse_predictor import FaceParsePredictor
from .animegan_predictor import AnimeGANPredictor
from .midas_predictor import MiDaSPredictor
......@@ -26,7 +26,7 @@ from ppgan.utils.download import get_path_from_url
class AnimeGANPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
output_path='output',
weight_path=None,
use_adjust_brightness=True):
self.output_path = output_path
......
## Monocular Depth Estimation
The implemention of MiDasv2 refers to https://github.com/intel-isl/MiDaS.
@article{Ranftl2020,
author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun},
title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer},
journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
year = {2020},
}
# Refer https://github.com/intel-isl/MiDaS
import paddle
import paddle.nn as nn
def _make_encoder(backbone,
features,
use_pretrained,
groups=1,
expand=False,
exportable=True):
if backbone == "resnext101_wsl":
# resnext101_wsl
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048],
features,
groups=groups,
expand=expand)
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Layer()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand == True:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2D(in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)
scratch.layer2_rn = nn.Conv2D(in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)
scratch.layer3_rn = nn.Conv2D(in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)
scratch.layer4_rn = nn.Conv2D(in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)
return scratch
def _make_resnet_backbone(resnet):
pretrained = nn.Layer()
pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
resnet.maxpool, resnet.layer1)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
from .resnext import resnext101_32x8d_wsl
resnet = resnext101_32x8d_wsl()
return _make_resnet_backbone(resnet)
class ResidualConvUnit(nn.Layer):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2D(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)
self.conv2 = nn.Conv2D(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)
self.relu = nn.ReLU()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
x = self.relu(x)
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Layer):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(output,
scale_factor=2,
mode="bilinear",
align_corners=True)
return output
# Refer https://github.com/intel-isl/MiDaS
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
"""
import numpy as np
import paddle
import paddle.nn as nn
from .blocks import FeatureFusionBlock, _make_encoder
class BaseModel(paddle.nn.Layer):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = paddle.load(path)
self.set_dict(parameters)
class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(
backbone="resnext101_wsl",
features=features,
use_pretrained=use_pretrained)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
output_conv = [
nn.Conv2D(features, 128, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="bilinear"),
nn.Conv2D(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2D(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU() if non_negative else nn.Identity(),
]
if non_negative:
output_conv.append(nn.ReLU())
self.scratch.output_conv = nn.Sequential(*output_conv)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return paddle.squeeze(out, axis=1)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.vision.models.resnet import ResNet
from paddle.vision.models.resnet import BottleneckBlock
from paddle.utils.download import get_weights_path_from_url
__all__ = ['resnext101_32x8d_wsl']
class ResNetEx(ResNet):
"""ResNet extention model, support ResNeXt.
"""
def __init__(self,
block,
depth,
num_classes=1000,
with_pool=True,
groups=1,
width_per_group=64):
self.groups = groups
self.base_width = width_per_group
super(ResNetEx, self).__init__(block, depth, num_classes, with_pool)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(self.inplanes,
planes * block.expansion,
1,
stride=stride,
bias_attr=False),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _resnext(arch, Block, depth, **kwargs):
model = ResNetEx(Block, depth, **kwargs)
return model
def resnext101_32x8d_wsl(**kwargs):
"""ResNet101 32x8d wsl model
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnext('resnet101_32x8d', BottleneckBlock, 101, **kwargs)
# Refer https://github.com/intel-isl/MiDaS
import numpy as np
import cv2
import math
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) *
self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) *
self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height,
min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width,
min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height,
max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width,
max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1],
sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height),
interpolation=cv2.INTER_NEAREST)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "disparity" in sample:
disparity = sample["disparity"].astype(np.float32)
sample["disparity"] = np.ascontiguousarray(disparity)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
return sample
# Refer https://github.com/intel-isl/MiDaS
"""Utils for monoDepth.
"""
import sys
import re
import numpy as np
import cv2
def write_pfm(path, image, scale=1):
"""Write pfm file.
Args:
path (str): pathto file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with open(path, "wb") as file:
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif (len(image.shape) == 2
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
color = False
else:
raise Exception(
"Image must have H x W x 3, H x W x 1 or H x W dimensions.")
file.write("PF\n" if color else "Pf\n".encode())
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
file.write("%f\n".encode() % scale)
image.tofile(file)
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def write_depth(path, depth, bits=1):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
write_pfm(path + ".pfm", depth.astype(np.float32))
depth_min = depth.min()
depth_max = depth.max()
max_val = (2**(8 * bits)) - 1
if depth_max - depth_min > np.finfo("float").eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.type)
if bits == 1:
cv2.imwrite(path + ".png", out.astype("uint8"))
elif bits == 2:
cv2.imwrite(path + ".png", out.astype("uint16"))
return path + '.pfm', path + ".png"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import numpy as np
import cv2
import paddle
from paddle.vision.transforms import Compose
from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
from .midas.midas_net import MidasNet
from .midas.utils import write_depth
class MiDaSPredictor(BasePredictor):
def __init__(self, output=None, weight_path=None):
"""
output (str|None): output path, if None, do not write
depth map to pfm and png file.
weight_path (str|None): weight path, if None, load default
MiDaSv2.1 model.
"""
self.output_path = os.path.join(output, 'MiDaS') if output else None
self.net_h, self.net_w = 384, 384
if weight_path is None:
midasv2_weight_url = 'https://paddlegan.bj.bcebos.com/applications/midas.pdparams'
weight_path = get_path_from_url(midasv2_weight_url)
self.weight_path = weight_path
self.model = self.load_checkpoints()
self.transform = Compose([
Resize(
self.net_w,
self.net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
def load_checkpoints(self):
model = MidasNet(self.weight_path, non_negative=True)
model.eval()
return model
def run(self, img):
"""
img (str|np.ndarray|Image.Image): input image, it can be
images directory, Numpy.array or Image.Image.
"""
if isinstance(img, str):
img = cv2.imread(img)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
img_input = self.transform({"image": img})["image"]
with paddle.no_grad():
sample = paddle.to_tensor(img_input).unsqueeze(0)
prediction = self.model.forward(sample)
prediction = (paddle.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze().numpy())
if self.output_path:
os.makedirs(self.output_path, exist_ok=True)
img_name = img if isinstance(img, str) else 'depth'
filename = os.path.join(
self.output_path,
os.path.splitext(os.path.basename(img_name))[0])
pfm_f, png_f = write_depth(filename, prediction, bits=2)
return prediction, pfm_f, png_f
return prediction
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册