提交 aa8fbcc0 编写于 作者: Z zhaozhenlong

add cell psnr

上级 c0c0b098
......@@ -25,7 +25,7 @@ from .lstm import LSTM
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold
from .embedding import Embedding
from .pooling import AvgPool2d, MaxPool2d
from .image import ImageGradients, SSIM
from .image import ImageGradients, SSIM, PSNR
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
......@@ -36,5 +36,5 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot',
'Embedding',
'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold',
'ImageGradients', 'SSIM',
'ImageGradients', 'SSIM', 'PSNR',
]
......@@ -69,6 +69,18 @@ class ImageGradients(Cell):
return dy, dx
def _convert_img_dtype_to_float32(img, max_val):
"""convert img dtype to float32"""
# Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# We will scale img pixel value if max_val > 1. and just cast otherwise.
ret = F.cast(img, mstype.float32)
max_val = F.scalar_cast(max_val, mstype.float32)
if max_val > 1.:
scale = 1. / max_val
ret = ret * scale
return ret
@constexpr
def _gauss_kernel_helper(filter_size):
"""gauss kernel helper"""
......@@ -134,9 +146,9 @@ class SSIM(Cell):
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2):
max_val = self._convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = self._convert_img_dtype_to_float32(img1, self.max_val)
img2 = self._convert_img_dtype_to_float32(img2, self.max_val)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
kernel = self._fspecial_gauss(self.filter_size, self.filter_sigma)
kernel = P.Tile()(kernel, (1, P.Shape()(img1)[1], 1, 1))
......@@ -145,21 +157,10 @@ class SSIM(Cell):
return mean_ssim
def _convert_img_dtype_to_float32(self, img, max_val):
"""convert img dtype to float32"""
# Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# We will scale img pixel value if max_val > 1. and just cast otherwise.
ret = P.Cast()(img, mstype.float32)
max_val = F.scalar_cast(max_val, mstype.float32)
if max_val > 1.:
scale = 1./max_val
ret = ret*scale
return ret
def _calculate_mean_ssim(self, x, y, kernel, max_val, k1, k2):
"""calculate mean ssim"""
c1 = (k1*max_val)*(k1*max_val)
c2 = (k2*max_val)*(k2*max_val)
c1 = (k1 * max_val) * (k1 * max_val)
c2 = (k2 * max_val) * (k2 * max_val)
# SSIM luminance formula
# (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1)
......@@ -195,3 +196,52 @@ class SSIM(Cell):
g = P.Softmax()(g)
ret = F.reshape(g, (1, 1, filter_size, filter_size))
return ret
class PSNR(Cell):
r"""
Returns Peak Signal-to-Noise Ratio of two image batches.
It produces a PSNR value for each image in batch.
Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`.
:math:`MAX` represents the dynamic range of pixel values.
.. math::
MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\
PSNR&=10*log_{10}(\frac{MAX^2}{MSE})
Args:
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
Default: 1.0.
Inputs:
- **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2.
- **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1.
Outputs:
Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.
Examples:
>>> net = nn.PSNR()
>>> img1 = Tensor(np.random.random((1,3,16,16)))
>>> img2 = Tensor(np.random.random((1,3,16,16)))
>>> psnr = net(img1, img2)
"""
def __init__(self, max_val=1.0):
super(PSNR, self).__init__()
validator.check_type('max_val', max_val, [int, float])
validator.check('max_val', max_val, '', 0.0, Rel.GT)
self.max_val = max_val
def construct(self, img1, img2):
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1))
# 10*log_10(max_val^2/MSE)
psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)
return psnr
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test psnr
"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore import Tensor
class PSNRNet(nn.Cell):
def __init__(self, max_val=1.0):
super(PSNRNet, self).__init__()
self.net = nn.PSNR(max_val)
def construct(self, img1, img2):
return self.net(img1, img2)
def test_compile_psnr():
max_val = 1.0
net = PSNRNet(max_val)
img1 = Tensor(np.random.random((8, 3, 16, 16)))
img2 = Tensor(np.random.random((8, 3, 16, 16)))
_executor.compile(net, img1, img2)
def test_compile_psnr_grayscale():
max_val = 255
net = PSNRNet(max_val)
img1 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8))
img2 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8))
_executor.compile(net, img1, img2)
def test_psnr_max_val_negative():
max_val = -1
with pytest.raises(ValueError):
net = PSNRNet(max_val)
def test_psnr_max_val_bool():
max_val = True
with pytest.raises(ValueError):
net = PSNRNet(max_val)
def test_psnr_max_val_zero():
max_val = 0
with pytest.raises(ValueError):
net = PSNRNet(max_val)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test ssim
"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore import Tensor
class SSIMNet(nn.Cell):
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
super(SSIMNet, self).__init__()
self.net = nn.SSIM(max_val, filter_size, filter_sigma, k1, k2)
def construct(self, img1, img2):
return self.net(img1, img2)
def test_compile():
net = SSIMNet()
img1 = Tensor(np.random.random((8, 3, 16, 16)))
img2 = Tensor(np.random.random((8, 3, 16, 16)))
_executor.compile(net, img1, img2)
def test_compile_grayscale():
max_val = 255
net = SSIMNet(max_val = max_val)
img1 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8))
img2 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8))
_executor.compile(net, img1, img2)
def test_ssim_max_val_negative():
max_val = -1
with pytest.raises(ValueError):
net = SSIMNet(max_val)
def test_ssim_max_val_bool():
max_val = True
with pytest.raises(ValueError):
net = SSIMNet(max_val)
def test_ssim_max_val_zero():
max_val = 0
with pytest.raises(ValueError):
net = SSIMNet(max_val)
def test_ssim_filter_size_float():
with pytest.raises(ValueError):
net = SSIMNet(filter_size=1.1)
def test_ssim_filter_size_zero():
with pytest.raises(ValueError):
net = SSIMNet(filter_size=0)
def test_ssim_filter_sigma_zero():
with pytest.raises(ValueError):
net = SSIMNet(filter_sigma=0.0)
def test_ssim_filter_sigma_negative():
with pytest.raises(ValueError):
net = SSIMNet(filter_sigma=-0.1)
def test_ssim_k1_k2_wrong_value():
with pytest.raises(ValueError):
net = SSIMNet(k1=1.1)
with pytest.raises(ValueError):
net = SSIMNet(k1=1.0)
with pytest.raises(ValueError):
net = SSIMNet(k1=0.0)
with pytest.raises(ValueError):
net = SSIMNet(k1=-1.0)
with pytest.raises(ValueError):
net = SSIMNet(k2=1.1)
with pytest.raises(ValueError):
net = SSIMNet(k2=1.0)
with pytest.raises(ValueError):
net = SSIMNet(k2=0.0)
with pytest.raises(ValueError):
net = SSIMNet(k2=-1.0)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册