diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index cf601f03ff448c901ae13ea04f129e043d6cae7c..6c1b19a1104d8c91a69c4e60b9e0fd281c80c32f 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -25,7 +25,7 @@ from .lstm import LSTM from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold from .embedding import Embedding from .pooling import AvgPool2d, MaxPool2d -from .image import ImageGradients, SSIM +from .image import ImageGradients, SSIM, PSNR __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', @@ -36,5 +36,5 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Embedding', 'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold', - 'ImageGradients', 'SSIM', + 'ImageGradients', 'SSIM', 'PSNR', ] diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 6121776f59a7de6b63a0a4a4ef201aca7ffad7de..72c4c6d8e2d7bacf0a0fb6f7641b7db621aaeaab 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -69,6 +69,18 @@ class ImageGradients(Cell): return dy, dx +def _convert_img_dtype_to_float32(img, max_val): + """convert img dtype to float32""" + # Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1. + # We will scale img pixel value if max_val > 1. and just cast otherwise. + ret = F.cast(img, mstype.float32) + max_val = F.scalar_cast(max_val, mstype.float32) + if max_val > 1.: + scale = 1. / max_val + ret = ret * scale + return ret + + @constexpr def _gauss_kernel_helper(filter_size): """gauss kernel helper""" @@ -134,9 +146,9 @@ class SSIM(Cell): self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) def construct(self, img1, img2): - max_val = self._convert_img_dtype_to_float32(self.max_val, self.max_val) - img1 = self._convert_img_dtype_to_float32(img1, self.max_val) - img2 = self._convert_img_dtype_to_float32(img2, self.max_val) + max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) + img1 = _convert_img_dtype_to_float32(img1, self.max_val) + img2 = _convert_img_dtype_to_float32(img2, self.max_val) kernel = self._fspecial_gauss(self.filter_size, self.filter_sigma) kernel = P.Tile()(kernel, (1, P.Shape()(img1)[1], 1, 1)) @@ -145,21 +157,10 @@ class SSIM(Cell): return mean_ssim - def _convert_img_dtype_to_float32(self, img, max_val): - """convert img dtype to float32""" - # Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1. - # We will scale img pixel value if max_val > 1. and just cast otherwise. - ret = P.Cast()(img, mstype.float32) - max_val = F.scalar_cast(max_val, mstype.float32) - if max_val > 1.: - scale = 1./max_val - ret = ret*scale - return ret - def _calculate_mean_ssim(self, x, y, kernel, max_val, k1, k2): """calculate mean ssim""" - c1 = (k1*max_val)*(k1*max_val) - c2 = (k2*max_val)*(k2*max_val) + c1 = (k1 * max_val) * (k1 * max_val) + c2 = (k2 * max_val) * (k2 * max_val) # SSIM luminance formula # (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1) @@ -195,3 +196,52 @@ class SSIM(Cell): g = P.Softmax()(g) ret = F.reshape(g, (1, 1, filter_size, filter_size)) return ret + + +class PSNR(Cell): + r""" + Returns Peak Signal-to-Noise Ratio of two image batches. + + It produces a PSNR value for each image in batch. + Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`. + :math:`MAX` represents the dynamic range of pixel values. + + .. math:: + + MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\ + PSNR&=10*log_{10}(\frac{MAX^2}{MSE}) + + Args: + max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). + Default: 1.0. + + Inputs: + - **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2. + - **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1. + + Outputs: + Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1. + + Examples: + >>> net = nn.PSNR() + >>> img1 = Tensor(np.random.random((1,3,16,16))) + >>> img2 = Tensor(np.random.random((1,3,16,16))) + >>> psnr = net(img1, img2) + + """ + def __init__(self, max_val=1.0): + super(PSNR, self).__init__() + validator.check_type('max_val', max_val, [int, float]) + validator.check('max_val', max_val, '', 0.0, Rel.GT) + self.max_val = max_val + + def construct(self, img1, img2): + max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) + img1 = _convert_img_dtype_to_float32(img1, self.max_val) + img2 = _convert_img_dtype_to_float32(img2, self.max_val) + + mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1)) + # 10*log_10(max_val^2/MSE) + psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0) + + return psnr diff --git a/tests/ut/python/nn/test_psnr.py b/tests/ut/python/nn/test_psnr.py new file mode 100644 index 0000000000000000000000000000000000000000..5a908b308dd105fdbb18c08e06492b86910cefb0 --- /dev/null +++ b/tests/ut/python/nn/test_psnr.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test psnr +""" +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore import Tensor + + +class PSNRNet(nn.Cell): + def __init__(self, max_val=1.0): + super(PSNRNet, self).__init__() + self.net = nn.PSNR(max_val) + + def construct(self, img1, img2): + return self.net(img1, img2) + + +def test_compile_psnr(): + max_val = 1.0 + net = PSNRNet(max_val) + img1 = Tensor(np.random.random((8, 3, 16, 16))) + img2 = Tensor(np.random.random((8, 3, 16, 16))) + _executor.compile(net, img1, img2) + +def test_compile_psnr_grayscale(): + max_val = 255 + net = PSNRNet(max_val) + img1 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8)) + img2 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8)) + _executor.compile(net, img1, img2) + +def test_psnr_max_val_negative(): + max_val = -1 + with pytest.raises(ValueError): + net = PSNRNet(max_val) + +def test_psnr_max_val_bool(): + max_val = True + with pytest.raises(ValueError): + net = PSNRNet(max_val) + +def test_psnr_max_val_zero(): + max_val = 0 + with pytest.raises(ValueError): + net = PSNRNet(max_val) diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..a698b59f69cacc81c6fb958fe48e137c4e59d49f --- /dev/null +++ b/tests/ut/python/nn/test_ssim.py @@ -0,0 +1,95 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test ssim +""" +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore import Tensor + + +class SSIMNet(nn.Cell): + def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): + super(SSIMNet, self).__init__() + self.net = nn.SSIM(max_val, filter_size, filter_sigma, k1, k2) + + def construct(self, img1, img2): + return self.net(img1, img2) + + +def test_compile(): + net = SSIMNet() + img1 = Tensor(np.random.random((8, 3, 16, 16))) + img2 = Tensor(np.random.random((8, 3, 16, 16))) + _executor.compile(net, img1, img2) + +def test_compile_grayscale(): + max_val = 255 + net = SSIMNet(max_val = max_val) + img1 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8)) + img2 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8)) + _executor.compile(net, img1, img2) + +def test_ssim_max_val_negative(): + max_val = -1 + with pytest.raises(ValueError): + net = SSIMNet(max_val) + +def test_ssim_max_val_bool(): + max_val = True + with pytest.raises(ValueError): + net = SSIMNet(max_val) + +def test_ssim_max_val_zero(): + max_val = 0 + with pytest.raises(ValueError): + net = SSIMNet(max_val) + +def test_ssim_filter_size_float(): + with pytest.raises(ValueError): + net = SSIMNet(filter_size=1.1) + +def test_ssim_filter_size_zero(): + with pytest.raises(ValueError): + net = SSIMNet(filter_size=0) + +def test_ssim_filter_sigma_zero(): + with pytest.raises(ValueError): + net = SSIMNet(filter_sigma=0.0) + +def test_ssim_filter_sigma_negative(): + with pytest.raises(ValueError): + net = SSIMNet(filter_sigma=-0.1) + +def test_ssim_k1_k2_wrong_value(): + with pytest.raises(ValueError): + net = SSIMNet(k1=1.1) + with pytest.raises(ValueError): + net = SSIMNet(k1=1.0) + with pytest.raises(ValueError): + net = SSIMNet(k1=0.0) + with pytest.raises(ValueError): + net = SSIMNet(k1=-1.0) + + with pytest.raises(ValueError): + net = SSIMNet(k2=1.1) + with pytest.raises(ValueError): + net = SSIMNet(k2=1.0) + with pytest.raises(ValueError): + net = SSIMNet(k2=0.0) + with pytest.raises(ValueError): + net = SSIMNet(k2=-1.0) \ No newline at end of file