Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
b8d7f6d7
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b8d7f6d7
编写于
6月 23, 2020
作者:
H
huanghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add UnsortedSegmentSum fission pass
上级
874972ca
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
341 addition
and
0 deletion
+341
-0
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc
...timizer/ascend/ir_fission/unsorted_segment_sum_fission.cc
+118
-0
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h
...ptimizer/ascend/ir_fission/unsorted_segment_sum_fission.h
+37
-0
mindspore/ccsrc/backend/optimizer/common/helper.h
mindspore/ccsrc/backend/optimizer/common/helper.h
+1
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+5
-0
tests/st/fusion/test_unsorted_segment_sum_fission.py
tests/st/fusion/test_unsorted_segment_sum_fission.py
+47
-0
tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc
...te/ascend/ir_fission/unsorted_segment_sum_fission_test.cc
+68
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/unsorted_segment_sum_fission.py
.../gtest_input/pre_activate/unsorted_segment_sum_fission.py
+63
-0
未找到文件。
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
b8d7f6d7
...
...
@@ -26,6 +26,7 @@
#include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h"
#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h"
#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h"
#include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
...
...
@@ -172,6 +173,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
PackFission
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ConcatFission
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ReduceMinFission
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
UnsortSegmentSumFission
>
());
}
}
// namespace
...
...
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc
0 → 100644
浏览文件 @
b8d7f6d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h"
#include <memory>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
CNodePtr
CreatePadding
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
origin_node
,
const
size_t
&
pad_dim_size
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
origin_node
);
std
::
vector
<
AnfNodePtr
>
padding_inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
kPaddingOpName
)),
origin_node
->
input
(
1
)};
auto
padding
=
graph
->
NewCNode
(
padding_inputs
);
MS_EXCEPTION_IF_NULL
(
padding
);
padding
->
set_scope
(
origin_node
->
scope
());
auto
shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
origin_node
,
0
);
shape
[
shape
.
size
()
-
1
]
=
pad_dim_size
;
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
origin_node
,
0
)},
{
shape
},
padding
.
get
());
AnfAlgo
::
SetNodeAttr
(
kAttrPadDimSize
,
MakeValue
(
SizeToInt
(
pad_dim_size
)),
padding
);
return
padding
;
}
CNodePtr
CreateUnsortedSegmentSum
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
origin_node
,
const
CNodePtr
&
padding
,
const
size_t
&
pad_dim_size
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
origin_node
);
MS_EXCEPTION_IF_NULL
(
padding
);
std
::
vector
<
AnfNodePtr
>
unsorted_segment_sum8_inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimUnsortedSegmentSum
->
name
())),
padding
,
origin_node
->
input
(
2
)};
auto
unsorted_segment_sum
=
graph
->
NewCNode
(
unsorted_segment_sum8_inputs
);
MS_EXCEPTION_IF_NULL
(
unsorted_segment_sum
);
unsorted_segment_sum
->
set_scope
(
origin_node
->
scope
());
auto
shape
=
AnfAlgo
::
GetOutputInferShape
(
origin_node
,
0
);
shape
[
shape
.
size
()
-
1
]
=
pad_dim_size
;
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
origin_node
,
0
)},
{
shape
},
unsorted_segment_sum
.
get
());
AnfAlgo
::
SetNodeAttr
(
kAttrNumSegments
,
MakeValue
(
SizeToInt
(
shape
[
0
])),
unsorted_segment_sum
);
return
unsorted_segment_sum
;
}
CNodePtr
CreateSlice
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
unsort_segment_sum
,
const
CNodePtr
&
unsorted_segment_sum8
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
unsort_segment_sum
);
MS_EXCEPTION_IF_NULL
(
unsorted_segment_sum8
);
std
::
vector
<
AnfNodePtr
>
slice_inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
kSliceOpName
)),
unsorted_segment_sum8
};
auto
slice
=
graph
->
NewCNode
(
slice_inputs
);
MS_EXCEPTION_IF_NULL
(
slice
);
slice
->
set_scope
(
unsort_segment_sum
->
scope
());
slice
->
set_abstract
(
unsort_segment_sum
->
abstract
());
auto
unsort_segment_sum_shape
=
AnfAlgo
::
GetOutputInferShape
(
unsort_segment_sum
,
0
);
std
::
vector
<
size_t
>
offsets
(
unsort_segment_sum_shape
.
size
(),
0
);
AnfAlgo
::
SetNodeAttr
(
kAttrBegin
,
MakeValue
(
Convert2Int
(
offsets
)),
slice
);
AnfAlgo
::
SetNodeAttr
(
kAttrSize
,
MakeValue
(
Convert2Int
(
unsort_segment_sum_shape
)),
slice
);
return
slice
;
}
}
// namespace
const
BaseRef
UnsortSegmentSumFission
::
DefinePattern
()
const
{
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
VectorRef
pattern
({
prim
::
kPrimUnsortedSegmentSum
,
Xs
});
return
pattern
;
}
const
AnfNodePtr
UnsortSegmentSumFission
::
Process
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
origin_node
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
origin_node
);
if
(
origin_node
->
size
()
!=
kUnsortedSegmentSumInputNum
+
1
)
{
MS_LOG
(
INFO
)
<<
"UnsortedSegmentSum has wrong inputs num, not equal "
<<
kUnsortedSegmentSumInputNum
<<
". CNode= "
<<
origin_node
->
DebugString
();
return
nullptr
;
}
auto
input0_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
origin_node
,
0
);
if
(
input0_shape
[
input0_shape
.
size
()
-
1
]
!=
1
)
{
MS_LOG
(
INFO
)
<<
"UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<<
input0_shape
[
input0_shape
.
size
()
-
1
];
return
nullptr
;
}
size_t
pad_dim_size
;
auto
input_dtype
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
origin_node
,
0
);
if
(
input_dtype
==
kNumberTypeFloat32
)
{
pad_dim_size
=
8
;
}
else
if
(
input_dtype
==
kNumberTypeFloat16
)
{
pad_dim_size
=
16
;
}
else
{
MS_LOG
(
INFO
)
<<
"UnsortedSegmentSum data type not in (float21, float16), no need change"
;
return
nullptr
;
}
auto
padding
=
CreatePadding
(
graph
,
origin_node
,
pad_dim_size
);
auto
unsorted_segment_sum8
=
CreateUnsortedSegmentSum
(
graph
,
origin_node
,
padding
,
pad_dim_size
);
return
CreateSlice
(
graph
,
origin_node
,
unsorted_segment_sum8
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h
0 → 100644
浏览文件 @
b8d7f6d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_
#include <vector>
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace
mindspore
{
namespace
opt
{
class
UnsortSegmentSumFission
:
public
PatternProcessPass
{
public:
explicit
UnsortSegmentSumFission
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"unsorted_segment_sum_fission"
,
multigraph
)
{}
~
UnsortSegmentSumFission
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_
mindspore/ccsrc/backend/optimizer/common/helper.h
浏览文件 @
b8d7f6d7
...
...
@@ -98,6 +98,7 @@ constexpr size_t kTopkInputNum = 3;
constexpr
size_t
kLarsV2InputNum
=
5
;
constexpr
size_t
kFusedMulApplyMomentumOutputNum
=
2
;
constexpr
size_t
kSplitInputNum
=
2
;
constexpr
size_t
kUnsortedSegmentSumInputNum
=
2
;
enum
FusedBatchNormInput
{
kX
=
1
,
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
b8d7f6d7
...
...
@@ -182,6 +182,7 @@ constexpr auto kPushOpName = "Push";
constexpr
auto
kPullOpName
=
"Pull"
;
constexpr
auto
kEmbeddingLookupOpName
=
"EmbeddingLookup"
;
constexpr
auto
kEmbeddingLookupProxyOpName
=
"EmbeddingLookupProxy"
;
constexpr
auto
kPaddingOpName
=
"Padding"
;
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
...
...
@@ -253,6 +254,10 @@ constexpr auto kAttrInputNums = "inputNums";
constexpr
auto
kAttrT
=
"T"
;
constexpr
auto
kAttrNum
=
"num"
;
constexpr
auto
kAttrRankSize
=
"rank_size"
;
constexpr
auto
kAttrPadDimSize
=
"pad_dim_size"
;
constexpr
auto
kAttrNumSegments
=
"num_segments"
;
constexpr
auto
kAttrBegin
=
"begin"
;
constexpr
auto
kAttrSize
=
"size"
;
// attr value
constexpr
auto
kValueTargetSwitch
=
"target_switch"
;
...
...
tests/st/fusion/test_unsorted_segment_sum_fission.py
0 → 100644
浏览文件 @
b8d7f6d7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
mindspore
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
save_graphs
=
True
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
unsorted_segment_sum
=
P
.
UnsortedSegmentSum
()
self
.
num_segments
=
3
def
construct
(
self
,
x
,
segment_ids
):
x
=
self
.
unsorted_segment_sum
(
x
,
segment_ids
,
self
.
num_segments
)
return
x
def
test_net
():
input_x
=
np
.
random
.
randn
(
3
,
39
,
1
).
astype
(
np
.
float32
)
segment_ids
=
Tensor
([
0
,
1
,
2
],
mindspore
.
int32
)
net
=
Net
()
output
=
net
(
Tensor
(
input_x
),
segment_ids
)
print
(
"result"
,
output
.
asnumpy
())
if
__name__
==
"__main__"
:
test_net
()
tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc
0 → 100644
浏览文件 @
b8d7f6d7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "debug/anf_ir_dump.h"
namespace
mindspore
{
namespace
opt
{
class
TestHWUnsortedSegmentSumFission
:
public
BackendCommon
{
public:
TestHWUnsortedSegmentSumFission
()
:
get_py_fun_
(
"gtest_input.pre_activate.unsorted_segment_sum_fission"
,
true
)
{}
~
TestHWUnsortedSegmentSumFission
()
override
=
default
;
UT
::
PyFuncGraphFetcher
get_py_fun_
;
};
TEST_F
(
TestHWUnsortedSegmentSumFission
,
test_fission
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_unsorted_segment_sum_fission"
,
"before1"
);
EXPECT_NE
(
g
,
nullptr
);
std
::
vector
<
int
>
shp_x
{
16
,
1
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp_x
);
AbstractBasePtrList
args_spec_list
{
x_abstract
,
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
UnsortSegmentSumFission
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_unsorted_segment_sum_fission"
,
"after1"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWUnsortedSegmentSumFission
,
test_no_fission
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_unsorted_segment_sum_fission"
,
"before2"
);
EXPECT_NE
(
g
,
nullptr
);
std
::
vector
<
int
>
shp_x
{
16
,
2
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp_x
);
AbstractBasePtrList
args_spec_list
{
x_abstract
,
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
UnsortSegmentSumFission
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_unsorted_segment_sum_fission"
,
"after2"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/pre_activate/unsorted_segment_sum_fission.py
0 → 100644
浏览文件 @
b8d7f6d7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
operations
as
P
make_tuple
=
Primitive
(
'make_tuple'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
unsorted_segment_sum
=
P
.
UnsortedSegmentSum
()
num_segments
=
4
padding
=
Primitive
(
'Padding'
)
op_slice
=
Primitive
(
'Slice'
)
op_unsorted_segment_sum
=
Primitive
(
'UnsortedSegmentSum'
)
class
FnDict
:
def
__init__
(
self
):
self
.
fnDict
=
{}
def
__call__
(
self
,
fn
):
self
.
fnDict
[
fn
.
__name__
]
=
fn
def
__getitem__
(
self
,
name
):
return
self
.
fnDict
[
name
]
def
test_unsorted_segment_sum_fission
(
tag
):
fns
=
FnDict
()
@
fns
def
before1
(
input0
,
input1
):
x
=
unsorted_segment_sum
(
input0
,
input1
,
num_segments
)
return
x
@
fns
def
after1
(
input0
,
input1
):
x
=
padding
(
input0
)
x
=
op_unsorted_segment_sum
(
x
,
input1
)
x
=
op_slice
(
x
)
return
make_tuple
(
x
)
@
fns
def
before2
(
input0
,
input1
):
x
=
unsorted_segment_sum
(
input0
,
input1
,
num_segments
)
return
x
@
fns
def
after2
(
input0
,
input1
):
x
=
op_unsorted_segment_sum
(
input0
,
input1
)
return
make_tuple
(
x
)
return
fns
[
tag
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录