From 77b8bac378fa8ece7799370bf7222ad9cb7ec9b7 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Fri, 11 Dec 2020 10:15:11 +0800 Subject: [PATCH] Add MidDaSv2 in ppgan.apps (#118) * Add MidDaSv2 in ppgan.apps * remove ppgan/apps/midas/run.py --- docs/zh_CN/apis/apps.md | 44 ++++++++ ppgan/apps/__init__.py | 1 + ppgan/apps/animegan_predictor.py | 2 +- ppgan/apps/midas/README.md | 12 ++ ppgan/apps/midas/__init__.py | 0 ppgan/apps/midas/blocks.py | 164 +++++++++++++++++++++++++++ ppgan/apps/midas/midas_net.py | 92 +++++++++++++++ ppgan/apps/midas/resnext.py | 86 ++++++++++++++ ppgan/apps/midas/transforms.py | 185 +++++++++++++++++++++++++++++++ ppgan/apps/midas/utils.py | 88 +++++++++++++++ ppgan/apps/midas_predictor.py | 98 ++++++++++++++++ 11 files changed, 771 insertions(+), 1 deletion(-) create mode 100644 ppgan/apps/midas/README.md create mode 100644 ppgan/apps/midas/__init__.py create mode 100644 ppgan/apps/midas/blocks.py create mode 100644 ppgan/apps/midas/midas_net.py create mode 100644 ppgan/apps/midas/resnext.py create mode 100644 ppgan/apps/midas/transforms.py create mode 100644 ppgan/apps/midas/utils.py create mode 100644 ppgan/apps/midas_predictor.py diff --git a/docs/zh_CN/apis/apps.md b/docs/zh_CN/apis/apps.md index 034e3a5..1ec7426 100644 --- a/docs/zh_CN/apis/apps.md +++ b/docs/zh_CN/apis/apps.md @@ -387,3 +387,47 @@ ppgan.apps.AnimeGANPredictor(output_path='output_dir',weight_path=None,use_adjus > ``` > **返回值:** > > - anime_image(numpy.ndarray): 返回风格化后的景色图像 + + +## ppgan.apps.MiDaSPredictor + +```pyhton +ppgan.apps.MiDaSPredictor(output=None, weight_path=None) +``` + +> 单目深度估计模型MiDaSv2, 参考 https://github.com/intel-isl/MiDaS, 论文是 Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer , 论文链接: https://arxiv.org/abs/1907.01341v3 + +> **示例** +> +> ```python +> from ppgan.apps import MiDaSPredictor +> # if set output, will write depth pfm and png file in output/MiDaS +> model = MiDaSPredictor() +> prediction = model.run() +> ``` +> +> 深度图彩色显示: +> +> ```python +> import numpy as np +> import PIL.Image as Image +> import matplotlib as mpl +> import matplotlib.cm as cm +> +> vmax = np.percentile(prediction, 95) +> normalizer = mpl.colors.Normalize(vmin=prediction.min(), vmax=vmax) +> mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') +> colormapped_im = (mapper.to_rgba(prediction)[:, :, :3] * 255).astype(np.uint8) +> im = Image.fromarray(colormapped_im) +> im.save('test_disp.jpeg') +> ``` +> +> **参数:** +> +> > - output (str): 输出路径,如果是None,则不保存pfm和png的深度图文件。 +> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。 + +> **返回值:** +> > - prediction (numpy.ndarray): 返回预测结果。 +> > - pfm_f (str): 如果设置output路径,返回pfm文件保存路径。 +> > - png_f (str): 如果设置output路径,返回png文件保存路径。 diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 5115d05..8704748 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -20,3 +20,4 @@ from .edvr_predictor import EDVRPredictor from .first_order_predictor import FirstOrderPredictor from .face_parse_predictor import FaceParsePredictor from .animegan_predictor import AnimeGANPredictor +from .midas_predictor import MiDaSPredictor diff --git a/ppgan/apps/animegan_predictor.py b/ppgan/apps/animegan_predictor.py index b3579b5..8c5655d 100644 --- a/ppgan/apps/animegan_predictor.py +++ b/ppgan/apps/animegan_predictor.py @@ -26,7 +26,7 @@ from ppgan.utils.download import get_path_from_url class AnimeGANPredictor(BasePredictor): def __init__(self, - output_path='output_dir', + output_path='output', weight_path=None, use_adjust_brightness=True): self.output_path = output_path diff --git a/ppgan/apps/midas/README.md b/ppgan/apps/midas/README.md new file mode 100644 index 0000000..3b7e6e1 --- /dev/null +++ b/ppgan/apps/midas/README.md @@ -0,0 +1,12 @@ +## Monocular Depth Estimation + + +The implemention of MiDasv2 refers to https://github.com/intel-isl/MiDaS. + + +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} diff --git a/ppgan/apps/midas/__init__.py b/ppgan/apps/midas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ppgan/apps/midas/blocks.py b/ppgan/apps/midas/blocks.py new file mode 100644 index 0000000..bd2c761 --- /dev/null +++ b/ppgan/apps/midas/blocks.py @@ -0,0 +1,164 @@ +# Refer https://github.com/intel-isl/MiDaS + +import paddle +import paddle.nn as nn + + +def _make_encoder(backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True): + if backbone == "resnext101_wsl": + # resnext101_wsl + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], + features, + groups=groups, + expand=expand) + else: + print(f"Backbone '{backbone}' not implemented") + assert False + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Layer() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2D(in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False, + groups=groups) + scratch.layer2_rn = nn.Conv2D(in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False, + groups=groups) + scratch.layer3_rn = nn.Conv2D(in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False, + groups=groups) + scratch.layer4_rn = nn.Conv2D(in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False, + groups=groups) + + return scratch + + +def _make_resnet_backbone(resnet): + pretrained = nn.Layer() + pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, + resnet.maxpool, resnet.layer1) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + from .resnext import resnext101_32x8d_wsl + resnet = resnext101_32x8d_wsl() + return _make_resnet_backbone(resnet) + + +class ResidualConvUnit(nn.Layer): + """Residual convolution module. + """ + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2D(features, + features, + kernel_size=3, + stride=1, + padding=1, + bias_attr=True) + + self.conv2 = nn.Conv2D(features, + features, + kernel_size=3, + stride=1, + padding=1, + bias_attr=True) + + self.relu = nn.ReLU() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + x = self.relu(x) + out = self.conv1(x) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Layer): + """Feature fusion block. + """ + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + output = nn.functional.interpolate(output, + scale_factor=2, + mode="bilinear", + align_corners=True) + + return output diff --git a/ppgan/apps/midas/midas_net.py b/ppgan/apps/midas/midas_net.py new file mode 100644 index 0000000..ef0a00c --- /dev/null +++ b/ppgan/apps/midas/midas_net.py @@ -0,0 +1,92 @@ +# Refer https://github.com/intel-isl/MiDaS +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +""" +import numpy as np +import paddle +import paddle.nn as nn + +from .blocks import FeatureFusionBlock, _make_encoder + + +class BaseModel(paddle.nn.Layer): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = paddle.load(path) + self.set_dict(parameters) + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", + features=features, + use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + output_conv = [ + nn.Conv2D(features, 128, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor=2, mode="bilinear"), + nn.Conv2D(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2D(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU() if non_negative else nn.Identity(), + ] + if non_negative: + output_conv.append(nn.ReLU()) + + self.scratch.output_conv = nn.Sequential(*output_conv) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return paddle.squeeze(out, axis=1) diff --git a/ppgan/apps/midas/resnext.py b/ppgan/apps/midas/resnext.py new file mode 100644 index 0000000..198dfcc --- /dev/null +++ b/ppgan/apps/midas/resnext.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +from paddle.vision.models.resnet import ResNet +from paddle.vision.models.resnet import BottleneckBlock + +from paddle.utils.download import get_weights_path_from_url + +__all__ = ['resnext101_32x8d_wsl'] + + +class ResNetEx(ResNet): + """ResNet extention model, support ResNeXt. + """ + def __init__(self, + block, + depth, + num_classes=1000, + with_pool=True, + groups=1, + width_per_group=64): + self.groups = groups + self.base_width = width_per_group + + super(ResNetEx, self).__init__(block, depth, num_classes, with_pool) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D(self.inplanes, + planes * block.expansion, + 1, + stride=stride, + bias_attr=False), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + +def _resnext(arch, Block, depth, **kwargs): + model = ResNetEx(Block, depth, **kwargs) + return model + + +def resnext101_32x8d_wsl(**kwargs): + """ResNet101 32x8d wsl model + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnext('resnet101_32x8d', BottleneckBlock, 101, **kwargs) diff --git a/ppgan/apps/midas/transforms.py b/ppgan/apps/midas/transforms.py new file mode 100644 index 0000000..530c552 --- /dev/null +++ b/ppgan/apps/midas/transforms.py @@ -0,0 +1,185 @@ +# Refer https://github.com/intel-isl/MiDaS + +import numpy as np +import cv2 +import math + + +class Resize(object): + """Resize sample to given size (width, height). + """ + def __init__(self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * + self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * + self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, + min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, + min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, + max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, + max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], + sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), + interpolation=cv2.INTER_NEAREST) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/ppgan/apps/midas/utils.py b/ppgan/apps/midas/utils.py new file mode 100644 index 0000000..3054a49 --- /dev/null +++ b/ppgan/apps/midas/utils.py @@ -0,0 +1,88 @@ +# Refer https://github.com/intel-isl/MiDaS +"""Utils for monoDepth. +""" +import sys +import re +import numpy as np +import cv2 + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif (len(image.shape) == 2 + or len(image.shape) == 3 and image.shape[2] == 1): # greyscale + color = False + else: + raise Exception( + "Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + return img + + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8 * bits)) - 1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + return path + '.pfm', path + ".png" diff --git a/ppgan/apps/midas_predictor.py b/ppgan/apps/midas_predictor.py new file mode 100644 index 0000000..82b0f46 --- /dev/null +++ b/ppgan/apps/midas_predictor.py @@ -0,0 +1,98 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import numpy as np +import cv2 + +import paddle +from paddle.vision.transforms import Compose + +from ppgan.utils.download import get_path_from_url +from .base_predictor import BasePredictor +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from .midas.midas_net import MidasNet +from .midas.utils import write_depth + + +class MiDaSPredictor(BasePredictor): + def __init__(self, output=None, weight_path=None): + """ + output (str|None): output path, if None, do not write + depth map to pfm and png file. + weight_path (str|None): weight path, if None, load default + MiDaSv2.1 model. + """ + self.output_path = os.path.join(output, 'MiDaS') if output else None + + self.net_h, self.net_w = 384, 384 + if weight_path is None: + midasv2_weight_url = 'https://paddlegan.bj.bcebos.com/applications/midas.pdparams' + weight_path = get_path_from_url(midasv2_weight_url) + self.weight_path = weight_path + + self.model = self.load_checkpoints() + + self.transform = Compose([ + Resize( + self.net_w, + self.net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + def load_checkpoints(self): + model = MidasNet(self.weight_path, non_negative=True) + model.eval() + return model + + def run(self, img): + """ + img (str|np.ndarray|Image.Image): input image, it can be + images directory, Numpy.array or Image.Image. + """ + if isinstance(img, str): + img = cv2.imread(img) + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + img_input = self.transform({"image": img})["image"] + + with paddle.no_grad(): + sample = paddle.to_tensor(img_input).unsqueeze(0) + prediction = self.model.forward(sample) + prediction = (paddle.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze().numpy()) + + if self.output_path: + os.makedirs(self.output_path, exist_ok=True) + img_name = img if isinstance(img, str) else 'depth' + filename = os.path.join( + self.output_path, + os.path.splitext(os.path.basename(img_name))[0]) + pfm_f, png_f = write_depth(filename, prediction, bits=2) + return prediction, pfm_f, png_f + return prediction -- GitLab