Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
兔爷不爱我
mindspore
提交
cd6e8d65
M
mindspore
项目概览
兔爷不爱我
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cd6e8d65
编写于
4月 22, 2020
作者:
H
huanghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix ReluV2's mask shape in derelu fusion pass
上级
b48d663c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
34 addition
and
5 deletion
+34
-5
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+5
-1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
...re_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
+4
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc
...pore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc
+18
-2
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-1
tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc
...tivate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc
+5
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py
...pp/python_input/gtest_input/pre_activate/derelu_fusion.py
+1
-1
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
cd6e8d65
...
...
@@ -46,6 +46,8 @@
#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h"
#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h"
#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h"
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "pre_activate/ascend/ir_fusion/derelu_fusion.h"
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/pass/optimize_dependence.h"
...
...
@@ -94,8 +96,10 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
MulAddNFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
MatmulBiasaddFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AddnFission
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
DereluFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ConfusionMulGradFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
TransposeTransDataFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
}
}
// namespace
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
浏览文件 @
cd6e8d65
...
...
@@ -18,6 +18,7 @@
#include <memory>
#include <vector>
#include <algorithm>
#include <string>
#include "session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
...
...
@@ -89,6 +90,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
auto
reduce_sum
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
reduce_sum
);
auto
mul1
=
reduce_sum
->
input
(
1
);
if
(
mul1
->
fullname_with_scope
().
find
(
"bert/encoder"
)
==
std
::
string
::
npos
)
{
return
nullptr
;
}
if
(
IsUsedByOthers
(
graph
,
mul1
))
{
MS_LOG
(
INFO
)
<<
"Mul1 is used by others, quit fusion!"
;
return
nullptr
;
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc
浏览文件 @
cd6e8d65
...
...
@@ -50,9 +50,22 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
MS_EXCEPTION_IF_NULL
(
new_node
);
new_node
->
set_scope
(
relu
->
scope
());
// ReluV2's 2rd output is mask whose data type is uint8
and value is 0 or 1, so shape is an empty vector
// ReluV2's 2rd output is mask whose data type is uint8
TypeId
mask_dtype
=
kNumberTypeUInt8
;
std
::
vector
<
size_t
>
mask_shape
;
std
::
vector
<
size_t
>
mask_shape
=
AnfAlgo
::
GetOutputInferShape
(
relu
,
0
);
if
(
mask_shape
.
size
()
!=
4
)
{
MS_LOG
(
WARNING
)
<<
"relu's infer shape size not equal 4"
;
return
nullptr
;
}
auto
input_dtype
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
relu
,
0
);
if
(
input_dtype
==
kNumberTypeUInt8
||
input_dtype
==
kNumberTypeInt8
)
{
mask_shape
[
1
]
=
(
mask_shape
[
1
]
+
31
)
/
32
;
mask_shape
.
push_back
(
4
);
}
else
{
mask_shape
[
1
]
=
(
mask_shape
[
1
]
+
15
)
/
16
;
mask_shape
.
push_back
(
2
);
}
auto
types
=
{
AnfAlgo
::
GetOutputInferDataType
(
relu
,
0
),
mask_dtype
};
auto
shapes
=
{
AnfAlgo
::
GetOutputInferShape
(
relu
,
0
),
mask_shape
};
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
new_node
.
get
());
...
...
@@ -91,6 +104,9 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
MS_EXCEPTION_IF_NULL
(
relu
);
auto
relu_v2
=
CreateReluV2
(
graph
,
relu
);
if
(
relu_v2
==
nullptr
)
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
relu_v2_node_outputs
;
CreateMultipleOutputsOfAnfNode
(
graph
,
relu_v2
,
kReluV2OutputNum
,
&
relu_v2_node_outputs
);
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
cd6e8d65
...
...
@@ -120,7 +120,7 @@ constexpr auto kStreamActiveOpName = "StreamActive";
constexpr
auto
kAssignAddOpName
=
"AssignAdd"
;
constexpr
auto
kSendOpName
=
"Send"
;
constexpr
auto
kRecvOpName
=
"Recv"
;
constexpr
auto
kReluV2OpName
=
"Re
lu
V2"
;
constexpr
auto
kReluV2OpName
=
"Re
LU
V2"
;
constexpr
auto
kReluGradV2OpName
=
"ReluGradV2"
;
// attr key name
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc
浏览文件 @
cd6e8d65
...
...
@@ -32,6 +32,11 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon {
TEST_F
(
TestHWOptimizeConfusionMulGradFusion
,
test_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_confusion_mul_grad_fusion"
,
"before"
);
EXPECT_NE
(
g
,
nullptr
);
auto
bert_scope
=
std
::
make_shared
<
Scope
>
(
"bert/encoder"
);
for
(
auto
node
:
TopoSort
(
g
->
get_return
()))
{
node
->
set_scope
(
bert_scope
);
}
std
::
vector
<
int
>
shp
{
1
,
1
,
1
,
1
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py
浏览文件 @
cd6e8d65
...
...
@@ -17,7 +17,7 @@ from mindspore.ops import Primitive
relu
=
P
.
ReLU
()
relu_grad
=
Primitive
(
'ReluGrad'
)
relu_v2
=
Primitive
(
'Re
lu
V2'
)
relu_v2
=
Primitive
(
'Re
LU
V2'
)
relu_grad_v2
=
Primitive
(
'ReluGradV2'
)
make_tuple
=
Primitive
(
'make_tuple'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录