Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c88edfb3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c88edfb3
编写于
4月 26, 2020
作者:
Z
zhaozhenlong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
psnr check two input same shape and type
上级
3b625ac9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
84 addition
and
0 deletion
+84
-0
mindspore/nn/layer/image.py
mindspore/nn/layer/image.py
+11
-0
tests/ut/python/nn/test_psnr.py
tests/ut/python/nn/test_psnr.py
+37
-0
tests/ut/python/nn/test_ssim.py
tests/ut/python/nn/test_ssim.py
+36
-0
未找到文件。
mindspore/nn/layer/image.py
浏览文件 @
c88edfb3
...
...
@@ -95,6 +95,11 @@ def _gauss_kernel_helper(filter_size):
g
=
Tensor
(
g
)
return
filter_size
,
g
@
constexpr
def
_check_input_4d
(
input_shape
,
param_name
,
func_name
):
if
len
(
input_shape
)
!=
4
:
raise
ValueError
(
f
"
{
func_name
}
{
param_name
}
should be 4d, but got shape
{
input_shape
}
"
)
return
True
class
SSIM
(
Cell
):
r
"""
...
...
@@ -146,6 +151,9 @@ class SSIM(Cell):
self
.
mean
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
kernel_size
=
filter_size
)
def
construct
(
self
,
img1
,
img2
):
_check_input_4d
(
F
.
shape
(
img1
),
"img1"
,
"SSIM"
)
_check_input_4d
(
F
.
shape
(
img2
),
"img2"
,
"SSIM"
)
P
.
SameTypeShape
()(
img1
,
img2
)
max_val
=
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
img2
=
_convert_img_dtype_to_float32
(
img2
,
self
.
max_val
)
...
...
@@ -236,6 +244,9 @@ class PSNR(Cell):
self
.
max_val
=
max_val
def
construct
(
self
,
img1
,
img2
):
_check_input_4d
(
F
.
shape
(
img1
),
"img1"
,
"PSNR"
)
_check_input_4d
(
F
.
shape
(
img2
),
"img2"
,
"PSNR"
)
P
.
SameTypeShape
()(
img1
,
img2
)
max_val
=
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
img2
=
_convert_img_dtype_to_float32
(
img2
,
self
.
max_val
)
...
...
tests/ut/python/nn/test_psnr.py
浏览文件 @
c88edfb3
...
...
@@ -18,10 +18,12 @@ test psnr
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common.api
import
_executor
from
mindspore
import
Tensor
class
PSNRNet
(
nn
.
Cell
):
def
__init__
(
self
,
max_val
=
1.0
):
super
(
PSNRNet
,
self
).
__init__
()
...
...
@@ -59,3 +61,38 @@ def test_psnr_max_val_zero():
max_val
=
0
with
pytest
.
raises
(
ValueError
):
net
=
PSNRNet
(
max_val
)
def
test_psnr_different_shape
():
shape_1
=
(
8
,
3
,
16
,
16
)
shape_2
=
(
8
,
3
,
8
,
8
)
img1
=
Tensor
(
np
.
random
.
random
(
shape_1
))
img2
=
Tensor
(
np
.
random
.
random
(
shape_2
))
net
=
PSNRNet
()
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_psnr_different_dtype
():
dtype_1
=
mstype
.
float32
dtype_2
=
mstype
.
float16
img1
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)),
dtype
=
dtype_1
)
img2
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)),
dtype
=
dtype_2
)
net
=
PSNRNet
()
with
pytest
.
raises
(
TypeError
):
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_psnr_invalid_5d_input
():
shape_1
=
(
8
,
3
,
16
,
16
)
shape_2
=
(
8
,
3
,
8
,
8
)
invalid_shape
=
(
8
,
3
,
16
,
16
,
1
)
img1
=
Tensor
(
np
.
random
.
random
(
shape_1
))
invalid_img1
=
Tensor
(
np
.
random
.
random
(
invalid_shape
))
img2
=
Tensor
(
np
.
random
.
random
(
shape_2
))
invalid_img2
=
Tensor
(
np
.
random
.
random
(
invalid_shape
))
net
=
PSNRNet
()
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
invalid_img1
,
img2
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
img1
,
invalid_img2
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
invalid_img1
,
invalid_img2
)
tests/ut/python/nn/test_ssim.py
浏览文件 @
c88edfb3
...
...
@@ -18,6 +18,7 @@ test ssim
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
import
mindspore.common.dtype
as
mstype
from
mindspore.common.api
import
_executor
from
mindspore
import
Tensor
...
...
@@ -93,3 +94,38 @@ def test_ssim_k1_k2_wrong_value():
net
=
SSIMNet
(
k2
=
0.0
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k2
=-
1.0
)
def
test_ssim_different_shape
():
shape_1
=
(
8
,
3
,
16
,
16
)
shape_2
=
(
8
,
3
,
8
,
8
)
img1
=
Tensor
(
np
.
random
.
random
(
shape_1
))
img2
=
Tensor
(
np
.
random
.
random
(
shape_2
))
net
=
SSIMNet
()
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_ssim_different_dtype
():
dtype_1
=
mstype
.
float32
dtype_2
=
mstype
.
float16
img1
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)),
dtype
=
dtype_1
)
img2
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)),
dtype
=
dtype_2
)
net
=
SSIMNet
()
with
pytest
.
raises
(
TypeError
):
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_ssim_invalid_5d_input
():
shape_1
=
(
8
,
3
,
16
,
16
)
shape_2
=
(
8
,
3
,
8
,
8
)
invalid_shape
=
(
8
,
3
,
16
,
16
,
1
)
img1
=
Tensor
(
np
.
random
.
random
(
shape_1
))
invalid_img1
=
Tensor
(
np
.
random
.
random
(
invalid_shape
))
img2
=
Tensor
(
np
.
random
.
random
(
shape_2
))
invalid_img2
=
Tensor
(
np
.
random
.
random
(
invalid_shape
))
net
=
SSIMNet
()
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
invalid_img1
,
img2
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
img1
,
invalid_img2
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
invalid_img1
,
invalid_img2
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录