Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9b21420b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b21420b
编写于
7月 07, 2020
作者:
L
leilei_snow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update SSIM loss, add MSSSIM loss feature; add their ut testcases.
上级
03ef509e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
299 addition
and
80 deletion
+299
-80
mindspore/nn/layer/image.py
mindspore/nn/layer/image.py
+164
-60
tests/ut/python/nn/test_msssim.py
tests/ut/python/nn/test_msssim.py
+135
-0
tests/ut/python/nn/test_ssim.py
tests/ut/python/nn/test_ssim.py
+0
-20
未找到文件。
mindspore/nn/layer/image.py
浏览文件 @
9b21420b
...
...
@@ -21,9 +21,13 @@ from mindspore.ops import functional as F
from
mindspore.ops.primitive
import
constexpr
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
.conv
import
Conv2d
from
.container
import
CellList
from
.pooling
import
AvgPool2d
from
.activation
import
ReLU
from
..cell
import
Cell
__all__
=
[
'ImageGradients'
,
'SSIM'
,
'PSNR'
,
'CentralCrop'
]
__all__
=
[
'ImageGradients'
,
'SSIM'
,
'
MSSSIM'
,
'
PSNR'
,
'CentralCrop'
]
class
ImageGradients
(
Cell
):
r
"""
...
...
@@ -83,21 +87,6 @@ def _convert_img_dtype_to_float32(img, max_val):
ret
=
ret
*
scale
return
ret
@
constexpr
def
_gauss_kernel_helper
(
filter_size
):
"""gauss kernel helper"""
filter_size
=
F
.
scalar_cast
(
filter_size
,
mstype
.
int32
)
coords
=
()
for
i
in
range
(
filter_size
):
i_cast
=
F
.
scalar_cast
(
i
,
mstype
.
float32
)
offset
=
F
.
scalar_cast
(
filter_size
-
1
,
mstype
.
float32
)
/
2.0
element
=
i_cast
-
offset
coords
=
coords
+
(
element
,)
g
=
np
.
square
(
coords
).
astype
(
np
.
float32
)
g
=
Tensor
(
g
)
return
filter_size
,
g
@
constexpr
def
_check_input_4d
(
input_shape
,
param_name
,
func_name
):
if
len
(
input_shape
)
!=
4
:
...
...
@@ -110,9 +99,65 @@ def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
validator
.
check
(
param_name
+
" shape[2]"
,
input_shape
[
2
],
"filter_size"
,
filter_size
,
Rel
.
GE
,
func_name
)
validator
.
check
(
param_name
+
" shape[3]"
,
input_shape
[
3
],
"filter_size"
,
filter_size
,
Rel
.
GE
,
func_name
)
@
constexpr
def
_check_input_dtype
(
input_dtype
,
param_name
,
allow_dtypes
,
cls_name
):
validator
.
check_type_name
(
param_name
,
input_dtype
,
allow_dtypes
,
cls_name
)
def
_conv2d
(
in_channels
,
out_channels
,
kernel_size
,
weight
,
stride
=
1
,
padding
=
0
):
return
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
weight_init
=
weight
,
padding
=
padding
,
pad_mode
=
"valid"
)
def
_create_window
(
size
,
sigma
):
x_data
,
y_data
=
np
.
mgrid
[
-
size
//
2
+
1
:
size
//
2
+
1
,
-
size
//
2
+
1
:
size
//
2
+
1
]
x_data
=
np
.
expand_dims
(
x_data
,
axis
=-
1
).
astype
(
np
.
float32
)
x_data
=
np
.
expand_dims
(
x_data
,
axis
=-
1
)
**
2
y_data
=
np
.
expand_dims
(
y_data
,
axis
=-
1
).
astype
(
np
.
float32
)
y_data
=
np
.
expand_dims
(
y_data
,
axis
=-
1
)
**
2
sigma
=
2
*
sigma
**
2
g
=
np
.
exp
(
-
(
x_data
+
y_data
)
/
sigma
)
return
np
.
transpose
(
g
/
np
.
sum
(
g
),
(
2
,
3
,
0
,
1
))
def
_split_img
(
x
):
_
,
c
,
_
,
_
=
F
.
shape
(
x
)
img_split
=
P
.
Split
(
1
,
c
)
output
=
img_split
(
x
)
return
output
,
c
def
_compute_per_channel_loss
(
c1
,
c2
,
img1
,
img2
,
conv
):
"""computes ssim index between img1 and img2 per single channel"""
dot_img
=
img1
*
img2
mu1
=
conv
(
img1
)
mu2
=
conv
(
img2
)
mu1_sq
=
mu1
*
mu1
mu2_sq
=
mu2
*
mu2
mu1_mu2
=
mu1
*
mu2
sigma1_tmp
=
conv
(
img1
*
img1
)
sigma1_sq
=
sigma1_tmp
-
mu1_sq
sigma2_tmp
=
conv
(
img2
*
img2
)
sigma2_sq
=
sigma2_tmp
-
mu2_sq
sigma12_tmp
=
conv
(
dot_img
)
sigma12
=
sigma12_tmp
-
mu1_mu2
a
=
(
2
*
mu1_mu2
+
c1
)
b
=
(
mu1_sq
+
mu2_sq
+
c1
)
v1
=
2
*
sigma12
+
c2
v2
=
sigma1_sq
+
sigma2_sq
+
c2
ssim
=
(
a
*
v1
)
/
(
b
*
v2
)
cs
=
v1
/
v2
return
ssim
,
cs
def
_compute_multi_channel_loss
(
c1
,
c2
,
img1
,
img2
,
conv
,
concat
,
mean
):
"""computes ssim index between img1 and img2 per color channel"""
split_img1
,
c
=
_split_img
(
img1
)
split_img2
,
_
=
_split_img
(
img2
)
multi_ssim
=
()
multi_cs
=
()
for
i
in
range
(
c
):
ssim_per_channel
,
cs_per_channel
=
_compute_per_channel_loss
(
c1
,
c2
,
split_img1
[
i
],
split_img2
[
i
],
conv
)
multi_ssim
+=
(
ssim_per_channel
,)
multi_cs
+=
(
cs_per_channel
,)
multi_ssim
=
concat
(
multi_ssim
)
multi_cs
=
concat
(
multi_cs
)
ssim
=
mean
(
multi_ssim
,
(
2
,
3
))
cs
=
mean
(
multi_cs
,
(
2
,
3
))
return
ssim
,
cs
class
SSIM
(
Cell
):
r
"""
...
...
@@ -157,67 +202,126 @@ class SSIM(Cell):
self
.
max_val
=
max_val
self
.
filter_size
=
validator
.
check_integer
(
'filter_size'
,
filter_size
,
1
,
Rel
.
GE
,
self
.
cls_name
)
self
.
filter_sigma
=
validator
.
check_float_positive
(
'filter_sigma'
,
filter_sigma
,
self
.
cls_name
)
validator
.
check_value_type
(
'k1'
,
k1
,
[
float
],
self
.
cls_name
)
self
.
k1
=
validator
.
check_number_range
(
'k1'
,
k1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
self
.
cls_name
)
validator
.
check_value_type
(
'k2'
,
k2
,
[
float
],
self
.
cls_name
)
self
.
k2
=
validator
.
check_number_range
(
'k2'
,
k2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
self
.
cls_name
)
self
.
mean
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
kernel_size
=
filter_size
)
self
.
k1
=
validator
.
check_value_type
(
'k1'
,
k1
,
[
float
],
self
.
cls_name
)
self
.
k2
=
validator
.
check_value_type
(
'k2'
,
k2
,
[
float
],
self
.
cls_name
)
window
=
_create_window
(
filter_size
,
filter_sigma
)
self
.
conv
=
_conv2d
(
1
,
1
,
filter_size
,
Tensor
(
window
))
self
.
conv
.
weight
.
requires_grad
=
False
self
.
reduce_mean
=
P
.
ReduceMean
()
self
.
concat
=
P
.
Concat
(
axis
=
1
)
def
construct
(
self
,
img1
,
img2
):
_check_input_dtype
(
F
.
dtype
(
img1
),
"img1"
,
[
mstype
.
float32
,
mstype
.
float16
],
self
.
cls_name
)
_check_input_filter_size
(
F
.
shape
(
img1
),
"img1"
,
self
.
filter_size
,
self
.
cls_name
)
P
.
SameTypeShape
()(
img1
,
img2
)
max_val
=
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
img2
=
_convert_img_dtype_to_float32
(
img2
,
self
.
max_val
)
kernel
=
self
.
_fspecial_gauss
(
self
.
filter_size
,
self
.
filter_sigma
)
kernel
=
P
.
Tile
()(
kernel
,
(
1
,
P
.
Shape
()(
img1
)[
1
],
1
,
1
))
c1
=
(
self
.
k1
*
max_val
)
**
2
c2
=
(
self
.
k2
*
max_val
)
**
2
ssim_ave_channel
,
_
=
_compute_multi_channel_loss
(
c1
,
c2
,
img1
,
img2
,
self
.
conv
,
self
.
concat
,
self
.
reduce_mean
)
loss
=
self
.
reduce_mean
(
ssim_ave_channel
,
-
1
)
return
loss
def
_downsample
(
img1
,
img2
,
op
):
a
=
op
(
img1
)
b
=
op
(
img2
)
return
a
,
b
class
MSSSIM
(
Cell
):
r
"""
Returns MS-SSIM index between img1 and img2.
Its implementation is based on Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. `Multiscale structural similarity
for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_.
Signals, Systems and Computers, 2004.
mean_ssim
=
self
.
_calculate_mean_ssim
(
img1
,
img2
,
kernel
,
max_val
,
self
.
k1
,
self
.
k2
)
.. math::
return
mean_ssim
l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
MSSSIM(x,y)&=l^alpha_M*{\prod_{1\leq j\leq M} (c^beta_j*s^gamma_j)}.
def
_calculate_mean_ssim
(
self
,
x
,
y
,
kernel
,
max_val
,
k1
,
k2
):
"""calculate mean ssim"""
c1
=
(
k1
*
max_val
)
*
(
k1
*
max_val
)
c2
=
(
k2
*
max_val
)
*
(
k2
*
max_val
)
Args:
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
Default: 1.0.
power_factors (Union[tuple, list]): Iterable of weights for each of the scales.
Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al.
filter_size (int): The size of the Gaussian filter. Default: 11.
filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
# SSIM luminance formula
# (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1)
mean_x
=
self
.
mean
(
x
,
kernel
)
mean_y
=
self
.
mean
(
y
,
kernel
)
square_sum
=
F
.
square
(
mean_x
)
+
F
.
square
(
mean_y
)
luminance
=
(
2
*
mean_x
*
mean_y
+
c1
)
/
(
square_sum
+
c1
)
Inputs:
- **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2.
- **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1.
# SSIM contrast*structure formula (when c3 = c2/2)
# (2 * conv_{xy} + c2) / (conv_{xx} + conv_{yy} + c2), equals to
# (2 * (mean_{xy} - mean_{x}*mean_{y}) + c2) / (mean_{xx}-mean_{x}**2 + mean_{yy}-mean_{y}**2 + c2)
mean_xy
=
self
.
mean
(
x
*
y
,
kernel
)
mean_square_add
=
self
.
mean
(
F
.
square
(
x
)
+
F
.
square
(
y
),
kernel
)
Outputs:
Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
cs
=
(
2
*
(
mean_xy
-
mean_x
*
mean_y
)
+
c2
)
/
(
mean_square_add
-
square_sum
+
c2
)
Examples:
>>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
>>> img1 = Tensor(np.random.random((1,3,128,128)))
>>> img2 = Tensor(np.random.random((1,3,128,128)))
>>> msssim = net(img1, img2)
"""
def
__init__
(
self
,
max_val
=
1.0
,
power_factors
=
(
0.0448
,
0.2856
,
0.3001
,
0.2363
,
0.1333
),
filter_size
=
11
,
filter_sigma
=
1.5
,
k1
=
0.01
,
k2
=
0.03
):
super
(
MSSSIM
,
self
).
__init__
()
validator
.
check_value_type
(
'max_val'
,
max_val
,
[
int
,
float
],
self
.
cls_name
)
validator
.
check_number
(
'max_val'
,
max_val
,
0.0
,
Rel
.
GT
,
self
.
cls_name
)
self
.
max_val
=
max_val
validator
.
check_value_type
(
'power_factors'
,
power_factors
,
[
tuple
,
list
],
self
.
cls_name
)
self
.
filter_size
=
validator
.
check_integer
(
'filter_size'
,
filter_size
,
1
,
Rel
.
GE
,
self
.
cls_name
)
self
.
filter_sigma
=
validator
.
check_float_positive
(
'filter_sigma'
,
filter_sigma
,
self
.
cls_name
)
self
.
k1
=
validator
.
check_value_type
(
'k1'
,
k1
,
[
float
],
self
.
cls_name
)
self
.
k2
=
validator
.
check_value_type
(
'k2'
,
k2
,
[
float
],
self
.
cls_name
)
window
=
_create_window
(
filter_size
,
filter_sigma
)
self
.
level
=
len
(
power_factors
)
self
.
conv
=
[]
for
i
in
range
(
self
.
level
):
self
.
conv
.
append
(
_conv2d
(
1
,
1
,
filter_size
,
Tensor
(
window
)))
self
.
conv
[
i
].
weight
.
requires_grad
=
False
self
.
multi_convs_list
=
CellList
(
self
.
conv
)
self
.
weight_tensor
=
Tensor
(
power_factors
,
mstype
.
float32
)
self
.
avg_pool
=
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
,
pad_mode
=
'valid'
)
self
.
relu
=
ReLU
()
self
.
reduce_mean
=
P
.
ReduceMean
()
self
.
prod
=
P
.
ReduceProd
()
self
.
pow
=
P
.
Pow
()
self
.
pack
=
P
.
Pack
(
axis
=-
1
)
self
.
concat
=
P
.
Concat
(
axis
=
1
)
# SSIM formula
# luminance * cs
ssim
=
luminance
*
cs
def
construct
(
self
,
img1
,
img2
):
_check_input_4d
(
F
.
shape
(
img1
),
"img1"
,
self
.
cls_name
)
_check_input_4d
(
F
.
shape
(
img2
),
"img2"
,
self
.
cls_name
)
P
.
SameTypeShape
()(
img1
,
img2
)
max_val
=
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
img2
=
_convert_img_dtype_to_float32
(
img2
,
self
.
max_val
)
mean_ssim
=
P
.
ReduceMean
()(
ssim
,
(
-
3
,
-
2
,
-
1
))
c1
=
(
self
.
k1
*
max_val
)
**
2
c2
=
(
self
.
k2
*
max_val
)
**
2
return
mean_ssim
sim
=
()
mcs
=
()
def
_fspecial_gauss
(
self
,
filter_size
,
filter_sigma
):
"""get gauss kernel"""
filter_size
,
g
=
_gauss_kernel_helper
(
filter_size
)
for
i
in
range
(
self
.
level
):
sim
,
cs
=
_compute_multi_channel_loss
(
c1
,
c2
,
img1
,
img2
,
self
.
multi_convs_list
[
i
],
self
.
concat
,
self
.
reduce_mean
)
mcs
+=
(
self
.
relu
(
cs
),)
img1
,
img2
=
_downsample
(
img1
,
img2
,
self
.
avg_pool
)
square_sigma_scale
=
-
0.5
/
(
filter_sigma
*
filter_sigma
)
g
=
g
*
square_sigma_scale
g
=
F
.
reshape
(
g
,
(
1
,
-
1
))
+
F
.
reshape
(
g
,
(
-
1
,
1
))
g
=
F
.
reshape
(
g
,
(
1
,
-
1
))
g
=
P
.
Softmax
()(
g
)
ret
=
F
.
reshape
(
g
,
(
1
,
1
,
filter_size
,
filter_size
))
return
ret
mcs
=
mcs
[
0
:
-
1
:
1
]
mcs_and_ssim
=
self
.
pack
(
mcs
+
(
self
.
relu
(
sim
),))
mcs_and_ssim
=
self
.
pow
(
mcs_and_ssim
,
self
.
weight_tensor
)
ms_ssim
=
self
.
prod
(
mcs_and_ssim
,
-
1
)
loss
=
self
.
reduce_mean
(
ms_ssim
,
-
1
)
return
loss
class
PSNR
(
Cell
):
r
"""
...
...
tests/ut/python/nn/test_msssim.py
0 → 100644
浏览文件 @
9b21420b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test msssim
"""
import
numpy
as
np
import
pytest
import
mindspore.common.dtype
as
mstype
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common.api
import
_executor
_MSSSIM_WEIGHTS
=
(
0.0448
,
0.2856
,
0.3001
,
0.2363
,
0.1333
)
class
MSSSIMNet
(
nn
.
Cell
):
def
__init__
(
self
,
max_val
=
1.0
,
power_factors
=
_MSSSIM_WEIGHTS
,
filter_size
=
11
,
filter_sigma
=
1.5
,
k1
=
0.01
,
k2
=
0.03
):
super
(
MSSSIMNet
,
self
).
__init__
()
self
.
net
=
nn
.
MSSSIM
(
max_val
,
power_factors
,
filter_size
,
filter_sigma
,
k1
,
k2
)
def
construct
(
self
,
img1
,
img2
):
return
self
.
net
(
img1
,
img2
)
def
test_compile
():
factors
=
(
0.033
,
0.033
,
0.033
)
net
=
MSSSIMNet
(
power_factors
=
factors
)
img1
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
128
,
128
)))
img2
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
128
,
128
)))
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_compile_grayscale
():
max_val
=
255
factors
=
(
0.033
,
0.033
,
0.033
)
net
=
MSSSIMNet
(
max_val
=
max_val
,
power_factors
=
factors
)
img1
=
Tensor
(
np
.
random
.
randint
(
0
,
256
,
(
8
,
3
,
128
,
128
),
np
.
uint8
))
img2
=
Tensor
(
np
.
random
.
randint
(
0
,
256
,
(
8
,
3
,
128
,
128
),
np
.
uint8
))
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_msssim_max_val_negative
():
max_val
=
-
1
with
pytest
.
raises
(
ValueError
):
_
=
MSSSIMNet
(
max_val
)
def
test_msssim_max_val_bool
():
max_val
=
True
with
pytest
.
raises
(
TypeError
):
_
=
MSSSIMNet
(
max_val
)
def
test_msssim_max_val_zero
():
max_val
=
0
with
pytest
.
raises
(
ValueError
):
_
=
MSSSIMNet
(
max_val
)
def
test_msssim_power_factors_set
():
with
pytest
.
raises
(
TypeError
):
_
=
MSSSIMNet
(
power_factors
=
{
0.033
,
0.033
,
0.033
})
def
test_msssim_filter_size_float
():
with
pytest
.
raises
(
TypeError
):
_
=
MSSSIMNet
(
filter_size
=
1.1
)
def
test_msssim_filter_size_zero
():
with
pytest
.
raises
(
ValueError
):
_
=
MSSSIMNet
(
filter_size
=
0
)
def
test_msssim_filter_sigma_zero
():
with
pytest
.
raises
(
ValueError
):
_
=
MSSSIMNet
(
filter_sigma
=
0.0
)
def
test_msssim_filter_sigma_negative
():
with
pytest
.
raises
(
ValueError
):
_
=
MSSSIMNet
(
filter_sigma
=-
0.1
)
def
test_msssim_different_shape
():
shape_1
=
(
8
,
3
,
128
,
128
)
shape_2
=
(
8
,
3
,
256
,
256
)
factors
=
(
0.033
,
0.033
,
0.033
)
img1
=
Tensor
(
np
.
random
.
random
(
shape_1
))
img2
=
Tensor
(
np
.
random
.
random
(
shape_2
))
net
=
MSSSIMNet
(
power_factors
=
factors
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_msssim_different_dtype
():
dtype_1
=
mstype
.
float32
dtype_2
=
mstype
.
float16
factors
=
(
0.033
,
0.033
,
0.033
)
img1
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
128
,
128
)),
dtype
=
dtype_1
)
img2
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
128
,
128
)),
dtype
=
dtype_2
)
net
=
MSSSIMNet
(
power_factors
=
factors
)
with
pytest
.
raises
(
TypeError
):
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_msssim_invalid_5d_input
():
shape_1
=
(
8
,
3
,
128
,
128
)
shape_2
=
(
8
,
3
,
256
,
256
)
invalid_shape
=
(
8
,
3
,
128
,
128
,
1
)
factors
=
(
0.033
,
0.033
,
0.033
)
img1
=
Tensor
(
np
.
random
.
random
(
shape_1
))
invalid_img1
=
Tensor
(
np
.
random
.
random
(
invalid_shape
))
img2
=
Tensor
(
np
.
random
.
random
(
shape_2
))
invalid_img2
=
Tensor
(
np
.
random
.
random
(
invalid_shape
))
net
=
MSSSIMNet
(
power_factors
=
factors
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
invalid_img1
,
img2
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
img1
,
invalid_img2
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
invalid_img1
,
invalid_img2
)
tests/ut/python/nn/test_ssim.py
浏览文件 @
9b21420b
...
...
@@ -78,26 +78,6 @@ def test_ssim_filter_sigma_negative():
_
=
SSIMNet
(
filter_sigma
=-
0.1
)
def
test_ssim_k1_k2_wrong_value
():
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k1
=
1.1
)
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k1
=
1.0
)
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k1
=
0.0
)
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k1
=-
1.0
)
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k2
=
1.1
)
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k2
=
1.0
)
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k2
=
0.0
)
with
pytest
.
raises
(
ValueError
):
_
=
SSIMNet
(
k2
=-
1.0
)
def
test_ssim_different_shape
():
shape_1
=
(
8
,
3
,
16
,
16
)
shape_2
=
(
8
,
3
,
8
,
8
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录