Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
79d1e465
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79d1e465
编写于
4月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!702 add buffer fusion bnupdate eltwise pass
Merge pull request !702 from Etone.Chan/r0.2
上级
a04e8486
4e39354d
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
202 addition
and
1409 deletion
+202
-1409
mindspore/_akg/__init__.py
mindspore/_akg/__init__.py
+2
-3
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
+1
-2
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
.../ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
+178
-85
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
...e/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
+3
-3
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+3
-2
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+3
-3
mindspore/train/callback.py
mindspore/train/callback.py
+1
-2
mindspore/train/model.py
mindspore/train/model.py
+11
-11
tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
...p/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
+0
-1298
未找到文件。
mindspore/_akg/__init__.py
浏览文件 @
79d1e465
...
...
@@ -16,6 +16,8 @@
from
__future__
import
absolute_import
as
_abs
import
sys
import
os
from
.op_build
import
op_build
from
.message
import
compilewithjson
def
AKGAddPath
():
"""_akg add path."""
...
...
@@ -58,6 +60,3 @@ class AKGMetaPathLoader:
sys
.
meta_path
.
insert
(
0
,
AKGMetaPathFinder
())
from
.op_build
import
op_build
from
.message
import
compilewithjson
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
浏览文件 @
79d1e465
...
...
@@ -722,8 +722,7 @@ bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode
std
::
vector
<
nlohmann
::
json
>
*
output_desc_list
)
{
auto
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
cnode
);
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrOutputUsedNum
,
cnode
))
{
// wait anther pr: auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrOutputUsedNum);
auto
output_used_nums
=
{
SizeToInt
(
AnfAlgo
::
GetNodeAttr
<
std
::
size_t
>
(
cnode
,
kAttrOutputUsedNum
))};
auto
output_used_nums
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
int
>>
(
cnode
,
kAttrOutputUsedNum
);
MS_LOG
(
INFO
)
<<
"This node's output has been reused, node name: "
<<
cnode
->
fullname_with_scope
();
if
(
output_used_nums
.
size
()
!=
output_size
)
{
MS_LOG
(
INFO
)
<<
"Fusion error: output tenor num("
<<
output_size
<<
")"
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
浏览文件 @
79d1e465
...
...
@@ -17,6 +17,7 @@
#include <vector>
#include <tuple>
#include <utility>
#include <unordered_set>
#include <unordered_map>
#include <deque>
...
...
@@ -282,11 +283,17 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr
// outputs format and data type
std
::
vector
<
std
::
string
>
outputs_format
;
std
::
vector
<
TypeId
>
outputs_data_type
;
for
(
size_t
index
=
0
;
index
<
outputs_list
.
size
();
++
index
)
{
for
(
size_t
idx
=
0
;
idx
<
AnfAlgo
::
GetOutputTensorNum
(
outputs_list
[
index
]);
++
idx
)
{
auto
kernel_with_index
=
AnfAlgo
::
VisitKernel
(
outputs_list
[
index
],
idx
);
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
kernel_with_index
.
first
,
kernel_with_index
.
second
));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
kernel_with_index
.
first
,
kernel_with_index
.
second
));
for
(
const
auto
&
output
:
outputs_list
)
{
if
(
AnfAlgo
::
GetCNodeName
(
output
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
auto
tuple_getitem
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
}
else
{
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
output
,
0
));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
output
,
0
));
}
}
builder
.
SetInputsFormat
(
inputs_format
);
...
...
@@ -320,32 +327,35 @@ AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::K
return
tuple_item
;
}
void
ReplaceOldNode
(
const
std
::
vector
<
AnfNodePtr
>
&
outputs_list
,
const
AnfNodePtr
&
buffer_fusion_kernel
,
session
::
KernelGraph
*
kernel_graph
)
{
void
ReplaceInputNodeInOtherFusionScope
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
AnfNodePtr
&
output_item
,
const
AnfNodePtr
&
replace_item
)
{
for
(
int32_t
id
=
fusion_id
+
1
;
id
<=
SizeToInt
(
buffer_fusion_infos
->
size
());
++
id
)
{
auto
itr
=
std
::
find
((
*
buffer_fusion_infos
)[
id
].
inputs_list
.
begin
(),
(
*
buffer_fusion_infos
)[
id
].
inputs_list
.
end
(),
output_item
);
if
(
itr
!=
(
*
buffer_fusion_infos
)[
id
].
inputs_list
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"replace input of other pattern, id = "
<<
id
;
*
itr
=
replace_item
;
}
}
}
void
ReplaceOldNode
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
AnfNodePtr
&
buffer_fusion_kernel
,
session
::
KernelGraph
*
kernel_graph
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
if
(
outputs_list
.
size
()
==
1
)
{
// single output
(
void
)
manager
->
Replace
(
outputs_list
[
0
],
buffer_fusion_kernel
);
auto
buffer_fusion_info
=
(
*
buffer_fusion_infos
)[
fusion_id
];
if
(
buffer_fusion_info
.
outputs_list
.
size
()
==
1
)
{
// single output
(
void
)
manager
->
Replace
(
buffer_fusion_info
.
outputs_list
[
0
],
buffer_fusion_kernel
);
ReplaceInputNodeInOtherFusionScope
(
buffer_fusion_infos
,
fusion_id
,
buffer_fusion_info
.
outputs_list
[
0
],
buffer_fusion_kernel
);
}
else
{
// multiple output
size_t
real_idx
=
0
;
for
(
size_t
index
=
0
;
index
<
outputs_list
.
size
();
++
index
)
{
if
(
AnfAlgo
::
GetOutputTensorNum
(
outputs_list
[
index
])
==
1
)
{
auto
tuple_item
=
CreateTupleGetItem
(
buffer_fusion_kernel
,
kernel_graph
,
real_idx
++
);
(
void
)
manager
->
Replace
(
outputs_list
[
index
],
tuple_item
);
}
else
{
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
AbstractBasePtrList
abstract_list
;
make_tuple_inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
for
(
size_t
idx
=
0
;
idx
<
AnfAlgo
::
GetOutputTensorNum
(
outputs_list
[
index
]);
++
idx
)
{
auto
tuple_item
=
CreateTupleGetItem
(
buffer_fusion_kernel
,
kernel_graph
,
real_idx
++
);
abstract_list
.
push_back
(
tuple_item
->
abstract
());
make_tuple_inputs
.
push_back
(
tuple_item
);
}
AnfNodePtr
make_tuple
=
kernel_graph
->
NewCNode
(
make_tuple_inputs
);
make_tuple
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
));
(
void
)
manager
->
Replace
(
outputs_list
[
index
],
make_tuple
);
}
for
(
size_t
index
=
0
;
index
<
buffer_fusion_info
.
outputs_list
.
size
();
++
index
)
{
auto
tuple_item
=
CreateTupleGetItem
(
buffer_fusion_kernel
,
kernel_graph
,
index
);
(
void
)
manager
->
Replace
(
buffer_fusion_info
.
outputs_list
[
index
],
tuple_item
);
ReplaceInputNodeInOtherFusionScope
(
buffer_fusion_infos
,
fusion_id
,
buffer_fusion_info
.
outputs_list
[
index
],
tuple_item
);
}
}
}
...
...
@@ -406,38 +416,67 @@ void CheckCurrentNodeIsInput(const CNodePtr &node, const int32_t &cur_fusion_id,
}
}
void
InsertNode
(
const
AnfNodePtr
&
node
,
std
::
vector
<
AnfNodePtr
>
*
list
)
{
MS_EXCEPTION_IF_NULL
(
list
);
if
(
std
::
find
(
list
->
begin
(),
list
->
end
(),
node
)
==
list
->
end
())
{
(
void
)
list
->
insert
(
list
->
end
(),
node
);
void
GetFusionScopeComputeNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
nodes
=
TopoSort
(
kernel_graph
->
get_return
());
for
(
auto
&
node
:
nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
AnfAlgo
::
IsRealCNodeKernel
(
node
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
node
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
node
,
kOpAttrFusionId
);
(
*
buffer_fusion_infos
)[
fusion_id
].
anf_nodes
.
push_back
(
node
);
}
}
}
void
CheckCurrentNodeIsOutput
(
const
CNodePtr
&
node
,
const
int32_t
&
cur_fusion_id
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
node
);
void
GetFusionScopeOutputNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
for
(
auto
&
input
:
node
->
inputs
())
{
MS_EXCEPTION_IF_NULL
(
input
);
if
(
AnfAlgo
::
IsRealCNodeKernel
(
input
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
input
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
input
,
kOpAttrFusionId
);
if
(
buffer_fusion_infos
->
find
(
fusion_id
)
==
buffer_fusion_infos
->
end
())
{
BufferFusionInfo_t
buffer_fusion_info
;
(
*
buffer_fusion_infos
)[
fusion_id
]
=
buffer_fusion_info
;
}
if
(
fusion_id
!=
cur_fusion_id
)
{
InsertNode
(
input
,
&
((
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
));
}
}
else
if
(
input
->
isa
<
CNode
>
())
{
for
(
auto
&
input_in
:
input
->
cast
<
CNodePtr
>
()
->
inputs
())
{
if
(
AnfAlgo
::
IsRealCNodeKernel
(
input_in
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
input_in
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
input_in
,
kOpAttrFusionId
);
if
(
buffer_fusion_infos
->
find
(
fusion_id
)
==
buffer_fusion_infos
->
end
())
{
BufferFusionInfo_t
buffer_fusion_info
;
(
*
buffer_fusion_infos
)[
fusion_id
]
=
buffer_fusion_info
;
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
auto
fusion_id
=
buffer_fusion_info
.
first
;
auto
fusion_info
=
buffer_fusion_info
.
second
;
for
(
const
auto
&
node
:
fusion_info
.
anf_nodes
)
{
if
(
AnfAlgo
::
GetOutputTensorNum
(
node
)
==
1
)
{
for
(
auto
use_node
:
manager
->
node_users
()[
node
])
{
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
use_node
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
node
);
break
;
}
}
}
else
{
int
prev_idx
=
0
;
std
::
vector
<
AnfNodePtr
>
tuple_getitem_nodes
;
std
::
transform
(
manager
->
node_users
()[
node
].
begin
(),
manager
->
node_users
()[
node
].
end
(),
std
::
back_inserter
(
tuple_getitem_nodes
),
[](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
use_node
)
{
return
use_node
.
first
;
});
std
::
sort
(
tuple_getitem_nodes
.
begin
(),
tuple_getitem_nodes
.
end
(),
[](
const
AnfNodePtr
&
node1
,
const
AnfNodePtr
&
node2
)
{
auto
getitem1
=
node1
->
cast
<
CNodePtr
>
();
auto
getitem2
=
node2
->
cast
<
CNodePtr
>
();
auto
output_idx1
=
GetValue
<
int
>
(
GetValueNode
(
getitem1
->
input
(
2
)));
auto
output_idx2
=
GetValue
<
int
>
(
GetValueNode
(
getitem2
->
input
(
2
)));
return
output_idx1
<
output_idx2
;
});
for
(
auto
getitem
:
tuple_getitem_nodes
)
{
auto
getitem_ptr
=
getitem
->
cast
<
CNodePtr
>
();
auto
input2
=
getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
for
(
int
stub_idx
=
prev_idx
;
stub_idx
<
output_idx
;
++
stub_idx
)
{
auto
stub_node
=
CreateTupleGetItem
(
node
,
kernel_graph
,
IntToSize
(
stub_idx
));
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
stub_node
);
}
if
(
fusion_id
!=
cur_fusion_id
)
{
InsertNode
(
input_in
,
&
((
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
));
prev_idx
=
output_idx
+
1
;
for
(
auto
item_use_node
:
manager
->
node_users
()[
getitem
])
{
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
item_use_node
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
getitem
);
break
;
}
}
}
}
...
...
@@ -445,15 +484,72 @@ void CheckCurrentNodeIsOutput(const CNodePtr &node, const int32_t &cur_fusion_id
}
}
void
GetFusionScopeNodeList
(
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
nodes
=
TopoSort
(
kernel_graph
.
get_return
());
for
(
auto
&
node
:
nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
AnfAlgo
::
IsRealCNodeKernel
(
node
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
node
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
node
,
kOpAttrFusionId
);
(
*
buffer_fusion_infos
)[
fusion_id
].
anf_nodes
.
push_back
(
node
);
void
MatchConvBnreduce
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
conv
=
cnode
->
input
(
1
);
if
(
conv
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
conv
)
==
prim
::
kPrimConv2D
->
name
())
{
std
::
vector
<
int
>
output_used_num
{
SizeToInt
(
manager
->
node_users
()[
conv
].
size
())};
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
conv
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
conv
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
}
}
void
MatchBnupdateRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
getitem
=
relu_input
->
cast
<
CNodePtr
>
();
auto
bnupdate
=
getitem
->
input
(
1
);
if
(
bnupdate
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
bnupdate
)
==
kBNTrainingUpdateOpName
)
{
std
::
vector
<
int
>
output_used_num
(
AnfAlgo
::
GetOutputTensorNum
(
bnupdate
),
0
);
for
(
auto
out_getitem
:
manager
->
node_users
()[
bnupdate
])
{
auto
out_getitem_ptr
=
out_getitem
.
first
->
cast
<
CNodePtr
>
();
auto
input2
=
out_getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
output_used_num
[
output_idx
]
=
SizeToInt
(
manager
->
node_users
()[
out_getitem
.
first
].
size
());
}
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
}
}
void
MatchBnupdateAddRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
add
=
relu_input
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
add
);
auto
tuple_getitem
=
add
->
input
(
1
);
if
(
tuple_getitem
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
tuple_getitem
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
auto
getitem
=
tuple_getitem
->
cast
<
CNodePtr
>
();
auto
bnupdate
=
getitem
->
input
(
1
);
if
(
bnupdate
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
bnupdate
)
==
kBNTrainingUpdateOpName
)
{
std
::
vector
<
int
>
output_used_num
(
AnfAlgo
::
GetOutputTensorNum
(
bnupdate
),
0
);
for
(
auto
out_getitem
:
manager
->
node_users
()[
bnupdate
])
{
auto
out_getitem_ptr
=
out_getitem
.
first
->
cast
<
CNodePtr
>
();
auto
input2
=
out_getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
output_used_num
[
output_idx
]
=
SizeToInt
(
manager
->
node_users
()[
out_getitem
.
first
].
size
());
}
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
relu_input
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
}
}
}
...
...
@@ -470,15 +566,14 @@ void MatchOpNamePattern(const session::KernelGraph &kernel_graph, std::unordered
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kBNTrainingReduceOpName
)
{
auto
conv
=
cnode
->
input
(
1
);
if
(
conv
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
conv
)
==
prim
::
kPrimConv2D
->
name
())
{
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
&
users
=
manager
->
node_users
();
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
users
[
conv
].
size
()),
conv
);
std
::
unordered_set
<
AnfNodePtr
>
record
({
cnode
,
conv
});
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
MatchConvBnreduce
(
cnode
,
kernel_graph
,
fused_set
,
candidate_fusion
);
}
else
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kReluV2OpName
||
AnfAlgo
::
GetCNodeName
(
cnode
)
==
prim
::
kPrimRelu
->
name
())
{
auto
relu_input
=
cnode
->
input
(
1
);
if
(
relu_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
relu_input
)
==
prim
::
kPrimTensorAdd
->
name
())
{
MatchBnupdateAddRelu
(
cnode
,
relu_input
,
kernel_graph
,
fused_set
,
candidate_fusion
);
}
else
if
(
relu_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
relu_input
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
MatchBnupdateRelu
(
cnode
,
relu_input
,
kernel_graph
,
fused_set
,
candidate_fusion
);
}
}
}
...
...
@@ -536,27 +631,23 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
}
}
// namespace
void
BufferFusion
::
GetBufferFusionInfo
(
const
session
::
KernelGraph
&
kernel_graph
,
void
BufferFusion
::
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
.
get_return
());
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
->
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
))
{
continue
;
}
int32_t
cur_fusion_id
=
-
1
;
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
cnode
))
{
cur_fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
cnode
,
kOpAttrFusionId
);
auto
cur_fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
cnode
,
kOpAttrFusionId
);
CheckCurrentNodeIsInput
(
cnode
,
cur_fusion_id
,
buffer_fusion_infos
);
}
// Check if current node is output
CheckCurrentNodeIsOutput
(
cnode
,
cur_fusion_id
,
buffer_fusion_infos
);
}
GetFusionScopeNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeComputeNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScope
Output
NodeList
(
kernel_graph
,
buffer_fusion_infos
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
buffer_fusion_info
.
second
.
kernel_build_info
=
CreateFusionOpKernelInfo
(
buffer_fusion_info
.
second
.
inputs_list_in
,
buffer_fusion_info
.
second
.
inputs_list
,
...
...
@@ -569,7 +660,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
bool
change
=
false
;
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
buffer_fusion_infos
;
buffer_fusion_infos
.
clear
();
GetBufferFusionInfo
(
*
kernel_graph
,
&
buffer_fusion_infos
);
GetBufferFusionInfo
(
kernel_graph
,
&
buffer_fusion_infos
);
std
::
vector
<
mindspore
::
kernel
::
FusionScopeInfo
>
fusion_scope_infos
;
for
(
auto
&
buffer_fusion_info
:
buffer_fusion_infos
)
{
...
...
@@ -600,7 +691,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
MS_LOG
(
DEBUG
)
<<
"fusion id: "
<<
fusion_id
<<
", fusion op compiling failed"
;
continue
;
}
change
=
ReplaceFusionOp
(
buffer_fusion_infos
[
fusion_id
]
,
kernel_mods
[
fusion_id
],
kernel_graph
);
change
=
ReplaceFusionOp
(
&
buffer_fusion_infos
,
fusion_id
,
kernel_mods
[
fusion_id
],
kernel_graph
);
}
MS_LOG
(
DEBUG
)
<<
"End Buffer Fusion"
;
return
change
;
...
...
@@ -630,8 +721,10 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g
return
true
;
}
bool
BufferFusion
::
ReplaceFusionOp
(
const
BufferFusionInfo_t
&
buffer_fusion_info
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
bool
BufferFusion
::
ReplaceFusionOp
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
{
auto
buffer_fusion_info
=
(
*
buffer_fusion_infos
)[
fusion_id
];
auto
buffer_fusion
=
CreateFusionOp
(
buffer_fusion_info
.
inputs_list
,
buffer_fusion_info
.
outputs_list
,
buffer_fusion_info
.
anf_nodes
,
kernel_graph
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
buffer_fusion_info
.
kernel_build_info
,
buffer_fusion
.
get
());
...
...
@@ -651,7 +744,7 @@ bool BufferFusion::ReplaceFusionOp(const BufferFusionInfo_t &buffer_fusion_info,
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
buffer_fusion
.
get
());
AnfAlgo
::
SetKernelMod
(
kernel_ptr
,
buffer_fusion
.
get
());
// replace node
ReplaceOldNode
(
buffer_fusion_info
.
outputs_list
,
buffer_fusion
,
kernel_graph
);
ReplaceOldNode
(
buffer_fusion_info
s
,
fusion_id
,
buffer_fusion
,
kernel_graph
);
return
true
;
}
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
浏览文件 @
79d1e465
...
...
@@ -44,10 +44,10 @@ class BufferFusion : public Pass {
bool
Run
(
const
FuncGraphPtr
&
graph
)
override
;
private:
void
GetBufferFusionInfo
(
const
session
::
KernelGraph
&
kernel_graph
,
void
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
;
bool
ReplaceFusionOp
(
const
BufferFusionInfo_t
&
buffer_fusion_info
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
;
bool
ReplaceFusionOp
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
;
bool
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
)
const
;
bool
FuseBufferFusionPattern
(
session
::
KernelGraph
*
kernel_graph
)
const
;
};
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
79d1e465
...
...
@@ -104,7 +104,7 @@ class FTRL(Optimizer):
self
.
lr_power
=
lr_power
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
weight_decay
=
weight_decay
self
.
decay_tf
=
tuple
((
lambda
:
True
)()
for
x
in
self
.
parameters
)
self
.
decay_tf
=
tuple
((
lambda
:
True
)()
for
x
in
self
.
parameters
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyFtrl
(
use_locking
=
use_locking
)
self
.
one
=
Tensor
(
1
,
mstype
.
int32
)
...
...
@@ -118,5 +118,6 @@ class FTRL(Optimizer):
if
self
.
reciprocal_scale
!=
1.0
:
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
grads
)
lr
=
self
.
learning_rate
success
=
self
.
hyper_map
(
F
.
partial
(
ftrl_opt
,
self
.
opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
success
=
self
.
hyper_map
(
F
.
partial
(
ftrl_opt
,
self
.
opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
return
success
mindspore/ops/operations/nn_ops.py
浏览文件 @
79d1e465
...
...
@@ -2063,7 +2063,7 @@ class LSTM(PrimitiveWithInfer):
return
(
y_shape
,
h_shape
,
c_shape
,
reserved_shape
,
state_shape
)
def
infer_dtype
(
self
,
x_dtype
,
h_dtype
,
c_dtype
,
w_dtype
):
args
=
{
'x'
:
x_dtype
,
'h'
:
h_dtype
,
'c'
:
c_dtype
,
'w'
:
w_dtype
}
args
=
{
'x'
:
x_dtype
,
'h'
:
h_dtype
,
'c'
:
c_dtype
,
'w'
:
w_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
float32
,
mstype
.
float16
),
self
.
name
)
return
(
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
)
...
...
@@ -2670,8 +2670,8 @@ class ConfusionMulGrad(PrimitiveWithInfer):
"""
@
prim_attr_register
def
__init__
(
self
,
axis
=
(),
keep_dims
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
"input0"
,
"input1"
,
"input2"
],
outputs
=
[
"output0"
,
"output1"
])
def
__init__
(
self
,
axis
=
(),
keep_dims
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
"input0"
,
"input1"
,
"input2"
],
outputs
=
[
"output0"
,
"output1"
])
self
.
axis_
=
validator
.
check_value_type
(
"axis"
,
axis
,
[
int
,
tuple
,
list
],
self
.
name
)
self
.
keep_dims_
=
validator
.
check_value_type
(
"keep_dims"
,
keep_dims
,
[
bool
],
self
.
name
)
...
...
mindspore/train/callback.py
浏览文件 @
79d1e465
...
...
@@ -689,7 +689,7 @@ class TimeMonitor(Callback):
def
epoch_begin
(
self
,
run_context
):
self
.
epoch_time
=
time
.
time
()
def
epoch_end
(
self
,
run_context
):
epoch_mseconds
=
(
time
.
time
()
-
self
.
epoch_time
)
*
1000
per_step_mseconds
=
epoch_mseconds
/
self
.
data_size
...
...
@@ -701,4 +701,3 @@ class TimeMonitor(Callback):
def
step_end
(
self
,
run_context
):
step_mseconds
=
(
time
.
time
()
-
self
.
step_time
)
*
1000
print
(
'step time'
,
step_mseconds
,
flush
=
True
)
mindspore/train/model.py
浏览文件 @
79d1e465
...
...
@@ -130,17 +130,17 @@ class Model:
if
self
.
_optimizer
:
if
self
.
_loss_scale_manager_set
:
network
=
amp
.
build_train_network
(
network
,
self
.
_optimizer
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
loss_scale_manager
=
self
.
_loss_scale_manager
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
self
.
_optimizer
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
loss_scale_manager
=
self
.
_loss_scale_manager
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
else
:
network
=
amp
.
build_train_network
(
network
,
self
.
_optimizer
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
self
.
_optimizer
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
elif
self
.
_loss_fn
:
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
# If need to check if loss_fn is not None, but optimizer is None
...
...
@@ -280,7 +280,7 @@ class Model:
# remove later to deal with loop sink
if
need_wrap
:
self
.
_train_network
=
nn
.
DataWrapper
(
self
.
_train_network
,
*
(
dataset_helper
.
types_shapes
()),
train_dataset
.
__ME_INITED__
)
train_dataset
.
__ME_INITED__
)
cb_params
.
train_network
=
self
.
_train_network
self
.
_train_network
.
set_train
()
...
...
@@ -449,7 +449,7 @@ class Model:
# remove later to deal with loop sink
if
need_wrap
:
self
.
_eval_network
=
nn
.
DataWrapper
(
self
.
_eval_network
,
*
(
dataset_helper
.
types_shapes
()),
valid_dataset
.
__ME_INITED__
)
valid_dataset
.
__ME_INITED__
)
self
.
_eval_network
.
set_train
(
mode
=
False
)
self
.
_eval_network
.
phase
=
'eval'
...
...
tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
已删除
100644 → 0
浏览文件 @
a04e8486
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录