Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
aa8fbcc0
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
aa8fbcc0
编写于
4月 15, 2020
作者:
Z
zhaozhenlong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cell psnr
上级
c0c0b098
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
224 addition
and
18 deletion
+224
-18
mindspore/nn/layer/__init__.py
mindspore/nn/layer/__init__.py
+2
-2
mindspore/nn/layer/image.py
mindspore/nn/layer/image.py
+66
-16
tests/ut/python/nn/test_psnr.py
tests/ut/python/nn/test_psnr.py
+61
-0
tests/ut/python/nn/test_ssim.py
tests/ut/python/nn/test_ssim.py
+95
-0
未找到文件。
mindspore/nn/layer/__init__.py
浏览文件 @
aa8fbcc0
...
...
@@ -25,7 +25,7 @@ from .lstm import LSTM
from
.basic
import
Dropout
,
Flatten
,
Dense
,
ClipByNorm
,
Norm
,
OneHot
,
Pad
,
Unfold
from
.embedding
import
Embedding
from
.pooling
import
AvgPool2d
,
MaxPool2d
from
.image
import
ImageGradients
,
SSIM
from
.image
import
ImageGradients
,
SSIM
,
PSNR
__all__
=
[
'Softmax'
,
'LogSoftmax'
,
'ReLU'
,
'ReLU6'
,
'Tanh'
,
'GELU'
,
'Sigmoid'
,
'PReLU'
,
'get_activation'
,
'LeakyReLU'
,
'HSigmoid'
,
'HSwish'
,
'ELU'
,
...
...
@@ -36,5 +36,5 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'Dropout'
,
'Flatten'
,
'Dense'
,
'ClipByNorm'
,
'Norm'
,
'OneHot'
,
'Embedding'
,
'AvgPool2d'
,
'MaxPool2d'
,
'Pad'
,
'Unfold'
,
'ImageGradients'
,
'SSIM'
,
'ImageGradients'
,
'SSIM'
,
'PSNR'
,
]
mindspore/nn/layer/image.py
浏览文件 @
aa8fbcc0
...
...
@@ -69,6 +69,18 @@ class ImageGradients(Cell):
return
dy
,
dx
def
_convert_img_dtype_to_float32
(
img
,
max_val
):
"""convert img dtype to float32"""
# Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# We will scale img pixel value if max_val > 1. and just cast otherwise.
ret
=
F
.
cast
(
img
,
mstype
.
float32
)
max_val
=
F
.
scalar_cast
(
max_val
,
mstype
.
float32
)
if
max_val
>
1.
:
scale
=
1.
/
max_val
ret
=
ret
*
scale
return
ret
@
constexpr
def
_gauss_kernel_helper
(
filter_size
):
"""gauss kernel helper"""
...
...
@@ -134,9 +146,9 @@ class SSIM(Cell):
self
.
mean
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
kernel_size
=
filter_size
)
def
construct
(
self
,
img1
,
img2
):
max_val
=
self
.
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
self
.
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
img2
=
self
.
_convert_img_dtype_to_float32
(
img2
,
self
.
max_val
)
max_val
=
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
img2
=
_convert_img_dtype_to_float32
(
img2
,
self
.
max_val
)
kernel
=
self
.
_fspecial_gauss
(
self
.
filter_size
,
self
.
filter_sigma
)
kernel
=
P
.
Tile
()(
kernel
,
(
1
,
P
.
Shape
()(
img1
)[
1
],
1
,
1
))
...
...
@@ -145,21 +157,10 @@ class SSIM(Cell):
return
mean_ssim
def
_convert_img_dtype_to_float32
(
self
,
img
,
max_val
):
"""convert img dtype to float32"""
# Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
# We will scale img pixel value if max_val > 1. and just cast otherwise.
ret
=
P
.
Cast
()(
img
,
mstype
.
float32
)
max_val
=
F
.
scalar_cast
(
max_val
,
mstype
.
float32
)
if
max_val
>
1.
:
scale
=
1.
/
max_val
ret
=
ret
*
scale
return
ret
def
_calculate_mean_ssim
(
self
,
x
,
y
,
kernel
,
max_val
,
k1
,
k2
):
"""calculate mean ssim"""
c1
=
(
k1
*
max_val
)
*
(
k1
*
max_val
)
c2
=
(
k2
*
max_val
)
*
(
k2
*
max_val
)
c1
=
(
k1
*
max_val
)
*
(
k1
*
max_val
)
c2
=
(
k2
*
max_val
)
*
(
k2
*
max_val
)
# SSIM luminance formula
# (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1)
...
...
@@ -195,3 +196,52 @@ class SSIM(Cell):
g
=
P
.
Softmax
()(
g
)
ret
=
F
.
reshape
(
g
,
(
1
,
1
,
filter_size
,
filter_size
))
return
ret
class
PSNR
(
Cell
):
r
"""
Returns Peak Signal-to-Noise Ratio of two image batches.
It produces a PSNR value for each image in batch.
Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`.
:math:`MAX` represents the dynamic range of pixel values.
.. math::
MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\
PSNR&=10*log_{10}(\frac{MAX^2}{MSE})
Args:
max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
Default: 1.0.
Inputs:
- **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2.
- **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1.
Outputs:
Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.
Examples:
>>> net = nn.PSNR()
>>> img1 = Tensor(np.random.random((1,3,16,16)))
>>> img2 = Tensor(np.random.random((1,3,16,16)))
>>> psnr = net(img1, img2)
"""
def
__init__
(
self
,
max_val
=
1.0
):
super
(
PSNR
,
self
).
__init__
()
validator
.
check_type
(
'max_val'
,
max_val
,
[
int
,
float
])
validator
.
check
(
'max_val'
,
max_val
,
''
,
0.0
,
Rel
.
GT
)
self
.
max_val
=
max_val
def
construct
(
self
,
img1
,
img2
):
max_val
=
_convert_img_dtype_to_float32
(
self
.
max_val
,
self
.
max_val
)
img1
=
_convert_img_dtype_to_float32
(
img1
,
self
.
max_val
)
img2
=
_convert_img_dtype_to_float32
(
img2
,
self
.
max_val
)
mse
=
P
.
ReduceMean
()(
F
.
square
(
img1
-
img2
),
(
-
3
,
-
2
,
-
1
))
# 10*log_10(max_val^2/MSE)
psnr
=
10
*
P
.
Log
()(
F
.
square
(
max_val
)
/
mse
)
/
F
.
scalar_log
(
10.0
)
return
psnr
tests/ut/python/nn/test_psnr.py
0 → 100644
浏览文件 @
aa8fbcc0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test psnr
"""
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore.common.api
import
_executor
from
mindspore
import
Tensor
class
PSNRNet
(
nn
.
Cell
):
def
__init__
(
self
,
max_val
=
1.0
):
super
(
PSNRNet
,
self
).
__init__
()
self
.
net
=
nn
.
PSNR
(
max_val
)
def
construct
(
self
,
img1
,
img2
):
return
self
.
net
(
img1
,
img2
)
def
test_compile_psnr
():
max_val
=
1.0
net
=
PSNRNet
(
max_val
)
img1
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)))
img2
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)))
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_compile_psnr_grayscale
():
max_val
=
255
net
=
PSNRNet
(
max_val
)
img1
=
Tensor
(
np
.
random
.
randint
(
0
,
256
,
(
8
,
1
,
16
,
16
),
np
.
uint8
))
img2
=
Tensor
(
np
.
random
.
randint
(
0
,
256
,
(
8
,
1
,
16
,
16
),
np
.
uint8
))
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_psnr_max_val_negative
():
max_val
=
-
1
with
pytest
.
raises
(
ValueError
):
net
=
PSNRNet
(
max_val
)
def
test_psnr_max_val_bool
():
max_val
=
True
with
pytest
.
raises
(
ValueError
):
net
=
PSNRNet
(
max_val
)
def
test_psnr_max_val_zero
():
max_val
=
0
with
pytest
.
raises
(
ValueError
):
net
=
PSNRNet
(
max_val
)
tests/ut/python/nn/test_ssim.py
0 → 100644
浏览文件 @
aa8fbcc0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test ssim
"""
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore.common.api
import
_executor
from
mindspore
import
Tensor
class
SSIMNet
(
nn
.
Cell
):
def
__init__
(
self
,
max_val
=
1.0
,
filter_size
=
11
,
filter_sigma
=
1.5
,
k1
=
0.01
,
k2
=
0.03
):
super
(
SSIMNet
,
self
).
__init__
()
self
.
net
=
nn
.
SSIM
(
max_val
,
filter_size
,
filter_sigma
,
k1
,
k2
)
def
construct
(
self
,
img1
,
img2
):
return
self
.
net
(
img1
,
img2
)
def
test_compile
():
net
=
SSIMNet
()
img1
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)))
img2
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)))
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_compile_grayscale
():
max_val
=
255
net
=
SSIMNet
(
max_val
=
max_val
)
img1
=
Tensor
(
np
.
random
.
randint
(
0
,
256
,
(
8
,
1
,
16
,
16
),
np
.
uint8
))
img2
=
Tensor
(
np
.
random
.
randint
(
0
,
256
,
(
8
,
1
,
16
,
16
),
np
.
uint8
))
_executor
.
compile
(
net
,
img1
,
img2
)
def
test_ssim_max_val_negative
():
max_val
=
-
1
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
max_val
)
def
test_ssim_max_val_bool
():
max_val
=
True
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
max_val
)
def
test_ssim_max_val_zero
():
max_val
=
0
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
max_val
)
def
test_ssim_filter_size_float
():
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
filter_size
=
1.1
)
def
test_ssim_filter_size_zero
():
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
filter_size
=
0
)
def
test_ssim_filter_sigma_zero
():
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
filter_sigma
=
0.0
)
def
test_ssim_filter_sigma_negative
():
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
filter_sigma
=-
0.1
)
def
test_ssim_k1_k2_wrong_value
():
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k1
=
1.1
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k1
=
1.0
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k1
=
0.0
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k1
=-
1.0
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k2
=
1.1
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k2
=
1.0
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k2
=
0.0
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k2
=-
1.0
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录