Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0db45147
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0db45147
编写于
4月 09, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dock relu6 for open source process and fix pow bprop
上级
734f8a7f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
90 addition
and
3 deletion
+90
-3
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
+2
-0
mindspore/ops/_grad/grad_math_ops.py
mindspore/ops/_grad/grad_math_ops.py
+3
-3
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+2
-0
mindspore/ops/_op_impl/tbe/relu6.py
mindspore/ops/_op_impl/tbe/relu6.py
+40
-0
mindspore/ops/_op_impl/tbe/relu6_grad.py
mindspore/ops/_op_impl/tbe/relu6_grad.py
+43
-0
未找到文件。
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
浏览文件 @
0db45147
...
...
@@ -30,6 +30,8 @@ namespace mindspore {
namespace
kernel
{
namespace
tbe
{
static
std
::
map
<
string
,
string
>
tbe_func_adapter_map
=
{
{
"re_lu6"
,
"relu6"
},
{
"re_lu6_grad"
,
"relu6_grad"
},
{
"re_lu"
,
"relu"
},
{
"tensor_add"
,
"add"
},
{
"reduce_mean"
,
"reduce_mean_d"
},
...
...
mindspore/ops/_grad/grad_math_ops.py
浏览文件 @
0db45147
...
...
@@ -340,9 +340,9 @@ def get_bprop_pow(self):
ln
=
P
.
Log
()
def
bprop
(
x
,
power
,
out
,
dout
):
dx
=
power
*
pow_op
(
x
,
power
-
1.0
)
*
dout
dpower
=
pow_op
(
x
,
power
)
*
ln
(
x
)
*
dout
return
dx
,
dpower
bc_
dx
=
power
*
pow_op
(
x
,
power
-
1.0
)
*
dout
bc_dpower
=
out
*
ln
(
x
)
*
dout
return
binop_grad_common
(
x
,
power
,
bc_dx
,
bc_dpower
)
return
bprop
...
...
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
0db45147
...
...
@@ -42,6 +42,8 @@ from .mul import _mul_tbe
from
.real_div
import
_real_div_tbe
from
.relu
import
_relu_tbe
from
.relu_grad
import
_relu_grad_tbe
from
.relu6
import
_relu6_tbe
from
.relu6_grad
import
_relu6_grad_tbe
from
.softmax_cross_entropy_with_logits
import
_softmax_cross_entropy_with_logits_tbe
from
.sigmoid_cross_entropy_with_logits
import
_sigmoid_cross_entropy_with_logits_tbe
from
.sigmoid_cross_entropy_with_logits_grad
import
_sigmoid_cross_entropy_with_logits_grad_tbe
...
...
mindspore/ops/_op_impl/tbe/relu6.py
0 → 100644
浏览文件 @
0db45147
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReLU6 op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
relu6_op_info
=
TBERegOp
(
"ReLU6"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"relu6.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"relu6"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"features"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"activations"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
relu6_op_info
)
def
_relu6_tbe
():
"""Relu6 TBE register"""
return
mindspore/ops/_op_impl/tbe/relu6_grad.py
0 → 100644
浏览文件 @
0db45147
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReLU6Grad op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
relu6_grad_op_info
=
TBERegOp
(
"ReLU6Grad"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"relu6_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"relu6_grad"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"gradients"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"features"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"backprops"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
dtype_format
(
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
.
get_op_info
()
@
op_info_register
(
relu6_grad_op_info
)
def
_relu6_grad_tbe
():
"""Relu6Grad TBE register"""
return
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录