Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea681bfb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea681bfb
编写于
6月 16, 2020
作者:
Z
zhaozhenlong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix ssim filter size check
上级
f48ba776
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
11 addition
and
2 deletion
+11
-2
mindspore/nn/layer/image.py
mindspore/nn/layer/image.py
+7
-2
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+4
-0
未找到文件。
mindspore/nn/layer/image.py
浏览文件 @
ea681bfb
...
...
@@ -104,6 +104,12 @@ def _check_input_4d(input_shape, param_name, func_name):
raise
ValueError
(
f
"
{
func_name
}
{
param_name
}
should be 4d, but got shape
{
input_shape
}
"
)
return
True
@
constexpr
def
_check_input_filter_size
(
input_shape
,
param_name
,
filter_size
,
func_name
):
_check_input_4d
(
input_shape
,
param_name
,
func_name
)
validator
.
check
(
param_name
+
" shape[2]"
,
input_shape
[
2
],
"filter_size"
,
filter_size
,
Rel
.
GE
,
func_name
)
validator
.
check
(
param_name
+
" shape[3]"
,
input_shape
[
3
],
"filter_size"
,
filter_size
,
Rel
.
GE
,
func_name
)
class
SSIM
(
Cell
):
r
"""
Returns SSIM index between img1 and img2.
...
...
@@ -154,8 +160,7 @@ class SSIM(Cell):
self
.
mean
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
kernel_size
=
filter_size
)
def
construct
(
self
,
img1
,
img2
):
_check_input_4d
(
F
.
shape
(
img1
),
"img1"
,
self
.
cls_name
)
_check_input_4d
(
F
.
shape
(
img2
),
"img2"
,
self
.
cls_name
)
_check_input_filter_size
(
F
.
shape
(
img1
),
"img1"
,
self
.
filter_size
,
self
.
cls_name
)
P
.
SameTypeShape
()(
img1
,
img2
)
max_val
=
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
ea681bfb
...
...
@@ -1754,6 +1754,10 @@ raise_set = [
'block'
:
(
P
.
PReLU
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[[
2
],
[
1
]],
'desc_bprop'
:
[[
1
]]}),
(
'SSIM'
,
{
'block'
:
(
nn
.
SSIM
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
((
1
,
3
,
8
,
8
)),
mstype
.
float32
),
Tensor
(
np
.
ones
((
1
,
3
,
8
,
8
)),
mstype
.
float32
)]}),
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录