Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6721541c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6721541c
编写于
4月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!25 Develop Cell unfold,and Op ExtractImagePatches.
Merge pull request !25 from zhangbuxue/unfold-develop
上级
9c9c7091
62807da0
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
310 addition
and
42 deletion
+310
-42
mindspore/ccsrc/kernel/tbe/tbe_adapter.h
mindspore/ccsrc/kernel/tbe/tbe_adapter.h
+0
-6
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+2
-0
mindspore/ccsrc/transform/op_adapter.h
mindspore/ccsrc/transform/op_adapter.h
+12
-14
mindspore/ccsrc/transform/op_declare.cc
mindspore/ccsrc/transform/op_declare.cc
+9
-5
mindspore/ccsrc/transform/op_declare.h
mindspore/ccsrc/transform/op_declare.h
+2
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-0
mindspore/nn/layer/__init__.py
mindspore/nn/layer/__init__.py
+2
-2
mindspore/nn/layer/basic.py
mindspore/nn/layer/basic.py
+48
-0
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+56
-1
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+96
-4
tests/ut/python/ops/test_math_ops.py
tests/ut/python/ops/test_math_ops.py
+24
-8
tests/ut/python/ops/test_nn_ops.py
tests/ut/python/ops/test_nn_ops.py
+55
-0
未找到文件。
mindspore/ccsrc/kernel/tbe/tbe_adapter.h
浏览文件 @
6721541c
...
...
@@ -45,12 +45,6 @@ class TbeAdapter {
std
::
vector
<
nlohmann
::
json
>
*
input_list
,
kCreaterType
creater_type
);
private:
static
void
MaxPoolWithArgmaxAttrJsonPass
(
const
AnfNodePtr
&
anf_node
,
const
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
&
op_info_attrs
,
nlohmann
::
json
*
attrs_json
);
static
void
MaxPoolGradWithArgmaxAttrJsonPass
(
const
AnfNodePtr
&
anf_node
,
const
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
&
op_info_attrs
,
nlohmann
::
json
*
attrs_json
);
static
void
Conv2DAttrJsonPass
(
const
AnfNodePtr
&
anf_node
,
const
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
&
op_info_attrs
,
nlohmann
::
json
*
attrs_json
);
static
void
Conv2DBackpropFilterAttrJsonPass
(
const
AnfNodePtr
&
anf_node
,
...
...
mindspore/ccsrc/transform/convert.cc
浏览文件 @
6721541c
...
...
@@ -96,6 +96,7 @@ const char kNameConfusionMatrix[] = "ConfusionMatrix";
const
char
kNameResizeNearestNeighborD
[]
=
"ResizeNearestNeighbor"
;
const
char
kNameResizeNearestNeighborGrad
[]
=
"ResizeNearestNeighborGrad"
;
const
char
kNameApplyAdam
[]
=
"Adam"
;
const
char
kNameExtractImagePatches
[]
=
"ExtractImagePatches"
;
const
char
kNameReLU6
[]
=
"ReLU6"
;
const
char
kNameReLU6Grad
[]
=
"ReLU6Grad"
;
const
char
kNameElu
[]
=
"Elu"
;
...
...
@@ -214,6 +215,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameMaxPoolGrad
),
ADPT_DESC
(
MaxPoolGrad
)},
{
string
(
kNameAvgPoolGrad
),
ADPT_DESC
(
AvgPoolGrad
)},
{
string
(
kNameMaxPoolGradWithArgmax
),
ADPT_DESC
(
MaxPoolGradWithArgmax
)},
{
string
(
kNameExtractImagePatches
),
ADPT_DESC
(
ExtractImagePatches
)},
{
prim
::
kPrimAssign
->
name
(),
ADPT_DESC
(
Assign
)},
{
prim
::
kPrimStateSetItem
->
name
(),
ADPT_DESC
(
Assign
)},
{
prim
::
kPrimReluGrad
->
name
(),
ADPT_DESC
(
ReluGrad
)},
...
...
mindspore/ccsrc/transform/op_adapter.h
浏览文件 @
6721541c
...
...
@@ -322,18 +322,12 @@ class OpAdapter : public BaseOpAdapter {
Status
UpdateSingleOutputDesc
(
const
OperatorPtr
&
op
,
const
abstract
::
BaseShapePtr
&
shp
,
const
TypePtr
&
type
)
{
MS_EXCEPTION_IF_NULL
(
type
);
TypeId
me_type
=
type
->
type_id
();
if
(
kObjectTypeTensorType
==
me_type
)
{
me_type
=
dyn_cast
<
TensorType
>
(
type
)
->
element
()
->
type_id
();
}
std
::
vector
<
int
>
shape
;
auto
normal_shape_ptr
=
dyn_cast
<
abstract
::
Shape
>
(
shp
);
if
(
nullptr
!=
normal_shape_ptr
)
{
shape
=
normal_shape_ptr
->
shape
();
std
::
string
format
=
"NCHW"
;
if
(
op
->
GetOpType
()
==
kExtractImagePatchesOpName
)
{
format
=
"NHWC"
;
}
auto
desc
=
TransformUtil
::
GetGeTensorDesc
(
shape
,
me_type
,
"NCHW"
);
auto
desc
=
CreateOutputDesc
(
dyn_cast
<
abstract
::
Shape
>
(
shp
),
type
,
format
);
if
(
desc
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Update output descriptor failed!"
;
return
FAILED
;
...
...
@@ -410,14 +404,15 @@ class OpAdapter : public BaseOpAdapter {
MS_LOG
(
ERROR
)
<<
"output_map is not equal tuple_shape size"
;
return
FAILED
;
}
std
::
string
format
=
"NCHW"
;
if
(
op
->
GetOpType
()
==
kTopKOpName
)
{
format
=
"NHWC"
;
}
for
(
size_t
i
=
0
;
i
<
tuple_shp
->
shape
().
size
();
++
i
)
{
auto
tuple_type
=
dyn_cast
<
Tuple
>
(
type
);
MS_EXCEPTION_IF_NULL
(
tuple_type
);
TypePtr
type_elem
=
tuple_type
->
elements
()[
i
];
std
::
string
format
=
"NCHW"
;
if
(
op
->
GetOpType
()
==
kTopKOpName
)
{
format
=
"NHWC"
;
}
auto
desc
=
CreateOutputDesc
(
dyn_cast
<
abstract
::
Shape
>
(
tuple_shp
->
shape
()[
i
]),
type_elem
,
format
);
if
(
desc
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create output descriptor failed!"
;
...
...
@@ -476,6 +471,9 @@ class OpAdapter : public BaseOpAdapter {
if
(
desc
==
nullptr
)
{
continue
;
}
if
(
op
->
GetOpType
()
==
kExtractImagePatchesOpName
)
{
desc
->
SetFormat
(
ge
::
Format
::
FORMAT_NHWC
);
}
it
->
second
.
update_input_desc
(
op
,
*
desc
);
}
}
...
...
mindspore/ccsrc/transform/op_declare.cc
浏览文件 @
6721541c
...
...
@@ -751,16 +751,20 @@ ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyT
OUTPUT_MAP
(
MaxPoolWithArgmax
)
=
{{
0
,
OUTPUT_DESC
(
y
)},
{
1
,
OUTPUT_DESC
(
argmax
)}};
// MaxPoolGradWithArgmax
INPUT_MAP
(
MaxPoolGradWithArgmax
)
=
{
{
1
,
INPUT_DESC
(
x
)},
{
2
,
INPUT_DESC
(
grad
)},
{
3
,
INPUT_DESC
(
argmax
)},
};
INPUT_MAP
(
MaxPoolGradWithArgmax
)
=
{{
1
,
INPUT_DESC
(
x
)},
{
2
,
INPUT_DESC
(
grad
)},
{
3
,
INPUT_DESC
(
argmax
)}};
ATTR_MAP
(
MaxPoolGradWithArgmax
)
=
{{
"ksize"
,
ATTR_DESC
(
ksize
,
AnyTraits
<
int
>
(),
AnyTraits
<
std
::
vector
<
int64_t
>>
())},
{
"strides"
,
ATTR_DESC
(
strides
,
AnyTraits
<
int
>
(),
AnyTraits
<
std
::
vector
<
int64_t
>>
())},
{
"padding"
,
ATTR_DESC
(
padding
,
AnyTraits
<
std
::
string
>
())}};
OUTPUT_MAP
(
MaxPoolGradWithArgmax
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// ExtractImagePatches
INPUT_MAP
(
ExtractImagePatches
)
=
{{
1
,
INPUT_DESC
(
images
)}};
ATTR_MAP
(
ExtractImagePatches
)
=
{{
"ksizes"
,
ATTR_DESC
(
ksizes
,
AnyTraits
<
int
>
(),
AnyTraits
<
std
::
vector
<
int64_t
>>
())},
{
"strides"
,
ATTR_DESC
(
strides
,
AnyTraits
<
int
>
(),
AnyTraits
<
std
::
vector
<
int64_t
>>
())},
{
"rates"
,
ATTR_DESC
(
rates
,
AnyTraits
<
int
>
(),
AnyTraits
<
std
::
vector
<
int64_t
>>
())},
{
"padding"
,
ATTR_DESC
(
padding
,
AnyTraits
<
std
::
string
>
())}};
OUTPUT_MAP
(
ExtractImagePatches
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// Conv2D
INPUT_MAP
(
Conv2D
)
=
{{
1
,
INPUT_DESC
(
x
)},
{
2
,
INPUT_DESC
(
filter
)}};
ATTR_MAP
(
Conv2D
)
=
{
...
...
mindspore/ccsrc/transform/op_declare.h
浏览文件 @
6721541c
...
...
@@ -95,6 +95,8 @@ DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax)
DECLARE_OP_ADAPTER
(
Conv2D
)
DECLARE_OP_USE_ENUM
(
Conv2D
)
DECLARE_OP_USE_OUTPUT
(
Conv2D
)
DECLARE_OP_ADAPTER
(
ExtractImagePatches
)
DECLARE_OP_USE_OUTPUT
(
ExtractImagePatches
)
DECLARE_OP_ADAPTER
(
Conv2DBackpropInputD
)
DECLARE_OP_USE_ENUM
(
Conv2DBackpropInputD
)
DECLARE_OP_USE_INPUT_ATTR
(
Conv2DBackpropInputD
)
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
6721541c
...
...
@@ -49,6 +49,7 @@ constexpr auto kBroadcastOpName = "Broadcast";
constexpr
auto
kReduceScatterOpName
=
"ReduceScatter"
;
constexpr
auto
kMemCpyAsyncOpName
=
"memcpy_async"
;
constexpr
auto
kTopKOpName
=
"TopK"
;
constexpr
auto
kExtractImagePatchesOpName
=
"ExtractImagePatches"
;
constexpr
auto
kBNTrainingReduceOpName
=
"BNTrainingReduce"
;
constexpr
auto
kBNTrainingUpdateOpName
=
"BNTrainingUpdate"
;
constexpr
auto
kSimpleMeanGradOpName
=
"SimpleMeanGrad"
;
...
...
mindspore/nn/layer/__init__.py
浏览文件 @
6721541c
...
...
@@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
from
.container
import
SequentialCell
,
CellList
from
.conv
import
Conv2d
,
Conv2dTranspose
from
.lstm
import
LSTM
from
.basic
import
Dropout
,
Flatten
,
Dense
,
ClipByNorm
,
Norm
,
OneHot
,
Pad
from
.basic
import
Dropout
,
Flatten
,
Dense
,
ClipByNorm
,
Norm
,
OneHot
,
Pad
,
Unfold
from
.embedding
import
Embedding
from
.pooling
import
AvgPool2d
,
MaxPool2d
from
.image
import
ImageGradients
,
SSIM
...
...
@@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'LSTM'
,
'Dropout'
,
'Flatten'
,
'Dense'
,
'ClipByNorm'
,
'Norm'
,
'OneHot'
,
'Embedding'
,
'AvgPool2d'
,
'MaxPool2d'
,
'Pad'
,
'AvgPool2d'
,
'MaxPool2d'
,
'Pad'
,
'Unfold'
,
'ImageGradients'
,
'SSIM'
,
]
mindspore/nn/layer/basic.py
浏览文件 @
6721541c
...
...
@@ -439,3 +439,51 @@ class Pad(Cell):
else
:
x
=
self
.
pad
(
x
,
self
.
paddings
)
return
x
class
Unfold
(
Cell
):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NCHW.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_depth, out_row, out_col], the out_batch is same as the in_batch.
Examples:
>>> net = Unfold(ksizes=[1, 2, 2, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1])
>>> image = Tensor(np.ones([1, 1, 3, 3]), dtype=mstype.float16)
>>> net(image)
Tensor ([[[[1, 1] [1, 1]] [[1, 1], [1, 1]] [[1, 1] [1, 1]], [[1, 1], [1, 1]]]],
shape=(1, 4, 2, 2), dtype=mstype.float16)
"""
def
__init__
(
self
,
ksizes
,
strides
,
rates
,
padding
=
"valid"
):
super
(
Unfold
,
self
).
__init__
()
self
.
extract_image_patches
=
P
.
ExtractImagePatches
(
ksizes
,
strides
,
rates
,
padding
)
self
.
transpose
=
P
.
Transpose
()
self
.
format_NHWC
=
(
0
,
2
,
3
,
1
)
self
.
format_NCHW
=
(
0
,
3
,
1
,
2
)
def
construct
(
self
,
input_x
):
x_transpose
=
self
.
transpose
(
input_x
,
self
.
format_NHWC
)
ret
=
self
.
extract_image_patches
(
x_transpose
)
ret_transpose
=
self
.
transpose
(
ret
,
self
.
format_NCHW
)
return
ret_transpose
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
6721541c
...
...
@@ -14,7 +14,7 @@
# ============================================================================
"""Define the grad rules of neural network related operations."""
from
mindspore.common
import
dtype
as
mstype
from
..
import
functional
as
F
from
..
import
operations
as
P
from
..operations
import
_grad_ops
as
G
...
...
@@ -52,6 +52,61 @@ def get_bprop_conv2d(self):
return
bprop
@
bprop_getters
.
register
(
P
.
ExtractImagePatches
)
def
get_bprop_extract_image_patches
(
self
):
"""Grad definition for `ExtractImagePatches` operation."""
get_shape
=
P
.
Shape
()
reshape
=
P
.
Reshape
()
extract_image_patches
=
P
.
ExtractImagePatches
(
ksizes
=
self
.
ksizes
,
strides
=
self
.
strides
,
rates
=
self
.
rates
,
padding
=
self
.
padding
)
concat
=
P
.
Concat
(
axis
=-
1
)
expand_dims
=
P
.
ExpandDims
()
scatter_nd
=
P
.
ScatterNd
()
dtype
=
P
.
DType
()
fill
=
P
.
Fill
()
slice_op
=
P
.
Slice
()
transpose
=
P
.
Transpose
()
matmul
=
P
.
MatMul
()
cast
=
P
.
Cast
()
_
,
ksizes_row
,
ksizes_col
,
_
=
self
.
ksizes
def
bprop
(
x
,
out
,
dout
):
x_shape
=
get_shape
(
x
)
x_batch
,
x_row
,
x_col
,
x_depth
=
x_shape
x_indices_num
=
x_row
*
x_col
+
1
x_idx
=
F
.
tuple_to_array
(
range
(
1
,
x_indices_num
))
x_idx
=
reshape
(
x_idx
,
(
1
,
x_row
,
x_col
,
1
))
x_idx
=
cast
(
x_idx
,
mstype
.
float16
)
x_idx_patch
=
extract_image_patches
(
x_idx
)
x_idx_patch
=
transpose
(
x_idx_patch
,
(
0
,
3
,
1
,
2
))
x_idx_patch
=
cast
(
x_idx_patch
,
mstype
.
int32
)
out_shape
=
get_shape
(
out
)
_
,
out_row
,
out_col
,
_
=
out_shape
out_indices_num
=
out_row
*
out_col
*
ksizes_row
*
ksizes_col
out_idx
=
F
.
tuple_to_array
(
range
(
out_indices_num
))
out_idx
=
reshape
(
out_idx
,
(
1
,
ksizes_row
*
ksizes_col
,
out_row
,
out_col
))
idx_tensor
=
concat
((
expand_dims
(
x_idx_patch
,
-
1
),
expand_dims
(
out_idx
,
-
1
)))
idx_tensor
=
reshape
(
idx_tensor
,
(
-
1
,
2
))
sp_shape
=
(
x_indices_num
,
out_indices_num
)
sp_tensor
=
scatter_nd
(
idx_tensor
,
fill
(
dtype
(
dout
),
(
out_indices_num
,),
1
),
sp_shape
)
sp_tensor
=
slice_op
(
sp_tensor
,
(
1
,
0
),
(
x_indices_num
-
1
,
out_indices_num
))
grad
=
reshape
(
dout
,
(
x_batch
,
out_row
,
out_col
,
ksizes_row
,
ksizes_col
,
x_depth
))
grad
=
transpose
(
grad
,
(
1
,
2
,
3
,
4
,
0
,
5
))
grad
=
reshape
(
grad
,
(
-
1
,
x_batch
*
x_depth
))
jac
=
matmul
(
sp_tensor
,
grad
)
dx
=
reshape
(
jac
,
(
x_row
,
x_col
,
x_batch
,
x_depth
))
dx
=
transpose
(
dx
,
(
2
,
0
,
1
,
3
))
return
(
dx
,)
return
bprop
@
bprop_getters
.
register
(
P
.
DepthwiseConv2dNative
)
def
get_bprop_depthwise_conv2d_native
(
self
):
"""Grad definition for `DepthwiseConv2dNative` operation."""
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
6721541c
...
...
@@ -57,7 +57,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
Gelu
,
Elu
,
GetNext
,
L2Normalize
,
LayerNorm
,
LogSoftmax
,
MaxPool
,
MaxPool
,
ExtractImagePatches
,
AvgPool
,
Conv2DBackpropInput
,
MaxPoolWithArgmax
,
OneHot
,
Pad
,
MirrorPad
,
PReLU
,
ReLU
,
ReLU6
,
HSwish
,
HSigmoid
,
ResizeBilinear
,
Sigmoid
,
...
...
@@ -89,6 +89,7 @@ __all__ = [
'Sqrt'
,
'Square'
,
'Conv2D'
,
'ExtractImagePatches'
,
'Flatten'
,
'MaxPoolWithArgmax'
,
'FusedBatchNorm'
,
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
6721541c
...
...
@@ -1475,7 +1475,7 @@ class LogicalNot(PrimitiveWithInfer):
Computes the "logical NOT" of a tensor element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor whose dtype is bool
- **input_x** (Tensor) - The input tensor whose dtype is bool
.
Outputs:
Tensor, the shape is same as the `input_x`, and the dtype is bool.
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
6721541c
...
...
@@ -2550,6 +2550,7 @@ class ApplyFtrl(PrimitiveWithInfer):
Outputs:
Tensor, representing the updated var.
"""
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'accum'
,
'linear'
,
'grad'
,
'lr'
,
'l1'
,
'l2'
,
'lr_power'
],
...
...
@@ -2570,8 +2571,99 @@ class ApplyFtrl(PrimitiveWithInfer):
args
=
{
'var_type'
:
var_type
,
'accum_type'
:
accum_type
,
'linear_type'
:
linear_type
,
'grad_type'
:
grad_type
}
validator
.
check_type_same
(
args
,
(
mstype
.
float32
,
mstype
.
float16
))
validator
.
check_typename
(
"lr"
,
lr_type
,[
mstype
.
float16
,
mstype
.
float32
])
validator
.
check_typename
(
"l1"
,
l1_type
,[
mstype
.
float16
,
mstype
.
float32
])
validator
.
check_typename
(
"l2"
,
l2_type
,[
mstype
.
float16
,
mstype
.
float32
])
validator
.
check_typename
(
"lr_power"
,
lr_power_type
,[
mstype
.
float16
,
mstype
.
float32
])
validator
.
check_typename
(
"lr"
,
lr_type
,
[
mstype
.
float16
,
mstype
.
float32
])
validator
.
check_typename
(
"l1"
,
l1_type
,
[
mstype
.
float16
,
mstype
.
float32
])
validator
.
check_typename
(
"l2"
,
l2_type
,
[
mstype
.
float16
,
mstype
.
float32
])
validator
.
check_typename
(
"lr_power"
,
lr_power_type
,
[
mstype
.
float16
,
mstype
.
float32
])
return
var_type
class
ExtractImagePatches
(
PrimitiveWithInfer
):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NHWC.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
"""
@
prim_attr_register
def
__init__
(
self
,
ksizes
,
strides
,
rates
,
padding
=
"valid"
):
"""init"""
validator
.
check_type
(
"ksizes"
,
ksizes
,
[
tuple
,
list
])
validator
.
check_type
(
"strides"
,
strides
,
[
tuple
,
list
])
validator
.
check_type
(
"rates"
,
rates
,
[
tuple
,
list
])
self
.
padding
=
validator
.
check_string
(
'padding'
,
padding
.
upper
(),
[
'VALID'
,
'SAME'
])
self
.
add_prim_attr
(
"padding"
,
self
.
padding
)
if
len
(
ksizes
)
!=
4
or
ksizes
[
0
]
!=
1
or
ksizes
[
3
]
!=
1
:
raise
ValueError
(
"The format of ksizes should be [1, ksize_row, ksize_col, 1], "
f
"but got
{
ksizes
}
."
)
if
not
isinstance
(
ksizes
[
1
],
int
)
or
not
isinstance
(
ksizes
[
2
],
int
)
or
\
ksizes
[
1
]
<
1
or
ksizes
[
2
]
<
1
:
raise
ValueError
(
"The ksize_row and ksize_col in ksizes should be an positive integer number, "
f
"but got ksize_row is
{
ksizes
[
1
]
}
, ksize_col is
{
ksizes
[
2
]
}
"
)
if
len
(
strides
)
!=
4
or
strides
[
0
]
!=
1
or
strides
[
3
]
!=
1
:
raise
ValueError
(
"The format of strides should be [1, stride_row, stride_col, 1], "
f
"but got
{
strides
}
."
)
if
not
isinstance
(
strides
[
1
],
int
)
or
not
isinstance
(
strides
[
2
],
int
)
or
\
strides
[
1
]
<
1
or
strides
[
2
]
<
1
:
raise
ValueError
(
"The stride_row and stride_col in strides should be an positive integer number, "
f
"but got stride_row is
{
strides
[
1
]
}
, stride_col is
{
strides
[
2
]
}
"
)
if
len
(
rates
)
!=
4
or
rates
[
0
]
!=
1
or
rates
[
3
]
!=
1
:
raise
ValueError
(
"The format of rates should be [1, rate_row, rate_col, 1], "
f
"but got
{
rates
}
."
)
if
not
isinstance
(
rates
[
1
],
int
)
or
not
isinstance
(
rates
[
2
],
int
)
or
\
rates
[
1
]
<
1
or
rates
[
2
]
<
1
:
raise
ValueError
(
"The rate_row and rate_col in rates should be an positive integer number, "
f
"but got rate_row is
{
rates
[
1
]
}
, rate_col is
{
rates
[
2
]
}
"
)
def
infer_shape
(
self
,
input_x
):
in_batch
,
in_row
,
in_col
,
in_depth
=
input_x
_
,
ksize_row
,
ksize_col
,
_
=
self
.
ksizes
_
,
stride_row
,
stride_col
,
_
=
self
.
strides
_
,
rate_row
,
rate_col
,
_
=
self
.
rates
if
len
(
input_x
)
!=
4
:
raise
ValueError
(
"The `input_x` should be a 4-D tensor, "
f
"but got a
{
len
(
input_x
)
}
-D tensor whose shape is
{
input_x
}
"
)
out_batch
=
in_batch
out_depth
=
ksize_row
*
ksize_col
*
in_depth
if
self
.
padding
==
"VALID"
:
out_row
=
\
(
in_row
-
(
ksize_row
+
(
ksize_row
-
1
)
*
(
rate_row
-
1
)))
//
stride_row
+
1
out_col
=
\
(
in_col
-
(
ksize_col
+
(
ksize_col
-
1
)
*
(
rate_col
-
1
)))
//
stride_col
+
1
else
:
out_row
=
(
in_row
-
1
)
//
stride_row
+
1
out_col
=
(
in_col
-
1
)
//
stride_col
+
1
out_shape
=
[
out_batch
,
out_row
,
out_col
,
out_depth
]
return
out_shape
def
infer_dtype
(
self
,
input_x
):
validator
.
check_subclass
(
"input_x"
,
input_x
,
mstype
.
tensor
)
validator
.
check_typename
(
"input_x_dtype"
,
input_x
,
(
mstype
.
int8
,
mstype
.
float16
,
mstype
.
float32
))
return
input_x
tests/ut/python/ops/test_math_ops.py
浏览文件 @
6721541c
...
...
@@ -30,6 +30,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from
....mindspore_test_framework.pipeline.forward.verify_exception
\
import
pipeline_for_verify_exception_for_case_by_case_config
# pylint: disable=W0613
# pylint: disable=W0231
# W0613: unused-argument
...
...
@@ -106,7 +108,7 @@ def test_realdiv():
result
=
div
(
x
,
y
)
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
expect
=
x
/
y
expect
=
x
/
y
assert
np
.
all
(
result
.
asnumpy
()
==
expect
)
...
...
@@ -122,6 +124,7 @@ def test_eye():
class
VirtualLossGrad
(
PrimitiveWithInfer
):
""" VirtualLossGrad definition """
@
prim_attr_register
def
__init__
(
self
):
"""init VirtualLossGrad"""
...
...
@@ -138,6 +141,7 @@ class VirtualLossGrad(PrimitiveWithInfer):
class
VirtualLoss
(
PrimitiveWithInfer
):
""" VirtualLoss definition """
@
prim_attr_register
def
__init__
(
self
):
"""init VirtualLoss"""
...
...
@@ -151,6 +155,7 @@ class VirtualLoss(PrimitiveWithInfer):
def
bprop
(
x
,
out
,
dout
):
dx
=
loss_grad
(
x
,
out
,
dout
)
return
(
dx
,)
return
bprop
def
infer_shape
(
self
,
x_shape
):
...
...
@@ -162,6 +167,7 @@ class VirtualLoss(PrimitiveWithInfer):
class
NetWithLoss
(
nn
.
Cell
):
""" NetWithLoss definition """
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
...
...
@@ -174,6 +180,7 @@ class NetWithLoss(nn.Cell):
class
GradWrap
(
nn
.
Cell
):
""" GradWrap definition """
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
...
...
@@ -184,6 +191,7 @@ class GradWrap(nn.Cell):
class
MatMulNet
(
nn
.
Cell
):
""" MatMulNet definition """
def
__init__
(
self
):
super
(
MatMulNet
,
self
).
__init__
()
self
.
matmul
=
P
.
MatMul
()
...
...
@@ -195,6 +203,7 @@ class MatMulNet(nn.Cell):
class
NetWithLossSub
(
nn
.
Cell
):
""" NetWithLossSub definition """
def
__init__
(
self
,
network
):
super
(
NetWithLossSub
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
...
...
@@ -207,6 +216,7 @@ class NetWithLossSub(nn.Cell):
class
GradWrapSub
(
nn
.
Cell
):
""" GradWrapSub definition """
def
__init__
(
self
,
network
):
super
(
GradWrapSub
,
self
).
__init__
()
self
.
network
=
network
...
...
@@ -217,6 +227,7 @@ class GradWrapSub(nn.Cell):
class
SubNet
(
nn
.
Cell
):
""" SubNet definition """
def
__init__
(
self
):
super
(
SubNet
,
self
).
__init__
()
self
.
sub
=
P
.
Sub
()
...
...
@@ -227,6 +238,7 @@ class SubNet(nn.Cell):
class
NpuFloatNet
(
nn
.
Cell
):
""" NpuFloat definition """
def
__init__
(
self
):
super
(
NpuFloatNet
,
self
).
__init__
()
self
.
mul
=
P
.
Mul
()
...
...
@@ -258,6 +270,7 @@ class NpuFloatNet(nn.Cell):
class
DiagNet
(
nn
.
Cell
):
""" DiagNet definition """
def
__init__
(
self
):
super
(
DiagNet
,
self
).
__init__
()
self
.
fill
=
P
.
Fill
()
...
...
@@ -269,6 +282,7 @@ class DiagNet(nn.Cell):
class
NetWithLossCumSum
(
nn
.
Cell
):
""" NetWithLossCumSum definition """
def
__init__
(
self
,
network
):
super
(
NetWithLossCumSum
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
...
...
@@ -281,6 +295,7 @@ class NetWithLossCumSum(nn.Cell):
class
GradWrapCumSum
(
nn
.
Cell
):
""" GradWrap definition """
def
__init__
(
self
,
network
):
super
(
GradWrapCumSum
,
self
).
__init__
()
self
.
network
=
network
...
...
@@ -291,6 +306,7 @@ class GradWrapCumSum(nn.Cell):
class
NetCumSum
(
nn
.
Cell
):
""" NetCumSum definition """
def
__init__
(
self
):
super
(
NetCumSum
,
self
).
__init__
()
self
.
cumsum
=
P
.
CumSum
()
...
...
@@ -321,8 +337,8 @@ test_case_math_ops = [
'skip'
:
[
'backward'
]}),
(
'CumSumGrad'
,
{
'block'
:
GradWrapCumSum
(
NetWithLossCumSum
(
NetCumSum
())),
'desc_inputs'
:
[
Tensor
(
np
.
array
([[
3
,
4
,
6
,
10
],
[
1
,
6
,
7
,
9
],[
4
,
3
,
8
,
7
],
[
1
,
3
,
7
,
9
]]).
astype
(
np
.
float16
))],
'desc_bprop'
:
[
Tensor
(
np
.
array
([[
3
,
4
,
6
,
10
],
[
1
,
6
,
7
,
9
],[
4
,
3
,
8
,
7
],
[
1
,
3
,
7
,
9
]]).
astype
(
np
.
float16
))],
'desc_inputs'
:
[
Tensor
(
np
.
array
([[
3
,
4
,
6
,
10
],
[
1
,
6
,
7
,
9
],
[
4
,
3
,
8
,
7
],
[
1
,
3
,
7
,
9
]]).
astype
(
np
.
float16
))],
'desc_bprop'
:
[
Tensor
(
np
.
array
([[
3
,
4
,
6
,
10
],
[
1
,
6
,
7
,
9
],
[
4
,
3
,
8
,
7
],
[
1
,
3
,
7
,
9
]]).
astype
(
np
.
float16
))],
'skip'
:
[
'backward'
]}),
(
'Diag'
,
{
'block'
:
DiagNet
(),
...
...
@@ -351,7 +367,6 @@ test_case_math_ops = [
'skip'
:
[
'backward'
]}),
]
test_case_lists
=
[
test_case_math_ops
]
test_exec_case
=
functools
.
reduce
(
lambda
x
,
y
:
x
+
y
,
test_case_lists
)
# use -k to select certain testcast
...
...
@@ -360,6 +375,7 @@ test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
import
mindspore.context
as
context
@
non_graph_engine
@
mindspore_test
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
)
def
test_exec
():
...
...
@@ -369,16 +385,16 @@ def test_exec():
raise_set
=
[
(
'StridedSlice_1_Error'
,
{
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
begin_mask
=
"1"
),
{
'exception'
:
ValueError
}),
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
begin_mask
=
"1"
),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
0
]}),
(
'StridedSlice_2_Error'
,
{
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
end_mask
=
"1"
),
{
'exception'
:
ValueError
}),
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
end_mask
=
"1"
),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
0
]}),
(
'StridedSlice_3_Error'
,
{
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
ellipsis_mask
=
1.1
),
{
'exception'
:
ValueError
}),
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
ellipsis_mask
=
1.1
),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
0
]}),
(
'StridedSlice_4_Error'
,
{
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
new_axis_mask
=
"1.1"
),
{
'exception'
:
ValueError
}),
'block'
:
(
lambda
x
:
P
.
StridedSlice
(
new_axis_mask
=
"1.1"
),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
0
]}),
]
...
...
tests/ut/python/ops/test_nn_ops.py
浏览文件 @
6721541c
...
...
@@ -382,6 +382,46 @@ def test_max_pool_with_arg_max():
print
(
ret
)
class
GradWrapUnfold
(
nn
.
Cell
):
""" GradWrapUnfold definition """
def
__init__
(
self
,
network
):
super
(
GradWrapUnfold
,
self
).
__init__
()
self
.
network
=
network
self
.
sens
=
Tensor
(
np
.
ones
([
1
,
4
,
2
,
2
],
np
.
float32
))
def
construct
(
self
,
x
):
return
C
.
grad_all_with_sens
(
self
.
network
)(
x
,
self
.
sens
)
class
UnfoldNetValid
(
nn
.
Cell
):
""" UnfoldNetValid definition """
def
__init__
(
self
):
super
(
UnfoldNetValid
,
self
).
__init__
()
self
.
unfold
=
nn
.
Unfold
(
ksizes
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
1
,
1
,
1
],
rates
=
[
1
,
1
,
1
,
1
],
padding
=
'VALID'
)
def
construct
(
self
,
x
):
return
self
.
unfold
(
x
)
class
UnfoldNetSame
(
nn
.
Cell
):
""" UnfoldNetSame definition """
def
__init__
(
self
):
super
(
UnfoldNetSame
,
self
).
__init__
()
self
.
unfold
=
nn
.
Unfold
(
ksizes
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
1
,
1
,
1
],
rates
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
def
construct
(
self
,
x
):
return
self
.
unfold
(
x
)
test_cases
=
[
(
'SoftMaxGrad'
,
{
'block'
:
SoftMaxGrad
(
VirtualNetWithLoss
(
P
.
Softmax
())),
...
...
@@ -440,6 +480,21 @@ test_cases = [
'block'
:
ComparisonNet
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
9
,
10
],
np
.
int32
)),
Tensor
(
np
.
ones
([
6
,
9
,
10
],
np
.
int32
))],
}),
(
'UnfoldValid'
,
{
'block'
:
UnfoldNetValid
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
1
,
1
,
3
,
3
],
np
.
float32
))],
'desc_bprop'
:
[
Tensor
(
np
.
ones
([
1
,
4
,
2
,
2
],
np
.
float32
))],
'skip'
:
[
'backward'
]}),
(
'UnfoldSame'
,
{
'block'
:
UnfoldNetSame
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
1
,
1
,
3
,
3
],
np
.
float32
))],
'desc_bprop'
:
[
Tensor
(
np
.
ones
([
1
,
4
,
3
,
3
],
np
.
float32
))],
'skip'
:
[
'backward'
]}),
(
'UnfoldGrad'
,
{
'block'
:
GradWrapUnfold
(
UnfoldNetValid
()),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
1
,
1
,
3
,
3
],
np
.
float32
))],
'desc_bprop'
:
[
Tensor
(
np
.
ones
([
1
,
4
,
2
,
2
],
np
.
float32
))],
'skip'
:
[
'backward'
]}),
]
test_cases_for_verify_exception
=
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录