提交 c88edfb3 编写于 作者: Z zhaozhenlong

psnr check two input same shape and type

上级 3b625ac9
......@@ -95,6 +95,11 @@ def _gauss_kernel_helper(filter_size):
g = Tensor(g)
return filter_size, g
@constexpr
def _check_input_4d(input_shape, param_name, func_name):
if len(input_shape) != 4:
raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}")
return True
class SSIM(Cell):
r"""
......@@ -146,6 +151,9 @@ class SSIM(Cell):
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", "SSIM")
_check_input_4d(F.shape(img2), "img2", "SSIM")
P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
......@@ -236,6 +244,9 @@ class PSNR(Cell):
self.max_val = max_val
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", "PSNR")
_check_input_4d(F.shape(img2), "img2", "PSNR")
P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
img2 = _convert_img_dtype_to_float32(img2, self.max_val)
......
......@@ -18,10 +18,12 @@ test psnr
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from mindspore.common.api import _executor
from mindspore import Tensor
class PSNRNet(nn.Cell):
def __init__(self, max_val=1.0):
super(PSNRNet, self).__init__()
......@@ -59,3 +61,38 @@ def test_psnr_max_val_zero():
max_val = 0
with pytest.raises(ValueError):
net = PSNRNet(max_val)
def test_psnr_different_shape():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
img1 = Tensor(np.random.random(shape_1))
img2 = Tensor(np.random.random(shape_2))
net = PSNRNet()
with pytest.raises(ValueError):
_executor.compile(net, img1, img2)
def test_psnr_different_dtype():
dtype_1 = mstype.float32
dtype_2 = mstype.float16
img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1)
img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2)
net = PSNRNet()
with pytest.raises(TypeError):
_executor.compile(net, img1, img2)
def test_psnr_invalid_5d_input():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
invalid_shape = (8, 3, 16, 16, 1)
img1 = Tensor(np.random.random(shape_1))
invalid_img1 = Tensor(np.random.random(invalid_shape))
img2 = Tensor(np.random.random(shape_2))
invalid_img2 = Tensor(np.random.random(invalid_shape))
net = PSNRNet()
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, img2)
with pytest.raises(ValueError):
_executor.compile(net, img1, invalid_img2)
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, invalid_img2)
......@@ -18,6 +18,7 @@ test ssim
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common.api import _executor
from mindspore import Tensor
......@@ -93,3 +94,38 @@ def test_ssim_k1_k2_wrong_value():
net = SSIMNet(k2=0.0)
with pytest.raises(ValueError):
net = SSIMNet(k2=-1.0)
def test_ssim_different_shape():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
img1 = Tensor(np.random.random(shape_1))
img2 = Tensor(np.random.random(shape_2))
net = SSIMNet()
with pytest.raises(ValueError):
_executor.compile(net, img1, img2)
def test_ssim_different_dtype():
dtype_1 = mstype.float32
dtype_2 = mstype.float16
img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1)
img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2)
net = SSIMNet()
with pytest.raises(TypeError):
_executor.compile(net, img1, img2)
def test_ssim_invalid_5d_input():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
invalid_shape = (8, 3, 16, 16, 1)
img1 = Tensor(np.random.random(shape_1))
invalid_img1 = Tensor(np.random.random(invalid_shape))
img2 = Tensor(np.random.random(shape_2))
invalid_img2 = Tensor(np.random.random(invalid_shape))
net = SSIMNet()
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, img2)
with pytest.raises(ValueError):
_executor.compile(net, img1, invalid_img2)
with pytest.raises(ValueError):
_executor.compile(net, invalid_img1, invalid_img2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册