Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b6e77e51
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b6e77e51
编写于
4月 21, 2020
作者:
L
liuxiao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ReluV2/ReluGradV2/ConfusionMulGrad for VM
上级
4e85ca68
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
337 addition
and
3 deletion
+337
-3
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
+1
-0
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+12
-0
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+3
-0
mindspore/ops/_op_impl/tbe/confusion_mul_grad.py
mindspore/ops/_op_impl/tbe/confusion_mul_grad.py
+38
-0
mindspore/ops/_op_impl/tbe/relu_grad_v2.py
mindspore/ops/_op_impl/tbe/relu_grad_v2.py
+40
-0
mindspore/ops/_op_impl/tbe/relu_v2.py
mindspore/ops/_op_impl/tbe/relu_v2.py
+40
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+4
-2
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+21
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+105
-0
tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py
tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py
+53
-0
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+19
-0
未找到文件。
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
浏览文件 @
b6e77e51
...
...
@@ -33,6 +33,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{
"re_lu6"
,
"relu6"
},
{
"re_lu6_grad"
,
"relu6_grad"
},
{
"re_lu"
,
"relu"
},
{
"re_luv2"
,
"relu_v2"
},
{
"tensor_add"
,
"add"
},
{
"reduce_mean"
,
"reduce_mean_d"
},
{
"reduce_max"
,
"reduce_max_d"
},
...
...
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
b6e77e51
...
...
@@ -227,6 +227,18 @@ def get_bprop_relu6(self):
return
bprop
@
bprop_getters
.
register
(
P
.
ReLUV2
)
def
get_bprop_relu_v2
(
self
):
"""Grad definition for `ReLUV2` operation."""
input_grad
=
G
.
ReluGradV2
()
def
bprop
(
x
,
out
,
dout
):
mask
=
out
[
1
]
dx
=
input_grad
(
dout
[
0
],
mask
)
return
(
dx
,)
return
bprop
@
bprop_getters
.
register
(
P
.
HSwish
)
def
get_bprop_hswish
(
self
):
"""Grad definition for `HSwish` operation."""
...
...
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
b6e77e51
...
...
@@ -33,6 +33,7 @@ from .cast import _cast_tbe
from
.conv2d
import
_conv2d_tbe
from
.conv2d_backprop_filter
import
_conv2d_backprop_filter_tbe
from
.conv2d_backprop_input
import
_conv2d_backprop_input_tbe
from
.confusion_mul_grad
import
_confusion_mul_grad_tbe
from
.dropout_do_mask
import
_dropout_do_mask_tbe
from
.gelu
import
_gelu_tbe
from
.gelu_grad
import
_gelu_grad_tbe
...
...
@@ -46,6 +47,8 @@ from .relu import _relu_tbe
from
.relu_grad
import
_relu_grad_tbe
from
.relu6
import
_relu6_tbe
from
.relu6_grad
import
_relu6_grad_tbe
from
.relu_v2
import
_relu_v2_tbe
from
.relu_grad_v2
import
_relu_grad_v2_tbe
from
.softmax_cross_entropy_with_logits
import
_softmax_cross_entropy_with_logits_tbe
from
.sigmoid_cross_entropy_with_logits
import
_sigmoid_cross_entropy_with_logits_tbe
from
.sigmoid_cross_entropy_with_logits_grad
import
_sigmoid_cross_entropy_with_logits_grad_tbe
...
...
mindspore/ops/_op_impl/tbe/confusion_mul_grad.py
0 → 100644
浏览文件 @
b6e77e51
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ConfusionMulGrad op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
confusion_mul_grad_op_info
=
TBERegOp
(
"ConfusionMulGrad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
attr
(
"axis"
,
"required"
,
"listInt"
,
"all"
)
\
.
attr
(
"keep_dims"
,
"required"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"input0"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"input1"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"input2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"output0"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"output1"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
@
op_info_register
(
confusion_mul_grad_op_info
)
def
_confusion_mul_grad_tbe
():
"""ConfusionMulGrad TBE register"""
return
mindspore/ops/_op_impl/tbe/relu_grad_v2.py
0 → 100644
浏览文件 @
b6e77e51
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReluGradV2 op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
relu_grad_v2_op_info
=
TBERegOp
(
"ReluGradV2"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"relu_grad_v2.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"relu_grad_v2"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"gradients"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"mask"
,
False
,
"rerequired"
,
"all"
)
\
.
output
(
0
,
"backprops"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
U8_Default
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
U8_Default
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
U8_Default
,
DataType
.
I32_5HD
)
\
.
dtype_format
(
DataType
.
I8_5HD
,
DataType
.
U8_Default
,
DataType
.
I8_5HD
)
\
.
dtype_format
(
DataType
.
U8_5HD
,
DataType
.
U8_Default
,
DataType
.
U8_5HD
)
\
.
get_op_info
()
@
op_info_register
(
relu_grad_v2_op_info
)
def
_relu_grad_v2_tbe
():
"""ReluGradV2 TBE register"""
return
mindspore/ops/_op_impl/tbe/relu_v2.py
0 → 100644
浏览文件 @
b6e77e51
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReluV2 op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
relu_v2_op_info
=
TBERegOp
(
"ReLUV2"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"relu_v2.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"relu_v2"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"mask"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
I8_5HD
,
DataType
.
I8_5HD
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
U8_5HD
,
DataType
.
U8_5HD
,
DataType
.
U8_Default
)
\
.
get_op_info
()
@
op_info_register
(
relu_v2_op_info
)
def
_relu_v2_tbe
():
"""ReluV2 TBE register"""
return
mindspore/ops/operations/__init__.py
浏览文件 @
b6e77e51
...
...
@@ -58,8 +58,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
GetNext
,
L2Normalize
,
LayerNorm
,
L2Loss
,
LogSoftmax
,
MaxPool
,
ExtractImagePatches
,
AvgPool
,
Conv2DBackpropInput
,
MaxPoolWithArgmax
,
OneHot
,
Pad
,
MirrorPad
,
PReLU
,
ReLU
,
ReLU6
,
HSwish
,
HSigmoid
,
AvgPool
,
Conv2DBackpropInput
,
ConfusionMulGrad
,
MaxPoolWithArgmax
,
OneHot
,
Pad
,
MirrorPad
,
PReLU
,
ReLU
,
ReLU6
,
ReLUV2
,
HSwish
,
HSigmoid
,
ResizeBilinear
,
Sigmoid
,
SigmoidCrossEntropyWithLogits
,
SmoothL1Loss
,
Softmax
,
...
...
@@ -101,6 +101,7 @@ __all__ = [
'LogSoftmax'
,
'SoftmaxCrossEntropyWithLogits'
,
'ROIAlign'
,
'ConfusionMulGrad'
,
'SparseSoftmaxCrossEntropyWithLogits'
,
'SGD'
,
'ApplyMomentum'
,
...
...
@@ -138,6 +139,7 @@ __all__ = [
'Split'
,
'ReLU'
,
'ReLU6'
,
'ReLUV2'
,
'Elu'
,
'Erf'
,
'Sigmoid'
,
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
b6e77e51
...
...
@@ -730,6 +730,27 @@ class ReLU6Grad(PrimitiveWithInfer):
return
x_dtype
class
ReluGradV2
(
PrimitiveWithInfer
):
"""Performs grad of ReLUV2 operation."""
@
prim_attr_register
def
__init__
(
self
):
self
.
init_prim_io_names
(
inputs
=
[
'gradients'
,
'mask'
],
outputs
=
[
'output'
])
def
__call__
(
self
,
gradients
,
mask
):
raise
NotImplementedError
def
infer_shape
(
self
,
gradients_shape
,
mask_shape
):
return
gradients_shape
def
infer_dtype
(
self
,
gradients_dtype
,
mask_dtype
):
args_type
=
{
'gradients'
:
gradients_dtype
,
'mask'
:
mask_dtype
}
validator
.
check_args_tensor
(
args_type
)
validator
.
check_typename
(
"gradients_dtype"
,
gradients_dtype
,
mstype
.
number_type
)
validator
.
check_typename
(
"mask_dtype"
,
mask_dtype
,
(
mstype
.
uint8
,))
return
gradients_dtype
class
EluGrad
(
PrimitiveWithInfer
):
"""Performs grad of Elu operation."""
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
b6e77e51
...
...
@@ -1329,7 +1329,7 @@ class Concat(PrimitiveWithInfer):
def
_get_pack_shape
(
x_shape
,
x_type
,
axis
):
"""for pack output shape"""
validator
.
check_type
(
"shape"
,
x_shape
,
[
tuple
])
validator
.
check_type
(
"shape"
,
x_shape
,
[
tuple
,
list
])
validator
.
check_integer
(
"len of input_x shape"
,
len
(
x_shape
),
0
,
Rel
.
GT
)
validator
.
check_subclass
(
"shape0"
,
x_type
[
0
],
mstype
.
tensor
)
validator
.
check_integer
(
"len of input_x0 shape"
,
len
(
x_shape
[
0
]),
0
,
Rel
.
GT
)
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
b6e77e51
...
...
@@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
..operations.math_ops
import
_infer_shape_reduce
def
_check_positive_int_or_tuple
(
arg_name
,
arg_value
,
prim_name
,
allow_four
=
False
,
ret_four
=
False
):
...
...
@@ -233,6 +234,62 @@ class ReLU6(PrimitiveWithInfer):
return
input_x
class
ReLUV2
(
PrimitiveWithInfer
):
r
"""
Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
It returns :math:`\max(x,\ 0)` element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor should be a 4-D tensor.
Outputs:
- **output** (Tensor) - Has the same type and shape as the `input_x`.
- **mask** (Tensor) - A tensor whose data type must be uint8.
Examples:
>>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32)
>>> relu_v2 = P.ReLUV2()
>>> output = relu_v2(input_x)
([[[[1., 0.], [0., 4.]], [[0., 6.], [7., 0.]]]],
[[[[1, 0], [2, 0]], [[2, 0], [1, 0]]]])
"""
@
prim_attr_register
def
__init__
(
self
):
"""init ReLUV2"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
,
'mask'
])
def
__infer__
(
self
,
input_x
):
input_shape
=
list
(
input_x
[
'shape'
])
input_dtype
=
input_x
[
'dtype'
]
mask_shape
=
[]
if
len
(
input_shape
)
!=
4
:
raise
ValueError
(
"The `input_x` should be a 4-D tensor, "
f
"but got a
{
len
(
input_shape
)
}
-D tensor whose shape is
{
input_shape
}
"
)
for
i
in
enumerate
(
input_shape
):
if
i
[
0
]
==
1
:
if
input_dtype
==
mstype
.
uint8
and
input_dtype
==
mstype
.
int8
:
mask_shape
.
append
((
input_shape
[
1
]
+
31
)
//
32
)
else
:
mask_shape
.
append
((
input_shape
[
1
]
+
15
)
//
16
)
else
:
mask_shape
.
append
(
i
[
1
])
if
input_dtype
==
mstype
.
uint8
and
input_dtype
==
mstype
.
int8
:
mask_shape
.
append
(
4
)
else
:
mask_shape
.
append
(
2
)
output_shape
=
(
input_x
[
'shape'
],
mask_shape
)
validator
.
check_subclass
(
"input_x"
,
input_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
'input_x'
:
input_dtype
},
mstype
.
number_type
,
self
.
name
)
mask_dtype
=
mstype
.
uint8
output_dtype
=
(
input_dtype
,
mask_dtype
)
return
{
'shape'
:
output_shape
,
'dtype'
:
output_dtype
,
'value'
:
None
}
class
Elu
(
PrimitiveWithInfer
):
r
"""
Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise.
...
...
@@ -2580,3 +2637,51 @@ class ExtractImagePatches(PrimitiveWithInfer):
def
infer_dtype
(
self
,
input_x
):
validator
.
check_tensor_type_same
({
"input_x"
:
input_x
},
(
mstype
.
int8
,
mstype
.
float16
,
mstype
.
float32
),
self
.
name
)
return
input_x
class
ConfusionMulGrad
(
PrimitiveWithInfer
):
"""
`output0` is the result of which input0 dot multily input1.
`output1` is the result of which input0 dot multily input1, then reducesum it.
Args:
axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
Default:(), reduce all dimensions. Only constant value is allowed.
keep_dims (bool):
- If true, keep these reduced dimensions and the length is 1.
- If false, don't keep these dimensions. Default:False.
Inputs:
- **input_0** (Tensor) - The input Tensor.
- **input_1** (Tensor) - The input Tensor.
- **input_2** (Tensor) - The input Tensor.
outputs:
- **output_0** (Tensor) - The same shape with `input0`.
- **output_1** (Tensor)
- If axis is (), and keep_dims is false, the output is a 0-D array representing
the sum of all elements in the input array.
- If axis is int, set as 2, and keep_dims is false,
the shape of output is :math:`(x_1,x_3,...,x_R)`.
- If axis is tuple(int), set as (2,3), and keep_dims is false,
the shape of output is :math:`(x_1,x_4,...x_R)`.
"""
@
prim_attr_register
def
__init__
(
self
,
axis
=
(),
keep_dims
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
"input0"
,
"input1"
,
"input2"
],
outputs
=
[
"output0"
,
"output1"
])
self
.
axis_
=
validator
.
check_value_type
(
"axis"
,
axis
,
[
int
,
tuple
,
list
],
self
.
name
)
self
.
keep_dims_
=
validator
.
check_value_type
(
"keep_dims"
,
keep_dims
,
[
bool
],
self
.
name
)
def
infer_shape
(
self
,
input0_shape
,
input1_shape
,
input2_shape
):
outshape0
=
input0_shape
outshape1
=
_infer_shape_reduce
(
input1_shape
,
self
.
axis_
,
self
.
keep_dims_
,
self
.
name
)
return
outshape0
,
outshape1
def
infer_dtype
(
self
,
input0_dtype
,
input1_dtype
,
input2_dtype
):
validator
.
check_subclass
(
"input0_dtype"
,
input0_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"input1_dtype"
,
input1_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"input2_dtype"
,
input2_dtype
,
mstype
.
tensor
,
self
.
name
)
return
input0_dtype
,
input1_dtype
tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py
0 → 100644
浏览文件 @
b6e77e51
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore.common.api
import
ms_function
import
numpy
as
np
import
mindspore.context
as
context
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore.ops.composite
import
GradOperation
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Grad
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
Grad
,
self
).
__init__
()
self
.
grad
=
GradOperation
(
name
=
"get_all"
,
get_all
=
True
)
self
.
network
=
network
@
ms_function
def
construct
(
self
,
input
):
return
self
.
grad
(
self
.
network
)(
input
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
relu_v2
=
P
.
ReLUV2
()
def
construct
(
self
,
x
):
return
self
.
relu_v2
(
x
)
def
test_net
():
x
=
Tensor
(
np
.
ones
((
2
,
3
,
3
,
4
)).
astype
(
np
.
float32
))
relu_net
=
Net
()
relu_output
=
relu_net
(
x
)
net
=
Grad
(
Net
())
output_grad
=
net
(
x
)
print
(
relu_output
[
0
].
asnumpy
())
print
(
relu_output
[
1
].
asnumpy
())
print
(
len
(
output_grad
))
print
(
output_grad
[
0
].
asnumpy
())
tests/ut/python/ops/test_ops.py
浏览文件 @
b6e77e51
...
...
@@ -582,6 +582,10 @@ test_case_nn_ops = [
'block'
:
P
.
ReLU6
(),
'desc_inputs'
:
[[
1
,
3
,
4
,
4
]],
'desc_bprop'
:
[[
1
,
3
,
4
,
4
]]}),
(
'ReLUV2'
,
{
'block'
:
P
.
ReLUV2
(),
'desc_inputs'
:
[[
1
,
3
,
4
,
4
]],
'desc_bprop'
:
[[
1
,
3
,
4
,
4
],
[
1
,
3
,
4
,
4
]]}),
(
'ReLUGrad'
,
{
'block'
:
G
.
ReluGrad
(),
'desc_inputs'
:
[[
1
,
3
,
4
,
4
],
[
1
,
3
,
4
,
4
]],
...
...
@@ -1134,6 +1138,21 @@ test_case_other_ops = [
'desc_inputs'
:
[
Tensor
(
np
.
array
([
1.1
]).
astype
(
np
.
float32
)),
Tensor
(
np
.
array
([
1.2
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
(
'ConfusionMulGrad_1'
,
{
'block'
:
P
.
ConfusionMulGrad
(
axis
=
[
0
],
keep_dims
=
False
),
'desc_inputs'
:
[[
3
,
2
],
[
3
,
2
],
[
3
,
2
]],
'desc_bprop'
:
[[
3
,
2
],
[
2
]],
'skip'
:
[
'backward'
]}),
(
'ConfusionMulGrad_2'
,
{
'block'
:
P
.
ConfusionMulGrad
(
axis
=
[
0
],
keep_dims
=
True
),
'desc_inputs'
:
[[
3
,
2
],
[
3
,
2
],
[
3
,
2
]],
'desc_bprop'
:
[[
3
,
2
],
[
1
,
2
]],
'skip'
:
[
'backward'
]}),
(
'ConfusionMulGrad_3'
,
{
'block'
:
P
.
ConfusionMulGrad
(
axis
=
(),
keep_dims
=
True
),
'desc_inputs'
:
[[
2
,
3
,
4
],
[
2
,
3
,
4
],
[
2
,
3
,
4
]],
'desc_bprop'
:
[[
2
,
3
,
4
],
[
1
,
1
,
1
]],
'skip'
:
[
'backward'
]}),
(
'HistogramSummary'
,
{
'block'
:
HistogramSummaryNet
(),
'desc_inputs'
:
[
Tensor
(
np
.
array
([
1.1
]).
astype
(
np
.
float32
)),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录