提交 c88edfb3 编写于 作者: Z zhaozhenlong

psnr check two input same shape and type

上级 3b625ac9
...@@ -95,6 +95,11 @@ def _gauss_kernel_helper(filter_size): ...@@ -95,6 +95,11 @@ def _gauss_kernel_helper(filter_size):
g = Tensor(g) g = Tensor(g)
return filter_size, g return filter_size, g
@constexpr
def _check_input_4d(input_shape, param_name, func_name):
if len(input_shape) != 4:
raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}")
return True
class SSIM(Cell): class SSIM(Cell):
r""" r"""
...@@ -146,6 +151,9 @@ class SSIM(Cell): ...@@ -146,6 +151,9 @@ class SSIM(Cell):
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2): def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", "SSIM")
_check_input_4d(F.shape(img2), "img2", "SSIM")
P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val) img1 = _convert_img_dtype_to_float32(img1, self.max_val)
img2 = _convert_img_dtype_to_float32(img2, self.max_val) img2 = _convert_img_dtype_to_float32(img2, self.max_val)
...@@ -236,6 +244,9 @@ class PSNR(Cell): ...@@ -236,6 +244,9 @@ class PSNR(Cell):
self.max_val = max_val self.max_val = max_val
def construct(self, img1, img2): def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", "PSNR")
_check_input_4d(F.shape(img2), "img2", "PSNR")
P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val) img1 = _convert_img_dtype_to_float32(img1, self.max_val)
img2 = _convert_img_dtype_to_float32(img2, self.max_val) img2 = _convert_img_dtype_to_float32(img2, self.max_val)
......
...@@ -18,10 +18,12 @@ test psnr ...@@ -18,10 +18,12 @@ test psnr
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common import dtype as mstype
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore import Tensor from mindspore import Tensor
class PSNRNet(nn.Cell): class PSNRNet(nn.Cell):
def __init__(self, max_val=1.0): def __init__(self, max_val=1.0):
super(PSNRNet, self).__init__() super(PSNRNet, self).__init__()
...@@ -59,3 +61,38 @@ def test_psnr_max_val_zero(): ...@@ -59,3 +61,38 @@ def test_psnr_max_val_zero():
max_val = 0 max_val = 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
net = PSNRNet(max_val) net = PSNRNet(max_val)
def test_psnr_different_shape():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
img1 = Tensor(np.random.random(shape_1))
img2 = Tensor(np.random.random(shape_2))
net = PSNRNet()
with pytest.raises(ValueError):
_executor.compile(net, img1, img2)
def test_psnr_different_dtype():
dtype_1 = mstype.float32
dtype_2 = mstype.float16
img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1)
img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2)
net = PSNRNet()
with pytest.raises(TypeError):
_executor.compile(net, img1, img2)
def test_psnr_invalid_5d_input():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
invalid_shape = (8, 3, 16, 16, 1)
img1 = Tensor(np.random.random(shape_1))
invalid_img1 = Tensor(np.random.random(invalid_shape))
img2 = Tensor(np.random.random(shape_2))
invalid_img2 = Tensor(np.random.random(invalid_shape))
net = PSNRNet()
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, img2)
with pytest.raises(ValueError):
_executor.compile(net, img1, invalid_img2)
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, invalid_img2)
...@@ -18,6 +18,7 @@ test ssim ...@@ -18,6 +18,7 @@ test ssim
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore import Tensor from mindspore import Tensor
...@@ -93,3 +94,38 @@ def test_ssim_k1_k2_wrong_value(): ...@@ -93,3 +94,38 @@ def test_ssim_k1_k2_wrong_value():
net = SSIMNet(k2=0.0) net = SSIMNet(k2=0.0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net = SSIMNet(k2=-1.0) net = SSIMNet(k2=-1.0)
def test_ssim_different_shape():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
img1 = Tensor(np.random.random(shape_1))
img2 = Tensor(np.random.random(shape_2))
net = SSIMNet()
with pytest.raises(ValueError):
_executor.compile(net, img1, img2)
def test_ssim_different_dtype():
dtype_1 = mstype.float32
dtype_2 = mstype.float16
img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1)
img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2)
net = SSIMNet()
with pytest.raises(TypeError):
_executor.compile(net, img1, img2)
def test_ssim_invalid_5d_input():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
invalid_shape = (8, 3, 16, 16, 1)
img1 = Tensor(np.random.random(shape_1))
invalid_img1 = Tensor(np.random.random(invalid_shape))
img2 = Tensor(np.random.random(shape_2))
invalid_img2 = Tensor(np.random.random(invalid_shape))
net = SSIMNet()
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, img2)
with pytest.raises(ValueError):
_executor.compile(net, img1, invalid_img2)
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, invalid_img2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册