Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
94872b76
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
94872b76
编写于
5月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1570 Check the size of topk input names before converting input to attr
Merge pull request !1570 from YuJianfeng/r0.3
上级
2f936166
1fb2cce2
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
59 addition
and
9 deletion
+59
-9
mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc
mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc
+6
-0
tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc
.../ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc
+53
-9
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc
浏览文件 @
94872b76
...
...
@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fission/topk_split.h"
#include <string>
#include <vector>
#include <memory>
#include <unordered_set>
...
...
@@ -102,6 +103,11 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
// set value node as topk's input
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input_names_vec
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
std
::
string
>>
(
cnode
,
kAttrInputNames
);
if
(
input_names_vec
.
size
()
<
kTopkIndexK
+
1
)
{
MS_LOG
(
INFO
)
<<
"The input k of topk has been converted to attr"
;
return
nullptr
;
}
// Copy a new node to check supported.
std
::
vector
<
AnfNodePtr
>
new_inputs
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
kTopKOpName
))};
new_inputs
.
insert
(
new_inputs
.
end
(),
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
());
...
...
tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc
浏览文件 @
94872b76
...
...
@@ -19,6 +19,7 @@
#include "device/kernel_info.h"
#include "pre_activate/pass/convert_const_input_to_attr.h"
#include "debug/anf_ir_dump.h"
#include "session/anf_runtime_algorithm.h"
#define private public
#define protected public
#include "pre_activate/ascend/ir_fission/topk_split.h"
...
...
@@ -32,6 +33,21 @@ class TestHWTopKSplit : public BackendCommon {
TestHWTopKSplit
()
:
get_py_fun_
(
"gtest_input.pre_activate.topk_split_test"
,
true
)
{}
~
TestHWTopKSplit
()
override
=
default
;
CNodePtr
GetTopkCNodeFromKernelGraph
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
ret
=
func_graph
->
get_return
();
MS_EXCEPTION_IF_NULL
(
ret
);
auto
make_tuple
=
ret
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
make_tuple
);
auto
tuple_getitem
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
auto
topk
=
tuple_getitem
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
topk
);
auto
topk_cnode
=
topk
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
topk_cnode
);
return
topk_cnode
;
}
UT
::
PyFuncGraphFetcher
get_py_fun_
;
};
...
...
@@ -39,7 +55,8 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker
()
=
default
;
~
MockSupportedChecker
()
override
=
default
;
bool
CheckAiCoreSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
override
{
bool
CheckAiCoreSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
override
{
return
true
;
}
};
// namespace opt
...
...
@@ -66,14 +83,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kernel_graph
);
auto
ret
=
new_graph
->
get_return
();
EXPECT_NE
(
ret
,
nullptr
);
auto
make_tuple
=
ret
->
input
(
1
);
EXPECT_NE
(
make_tuple
,
nullptr
);
auto
tuple_getitem
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
EXPECT_NE
(
tuple_getitem
,
nullptr
);
auto
topk
=
tuple_getitem
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
topk_cnode
=
topk
->
cast
<
CNodePtr
>
();
auto
topk_cnode
=
GetTopkCNodeFromKernelGraph
(
new_graph
);
EXPECT_EQ
(
topk_cnode
->
inputs
().
size
(),
3
);
EXPECT_TRUE
(
topk_cnode
->
input
(
2
)
->
isa
<
ValueNode
>
());
auto
value_node
=
topk_cnode
->
input
(
2
)
->
cast
<
ValueNodePtr
>
();
...
...
@@ -82,5 +92,39 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
EXPECT_EQ
(
tensor
->
shape
().
size
(),
1
);
EXPECT_EQ
(
tensor
->
shape
()[
0
],
4
);
}
TEST_F
(
TestHWTopKSplit
,
test_topk_no_split
)
{
/*
* def before(input):
* topk = TopKSplit(input)
* output = tuple_getitem(topk, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_topk_split"
,
"before"
);
std
::
vector
<
int
>
shp
{
4
,
4
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
{
x_abstract
};
auto
kernel_graph
=
GetKernelGraph
(
g
,
args_spec_list
);
CNodePtr
topk_cnode
=
GetTopkCNodeFromKernelGraph
(
kernel_graph
);
EXPECT_EQ
(
topk_cnode
->
inputs
().
size
(),
3
);
auto
input_names_vec
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
std
::
string
>>
(
topk_cnode
,
kAttrInputNames
);
EXPECT_EQ
(
input_names_vec
.
size
(),
2
);
std
::
unordered_set
<
size_t
>
attr_index
{
1
};
ConstInputToAttr
(
topk_cnode
,
attr_index
);
EXPECT_EQ
(
topk_cnode
->
inputs
().
size
(),
2
);
input_names_vec
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
std
::
string
>>
(
topk_cnode
,
kAttrInputNames
);
EXPECT_EQ
(
input_names_vec
.
size
(),
1
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ConvertConstInputToAttr
>
());
auto
topk_split
=
std
::
make_shared
<
opt
::
TopKSplit
>
();
topk_split
->
supported_checker_
=
std
::
make_shared
<
MockSupportedChecker
>
();
pm
->
AddPass
(
topk_split
);
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kernel_graph
);
EXPECT_EQ
(
topk_cnode
,
GetTopkCNodeFromKernelGraph
(
new_graph
));
}
}
// namespace opt
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录