提交 6089d58d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2171 ssim input shape h and w should be greater than or equal to filter_size

Merge pull request !2171 from zhaozhenlong/fix-issues-quant-not-exposed-ssim-ksize-check
...@@ -104,6 +104,12 @@ def _check_input_4d(input_shape, param_name, func_name): ...@@ -104,6 +104,12 @@ def _check_input_4d(input_shape, param_name, func_name):
raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}") raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}")
return True return True
@constexpr
def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
_check_input_4d(input_shape, param_name, func_name)
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
class SSIM(Cell): class SSIM(Cell):
r""" r"""
Returns SSIM index between img1 and img2. Returns SSIM index between img1 and img2.
...@@ -154,8 +160,7 @@ class SSIM(Cell): ...@@ -154,8 +160,7 @@ class SSIM(Cell):
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2): def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name) _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
P.SameTypeShape()(img1, img2) P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val) img1 = _convert_img_dtype_to_float32(img1, self.max_val)
......
...@@ -1754,6 +1754,10 @@ raise_set = [ ...@@ -1754,6 +1754,10 @@ raise_set = [
'block': (P.PReLU(), {'exception': ValueError}), 'block': (P.PReLU(), {'exception': ValueError}),
'desc_inputs': [[2], [1]], 'desc_inputs': [[2], [1]],
'desc_bprop': [[1]]}), 'desc_bprop': [[1]]}),
('SSIM', {
'block': (nn.SSIM(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones((1, 3, 8, 8)), mstype.float32),
Tensor(np.ones((1, 3, 8, 8)), mstype.float32)]}),
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册