Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7ffb8bb1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ffb8bb1
编写于
4月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!250 Add nn.pad to support three modes
Merge pull request !250 from casgj/gaojing_new4
上级
60958d6b
2db3e64f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
254 addition
and
3 deletion
+254
-3
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+4
-0
mindspore/ccsrc/transform/op_declare.cc
mindspore/ccsrc/transform/op_declare.cc
+10
-0
mindspore/ccsrc/transform/op_declare.h
mindspore/ccsrc/transform/op_declare.h
+4
-0
mindspore/nn/layer/__init__.py
mindspore/nn/layer/__init__.py
+2
-2
mindspore/nn/layer/basic.py
mindspore/nn/layer/basic.py
+69
-0
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+11
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+18
-0
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+70
-0
tests/ut/python/nn/test_nn_pad.py
tests/ut/python/nn/test_nn_pad.py
+64
-0
未找到文件。
mindspore/ccsrc/transform/convert.cc
浏览文件 @
7ffb8bb1
...
@@ -110,6 +110,8 @@ const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits
...
@@ -110,6 +110,8 @@ const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits
const
char
kNameSigmoidCrossEntropyWithLogitsGrad
[]
=
"SigmoidCrossEntropyWithLogitsGrad"
;
const
char
kNameSigmoidCrossEntropyWithLogitsGrad
[]
=
"SigmoidCrossEntropyWithLogitsGrad"
;
const
char
kNameScatterNdD
[]
=
"ScatterNd"
;
const
char
kNameScatterNdD
[]
=
"ScatterNd"
;
const
char
kNamePadD
[]
=
"Pad"
;
const
char
kNamePadD
[]
=
"Pad"
;
const
char
kNameMirrorPad
[]
=
"MirrorPad"
;
const
char
kNameMirrorPadGrad
[]
=
"MirrorPadGrad"
;
const
char
kNameGatherNd
[]
=
"GatherNd"
;
const
char
kNameGatherNd
[]
=
"GatherNd"
;
const
char
kNameArgmax
[]
=
"Argmax"
;
const
char
kNameArgmax
[]
=
"Argmax"
;
const
char
kNameArgmin
[]
=
"Argmin"
;
const
char
kNameArgmin
[]
=
"Argmin"
;
...
@@ -256,6 +258,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
...
@@ -256,6 +258,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameSigmoidCrossEntropyWithLogitsGrad
),
ADPT_DESC
(
SigmoidCrossEntropyWithLogitsGrad
)},
{
string
(
kNameSigmoidCrossEntropyWithLogitsGrad
),
ADPT_DESC
(
SigmoidCrossEntropyWithLogitsGrad
)},
{
string
(
kNameScatterNdD
),
ADPT_DESC
(
ScatterNdD
)},
{
string
(
kNameScatterNdD
),
ADPT_DESC
(
ScatterNdD
)},
{
string
(
kNamePadD
),
ADPT_DESC
(
PadD
)},
{
string
(
kNamePadD
),
ADPT_DESC
(
PadD
)},
{
string
(
kNameMirrorPad
),
ADPT_DESC
(
MirrorPad
)},
{
string
(
kNameMirrorPadGrad
),
ADPT_DESC
(
MirrorPadGrad
)},
{
string
(
kNameGatherNd
),
ADPT_DESC
(
GatherNd
)},
{
string
(
kNameGatherNd
),
ADPT_DESC
(
GatherNd
)},
{
string
(
kNameArgmax
),
ADPT_DESC
(
ArgMaxD
)},
{
string
(
kNameArgmax
),
ADPT_DESC
(
ArgMaxD
)},
{
string
(
kNameArgmin
),
ADPT_DESC
(
ArgMinD
)},
{
string
(
kNameArgmin
),
ADPT_DESC
(
ArgMinD
)},
...
...
mindspore/ccsrc/transform/op_declare.cc
浏览文件 @
7ffb8bb1
...
@@ -596,6 +596,16 @@ INPUT_MAP(PadD) = {{1, INPUT_DESC(x)}};
...
@@ -596,6 +596,16 @@ INPUT_MAP(PadD) = {{1, INPUT_DESC(x)}};
ATTR_MAP
(
PadD
)
=
{{
"paddings"
,
ATTR_DESC
(
paddings
,
AnyTraits
<
std
::
vector
<
std
::
vector
<
int64_t
>>>
())}};
ATTR_MAP
(
PadD
)
=
{{
"paddings"
,
ATTR_DESC
(
paddings
,
AnyTraits
<
std
::
vector
<
std
::
vector
<
int64_t
>>>
())}};
OUTPUT_MAP
(
PadD
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
OUTPUT_MAP
(
PadD
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// MirrorPad
INPUT_MAP
(
MirrorPad
)
=
{{
1
,
INPUT_DESC
(
x
)},
{
2
,
INPUT_DESC
(
paddings
)}};
ATTR_MAP
(
MirrorPad
)
=
{{
"mode"
,
ATTR_DESC
(
mode
,
AnyTraits
<
std
::
string
>
())}};
OUTPUT_MAP
(
MirrorPad
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// MirrorPadGrad
INPUT_MAP
(
MirrorPadGrad
)
=
{{
1
,
INPUT_DESC
(
x
)},
{
2
,
INPUT_DESC
(
paddings
)}};
ATTR_MAP
(
MirrorPadGrad
)
=
{{
"mode"
,
ATTR_DESC
(
mode
,
AnyTraits
<
std
::
string
>
())}};
OUTPUT_MAP
(
MirrorPadGrad
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// GatherNd
// GatherNd
INPUT_MAP
(
GatherNd
)
=
{{
1
,
INPUT_DESC
(
x1
)},
{
2
,
INPUT_DESC
(
x2
)}};
INPUT_MAP
(
GatherNd
)
=
{{
1
,
INPUT_DESC
(
x1
)},
{
2
,
INPUT_DESC
(
x2
)}};
ATTR_MAP
(
GatherNd
)
=
EMPTY_ATTR_MAP
;
ATTR_MAP
(
GatherNd
)
=
EMPTY_ATTR_MAP
;
...
...
mindspore/ccsrc/transform/op_declare.h
浏览文件 @
7ffb8bb1
...
@@ -155,6 +155,10 @@ DECLARE_OP_USE_INPUT_ATTR(ScatterNdD)
...
@@ -155,6 +155,10 @@ DECLARE_OP_USE_INPUT_ATTR(ScatterNdD)
DECLARE_OP_USE_OUTPUT
(
ScatterNdD
)
DECLARE_OP_USE_OUTPUT
(
ScatterNdD
)
DECLARE_OP_ADAPTER
(
PadD
)
DECLARE_OP_ADAPTER
(
PadD
)
DECLARE_OP_USE_OUTPUT
(
PadD
)
DECLARE_OP_USE_OUTPUT
(
PadD
)
DECLARE_OP_ADAPTER
(
MirrorPad
)
DECLARE_OP_USE_OUTPUT
(
MirrorPad
)
DECLARE_OP_ADAPTER
(
MirrorPadGrad
)
DECLARE_OP_USE_OUTPUT
(
MirrorPadGrad
)
DECLARE_OP_ADAPTER
(
BoundingBoxEncode
)
DECLARE_OP_ADAPTER
(
BoundingBoxEncode
)
DECLARE_OP_USE_OUTPUT
(
BoundingBoxEncode
)
DECLARE_OP_USE_OUTPUT
(
BoundingBoxEncode
)
DECLARE_OP_ADAPTER
(
BoundingBoxDecode
)
DECLARE_OP_ADAPTER
(
BoundingBoxDecode
)
...
...
mindspore/nn/layer/__init__.py
浏览文件 @
7ffb8bb1
...
@@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
...
@@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
from
.container
import
SequentialCell
,
CellList
from
.container
import
SequentialCell
,
CellList
from
.conv
import
Conv2d
,
Conv2dTranspose
from
.conv
import
Conv2d
,
Conv2dTranspose
from
.lstm
import
LSTM
from
.lstm
import
LSTM
from
.basic
import
Dropout
,
Flatten
,
Dense
,
ClipByNorm
,
Norm
,
OneHot
,
ImageGradients
from
.basic
import
Dropout
,
Flatten
,
Dense
,
ClipByNorm
,
Norm
,
OneHot
,
ImageGradients
,
Pad
from
.embedding
import
Embedding
from
.embedding
import
Embedding
from
.pooling
import
AvgPool2d
,
MaxPool2d
from
.pooling
import
AvgPool2d
,
MaxPool2d
...
@@ -34,5 +34,5 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
...
@@ -34,5 +34,5 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'LSTM'
,
'LSTM'
,
'Dropout'
,
'Flatten'
,
'Dense'
,
'ClipByNorm'
,
'Norm'
,
'OneHot'
,
'ImageGradients'
,
'Dropout'
,
'Flatten'
,
'Dense'
,
'ClipByNorm'
,
'Norm'
,
'OneHot'
,
'ImageGradients'
,
'Embedding'
,
'Embedding'
,
'AvgPool2d'
,
'MaxPool2d'
,
'AvgPool2d'
,
'MaxPool2d'
,
'Pad'
,
]
]
mindspore/nn/layer/basic.py
浏览文件 @
7ffb8bb1
...
@@ -415,3 +415,72 @@ class ImageGradients(Cell):
...
@@ -415,3 +415,72 @@ class ImageGradients(Cell):
dx_last
=
P
.
Fill
()(
P
.
DType
()(
images
),
(
batch_size
,
depth
,
height
,
1
),
0
)
dx_last
=
P
.
Fill
()(
P
.
DType
()(
images
),
(
batch_size
,
depth
,
height
,
1
),
0
)
dx
=
P
.
Concat
(
3
)((
dx
,
dx_last
))
dx
=
P
.
Concat
(
3
)((
dx
,
dx_last
))
return
dy
,
dx
return
dy
,
dx
class
Pad
(
Cell
):
"""
Pads the input tensor according to the paddings and mode.
Args:
paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of
paddings are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be
extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how many sizes to
be extended behind of the `D` th dimension of the input tensor.
mode (string): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
Inputs:
- ** input_x** (Tensor) - The input tensor.
Outputs:
Tensor, the tensor after padding.
- If `mode` is "CONSTANT", it fill the edge with 0, regardless of the values of the `input_x`.
If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the
Outputs is [[0,0,0,0,0,0,0],[0,0,1,2,3,0,0],[0,0,4,5,6,0,0],[0,0,7,8,9,0,0],[0,0,0,0,0,0,0]].
- If 'mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in,
symmetry. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the
Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]].
- If 'mode' is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
according to the symmetry axis, except that it includes the symmetry axis. If the `input_x`
is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is
[[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]].
Examples:
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> import mindspore.nn as nn
>>> import numpy as np
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.pad = nn.Pad(paddings=((1,1),(2,2)), mode="CONSTANT")
>>> def construct(self, x):
>>> return self.pad(x)
>>> x = np.random.random(size=(2, 3)).astype(np.float32)
>>> pad = Net()
>>> ms_output = pad(Tensor(x))
"""
def
__init__
(
self
,
paddings
,
mode
=
"CONSTANT"
):
super
(
Pad
,
self
).
__init__
()
self
.
mode
=
mode
self
.
paddings
=
paddings
validator
.
check_string
(
'mode'
,
self
.
mode
,
[
"CONSTANT"
,
"REFLECT"
,
"SYMMETRIC"
])
if
not
isinstance
(
paddings
,
tuple
):
raise
TypeError
(
'Paddings must be tuple type.'
)
for
item
in
paddings
:
if
len
(
item
)
!=
2
:
raise
ValueError
(
'The shape of paddings must be (n, 2).'
)
if
mode
==
"CONSTANT"
:
self
.
pad
=
P
.
Pad
(
self
.
paddings
)
else
:
self
.
paddings
=
Tensor
(
np
.
array
(
self
.
paddings
))
self
.
pad
=
P
.
MirrorPad
(
mode
=
mode
)
def
construct
(
self
,
x
):
if
self
.
mode
==
"CONSTANT"
:
x
=
self
.
pad
(
x
)
else
:
x
=
self
.
pad
(
x
,
self
.
paddings
)
return
x
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
7ffb8bb1
...
@@ -470,6 +470,17 @@ def get_bprop_pad(self):
...
@@ -470,6 +470,17 @@ def get_bprop_pad(self):
return
bprop
return
bprop
@
bprop_getters
.
register
(
P
.
MirrorPad
)
def
get_bprop_mirror_pad
(
self
):
"""Grad definition for `MirrorPad` operation."""
mirror_pad_grad
=
G
.
MirrorPadGrad
(
self
.
mode
)
def
bprop
(
x
,
paddings
,
out
,
dout
):
dx
=
mirror_pad_grad
(
dout
,
paddings
,
x
)
return
(
dx
,
zeros_like
(
paddings
))
return
bprop
@
bprop_getters
.
register
(
P
.
ROIAlign
)
@
bprop_getters
.
register
(
P
.
ROIAlign
)
def
get_bprop_roi_align
(
self
):
def
get_bprop_roi_align
(
self
):
"""Grad definition for `ROIAlign` operation."""
"""Grad definition for `ROIAlign` operation."""
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
7ffb8bb1
...
@@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
...
@@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
LogSoftmax
,
LogSoftmax
,
MaxPool
,
MaxPool
,
AvgPool
,
Conv2DBackpropInput
,
AvgPool
,
Conv2DBackpropInput
,
MaxPoolWithArgmax
,
OneHot
,
Pad
,
PReLU
,
ReLU
,
ReLU6
,
HSwish
,
HSigmoid
,
MaxPoolWithArgmax
,
OneHot
,
Pad
,
MirrorPad
,
PReLU
,
ReLU
,
ReLU6
,
HSwish
,
HSigmoid
,
ResizeBilinear
,
Sigmoid
,
ResizeBilinear
,
Sigmoid
,
SigmoidCrossEntropyWithLogits
,
SigmoidCrossEntropyWithLogits
,
SmoothL1Loss
,
Softmax
,
SmoothL1Loss
,
Softmax
,
...
@@ -180,6 +180,7 @@ __all__ = [
...
@@ -180,6 +180,7 @@ __all__ = [
'ScatterNd'
,
'ScatterNd'
,
'ResizeNearestNeighbor'
,
'ResizeNearestNeighbor'
,
'Pad'
,
'Pad'
,
'MirrorPad'
,
'GatherNd'
,
'GatherNd'
,
'ScatterNdUpdate'
,
'ScatterNdUpdate'
,
'Floor'
,
'Floor'
,
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
7ffb8bb1
...
@@ -947,6 +947,24 @@ class TanhGrad(PrimitiveWithInfer):
...
@@ -947,6 +947,24 @@ class TanhGrad(PrimitiveWithInfer):
return
out
return
out
class
MirrorPadGrad
(
PrimitiveWithInfer
):
"""Gradients of MirrorPad operation."""
@
prim_attr_register
def
__init__
(
self
,
mode
=
"REFLECT"
):
"""init MirrorPad"""
validator
.
check_string
(
'mode'
,
mode
,
[
'REFLECT'
,
'SYMMETRIC'
])
self
.
mode
=
mode
def
__infer__
(
self
,
dout
,
paddings
,
x
):
validator
.
check_subclass
(
"dout"
,
dout
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"paddings"
,
paddings
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x
[
'dtype'
],
mstype
.
tensor
)
return
{
'shape'
:
x
[
'shape'
],
'dtype'
:
dout
[
'dtype'
],
'value'
:
None
}
class
RefToEmbed
(
Primitive
):
class
RefToEmbed
(
Primitive
):
r
"""
r
"""
Make a key from Ref.
Make a key from Ref.
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
7ffb8bb1
...
@@ -2096,6 +2096,7 @@ class Pad(PrimitiveWithInfer):
...
@@ -2096,6 +2096,7 @@ class Pad(PrimitiveWithInfer):
for
item
in
paddings
:
for
item
in
paddings
:
if
len
(
item
)
!=
2
:
if
len
(
item
)
!=
2
:
raise
ValueError
(
'The shape of paddings must be (n, 2).'
)
raise
ValueError
(
'The shape of paddings must be (n, 2).'
)
self
.
paddings
=
paddings
def
infer_shape
(
self
,
x
):
def
infer_shape
(
self
,
x
):
paddings
=
np
.
array
(
self
.
paddings
)
paddings
=
np
.
array
(
self
.
paddings
)
...
@@ -2108,9 +2109,78 @@ class Pad(PrimitiveWithInfer):
...
@@ -2108,9 +2109,78 @@ class Pad(PrimitiveWithInfer):
return
y_shape
return
y_shape
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
validator
.
check_subclass
(
"input_x"
,
x
,
mstype
.
tensor
)
return
x
return
x
class
MirrorPad
(
PrimitiveWithInfer
):
"""
Pads the input tensor according to the paddings and mode.
Args:
mode (string): Specifies padding mode. The optional values are "REFLECT", "SYMMETRIC".
Default: "REFLECT".
Inputs:
- **input_x** (Tensor) - The input tensor.
- **paddings** (Tensor) - The paddings tensor. The value of `paddings` is a matrix(list),
and its shape is (N, 2). N is the rank of input data. All elements of paddings
are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be
extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates
how many sizes to be extended behind of the `D` th dimension of the input tensor.
Outputs:
Tensor, the tensor after padding.
- If 'mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in,
symmetry. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the
Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]].
- If 'mode' is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
according to the symmetry axis, except that it includes the symmetry axis. If the `input_x`
is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is
[[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]].
Examples:
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> import mindspore.nn as nn
>>> import numpy as np
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.pad = P.MirrorPad(mode="REFLECT")
>>> def construct(self, x, paddings):
>>> return self.pad(x, paddings)
>>> x = np.random.random(size=(2, 3)).astype(np.float32)
>>> paddings = Tensor([[1,1],[2,2]])
>>> pad = Net()
>>> ms_output = pad(Tensor(x), paddings)
"""
@
prim_attr_register
def
__init__
(
self
,
mode
=
'REFLECT'
):
"""Init Pad"""
validator
.
check_string
(
'mode'
,
mode
,
[
'REFLECT'
,
'SYMMETRIC'
])
self
.
mode
=
mode
def
__infer__
(
self
,
input_x
,
paddings
):
validator
.
check_subclass
(
"input_x"
,
input_x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"paddings"
,
paddings
[
'dtype'
],
mstype
.
tensor
)
x_shape
=
list
(
input_x
[
'shape'
])
paddings_value
=
paddings
[
'value'
].
asnumpy
()
paddings_size
=
paddings_value
.
size
validator
.
check_integer
(
'paddings.shape'
,
paddings_size
,
len
(
x_shape
)
*
2
,
Rel
.
EQ
)
if
not
np
.
all
(
paddings_size
>=
0
):
raise
ValueError
(
'All elements of paddings must be >= 0.'
)
y_shape
=
()
for
i
in
range
(
0
,
int
(
paddings_size
/
2
)):
y_shape
+=
((
x_shape
[
i
]
+
paddings_value
[
i
,
0
]
+
paddings_value
[
i
,
1
]),)
return
{
'shape'
:
y_shape
,
'dtype'
:
input_x
[
'dtype'
],
'value'
:
None
}
class
ROIAlign
(
PrimitiveWithInfer
):
class
ROIAlign
(
PrimitiveWithInfer
):
"""
"""
Computes Region of Interest (RoI) Align operator.
Computes Region of Interest (RoI) Align operator.
...
...
tests/ut/python/nn/test_nn_pad.py
0 → 100644
浏览文件 @
7ffb8bb1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test nn pad """
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore.ops.composite
import
GradOperation
from
mindspore.common.api
import
ms_function
import
numpy
as
np
import
mindspore.context
as
context
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
raw_paddings
,
mode
):
super
(
Net
,
self
).
__init__
()
self
.
pad
=
nn
.
Pad
(
raw_paddings
,
mode
=
mode
)
@
ms_function
def
construct
(
self
,
x
):
return
self
.
pad
(
x
)
class
Grad
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
Grad
,
self
).
__init__
()
self
.
grad
=
GradOperation
(
name
=
"get_all"
,
get_all
=
True
,
sens_param
=
True
)
self
.
network
=
network
@
ms_function
def
construct
(
self
,
x
,
grads
):
return
self
.
grad
(
self
.
network
)(
x
,
grads
)
def
test_pad_train
():
mode
=
'CONSTANT'
x
=
np
.
random
.
random
(
size
=
(
2
,
3
)).
astype
(
np
.
float32
)
raw_paddings
=
((
1
,
1
),
(
2
,
2
))
grads
=
np
.
random
.
random
(
size
=
(
4
,
7
)).
astype
(
np
.
float32
)
grad
=
Grad
(
Net
(
raw_paddings
,
mode
))
output
=
grad
(
Tensor
(
x
),
Tensor
(
grads
))
print
(
"=================output===================="
)
print
(
output
)
def
test_pad_infer
():
mode
=
'CONSTANT'
x
=
np
.
random
.
random
(
size
=
(
2
,
3
)).
astype
(
np
.
float32
)
raw_paddings
=
((
1
,
1
),
(
2
,
2
))
net
=
Net
(
raw_paddings
,
mode
)
output
=
net
(
Tensor
(
x
))
print
(
"=================output===================="
)
print
(
output
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录