Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5c8acc96
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5c8acc96
编写于
6月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2165 Add image.CentralCrop.
Merge pull request !2165 from liuxiao/central_crop
上级
f10e2974
aa73abc2
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
143 addition
and
1 deletion
+143
-1
mindspore/nn/layer/image.py
mindspore/nn/layer/image.py
+69
-1
tests/ut/python/nn/test_central_crop.py
tests/ut/python/nn/test_central_crop.py
+74
-0
未找到文件。
mindspore/nn/layer/image.py
浏览文件 @
5c8acc96
...
...
@@ -23,7 +23,7 @@ from mindspore._checkparam import Validator as validator
from
mindspore._checkparam
import
Rel
from
..cell
import
Cell
__all__
=
[
'ImageGradients'
,
'SSIM'
,
'PSNR'
]
__all__
=
[
'ImageGradients'
,
'SSIM'
,
'PSNR'
,
'CentralCrop'
]
class
ImageGradients
(
Cell
):
r
"""
...
...
@@ -259,3 +259,71 @@ class PSNR(Cell):
psnr
=
10
*
P
.
Log
()(
F
.
square
(
max_val
)
/
mse
)
/
F
.
scalar_log
(
10.0
)
return
psnr
@
constexpr
def
_check_input_3d_or_4d
(
input_shape
,
param_name
,
func_name
):
"""check input 3d or 4d"""
if
len
(
input_shape
)
!=
3
and
len
(
input_shape
)
!=
4
:
raise
ValueError
(
f
"
{
func_name
}
{
param_name
}
should be 3d or 4d, but got shape
{
input_shape
}
"
)
return
True
@
constexpr
def
_get_bbox
(
rank
,
shape
,
central_fraction
):
"""get bbox start and size for slice"""
if
rank
==
3
:
c
,
h
,
w
=
shape
else
:
n
,
c
,
h
,
w
=
shape
bbox_h_start
=
int
((
float
(
h
)
-
float
(
h
)
*
central_fraction
)
/
2
)
bbox_w_start
=
int
((
float
(
w
)
-
float
(
w
)
*
central_fraction
)
/
2
)
bbox_h_size
=
h
-
bbox_h_start
*
2
bbox_w_size
=
w
-
bbox_w_start
*
2
if
rank
==
3
:
bbox_begin
=
(
0
,
bbox_h_start
,
bbox_w_start
)
bbox_size
=
(
c
,
bbox_h_size
,
bbox_w_size
)
else
:
bbox_begin
=
(
0
,
0
,
bbox_h_start
,
bbox_w_start
)
bbox_size
=
(
n
,
c
,
bbox_h_size
,
bbox_w_size
)
return
bbox_begin
,
bbox_size
class
CentralCrop
(
Cell
):
"""
Crop the centeral region of the images with the central_fraction.
Args:
central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0].
Inputs:
- **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W].
Outputs:
Tensor, 3-D or 4-D float tensor, according to the input.
Examples:
>>> net = nn.CentralCrop(central_fraction=0.5)
>>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
>>> output = net(image)
"""
def
__init__
(
self
,
central_fraction
):
super
(
CentralCrop
,
self
).
__init__
()
validator
.
check_value_type
(
"central_fraction"
,
central_fraction
,
[
float
],
self
.
cls_name
)
self
.
central_fraction
=
validator
.
check_number_range
(
'central_fraction'
,
central_fraction
,
0.0
,
1.0
,
Rel
.
INC_RIGHT
,
self
.
cls_name
)
self
.
slice
=
P
.
Slice
()
def
construct
(
self
,
image
):
image_shape
=
F
.
shape
(
image
)
rank
=
len
(
image_shape
)
_check_input_3d_or_4d
(
image_shape
,
"image"
,
self
.
cls_name
)
if
self
.
central_fraction
==
1.0
:
return
image
bbox_begin
,
bbox_size
=
_get_bbox
(
rank
,
image_shape
,
self
.
central_fraction
)
image
=
self
.
slice
(
image
,
bbox_begin
,
bbox_size
)
return
image
tests/ut/python/nn/test_central_crop.py
0 → 100644
浏览文件 @
5c8acc96
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test CentralCrop
"""
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common.api
import
_executor
class
CentralCropNet
(
nn
.
Cell
):
def
__init__
(
self
,
central_fraction
):
super
(
CentralCropNet
,
self
).
__init__
()
self
.
net
=
nn
.
CentralCrop
(
central_fraction
)
def
construct
(
self
,
image
):
return
self
.
net
(
image
)
def
test_compile_3d_central_crop
():
central_fraction
=
0.2
net
=
CentralCropNet
(
central_fraction
)
image
=
Tensor
(
np
.
random
.
random
((
3
,
16
,
16
)),
mstype
.
float32
)
_executor
.
compile
(
net
,
image
)
def
test_compile_4d_central_crop
():
central_fraction
=
0.5
net
=
CentralCropNet
(
central_fraction
)
image
=
Tensor
(
np
.
random
.
random
((
8
,
3
,
16
,
16
)),
mstype
.
float32
)
_executor
.
compile
(
net
,
image
)
def
test_central_fraction_bool
():
central_fraction
=
True
with
pytest
.
raises
(
TypeError
):
_
=
CentralCropNet
(
central_fraction
)
def
test_central_crop_central_fraction_negative
():
central_fraction
=
-
1.0
with
pytest
.
raises
(
ValueError
):
_
=
CentralCropNet
(
central_fraction
)
def
test_central_fraction_zero
():
central_fraction
=
0.0
with
pytest
.
raises
(
ValueError
):
_
=
CentralCropNet
(
central_fraction
)
def
test_central_crop_invalid_5d_input
():
invalid_shape
=
(
8
,
3
,
16
,
16
,
1
)
invalid_image
=
Tensor
(
np
.
random
.
random
(
invalid_shape
))
net
=
CentralCropNet
(
central_fraction
=
0.5
)
with
pytest
.
raises
(
ValueError
):
_executor
.
compile
(
net
,
invalid_image
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录