提交 ea681bfb 编写于 作者: Z zhaozhenlong

fix ssim filter size check

上级 f48ba776
......@@ -104,6 +104,12 @@ def _check_input_4d(input_shape, param_name, func_name):
raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}")
return True
@constexpr
def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
_check_input_4d(input_shape, param_name, func_name)
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
class SSIM(Cell):
r"""
Returns SSIM index between img1 and img2.
......@@ -154,8 +160,7 @@ class SSIM(Cell):
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
_check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
......
......@@ -1754,6 +1754,10 @@ raise_set = [
'block': (P.PReLU(), {'exception': ValueError}),
'desc_inputs': [[2], [1]],
'desc_bprop': [[1]]}),
('SSIM', {
'block': (nn.SSIM(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones((1, 3, 8, 8)), mstype.float32),
Tensor(np.ones((1, 3, 8, 8)), mstype.float32)]}),
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册