Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7cb567eb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7cb567eb
编写于
7月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3543 Split unsupport transdata
Merge pull request !3543 from lianliguang/unify-primitive
上级
87bf9a48
0179724d
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
210 addition
and
12 deletion
+210
-12
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc
+33
-2
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h
+30
-3
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
...ernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
+2
-2
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
+2
-2
mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc
...timizer/ascend/format_type/split_unsupported_transdata.cc
+65
-0
mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h
...ptimizer/ascend/format_type/split_unsupported_transdata.h
+37
-0
tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc
...p/pre_activate/ascend/format_type/insert_trans_op_test.cc
+6
-0
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
...ctivate/ascend/format_type/remove_internal_output_test.cc
+6
-0
tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc
...pp/pre_activate/ascend/ir_fission/transdata_split_test.cc
+13
-0
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
...ivate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
+6
-0
tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc
...s/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc
+8
-3
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc
浏览文件 @
7cb567eb
...
...
@@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor)
std
::
shared_ptr
<
KernelBuildInfo
>
KernelBuildInfo
::
KernelBuildInfoBuilder
::
Build
()
{
return
kernel_build_info_
;
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetInputReshapeType
(
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetInput
s
ReshapeType
(
const
std
::
vector
<
std
::
vector
<
Axis
>>
&
input_reshape_type
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
kernel_build_info_
->
input_reshape_type_
=
input_reshape_type
;
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetOutputReshapeType
(
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetOutput
s
ReshapeType
(
const
std
::
vector
<
std
::
vector
<
Axis
>>
&
output_reshape_type
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
kernel_build_info_
->
output_reshape_type_
=
output_reshape_type
;
...
...
@@ -189,5 +189,36 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string
}
kernel_build_info_
->
outputs_format_
[
index
]
=
format
;
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetInputReshapeType
(
const
std
::
vector
<
Axis
>
&
input_reshape_type
,
size_t
index
)
{
if
(
index
>=
kernel_build_info_
->
input_reshape_type_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"index outof range!"
;
}
std
::
copy
(
input_reshape_type
.
begin
(),
input_reshape_type
.
end
(),
std
::
back_inserter
(
kernel_build_info_
->
input_reshape_type_
[
index
]));
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetOutputReshapeType
(
const
std
::
vector
<
Axis
>
&
output_reshape_type
,
size_t
index
)
{
if
(
index
>=
kernel_build_info_
->
output_reshape_type_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"index outof range!"
;
}
std
::
copy
(
output_reshape_type
.
begin
(),
output_reshape_type
.
end
(),
std
::
back_inserter
(
kernel_build_info_
->
output_reshape_type_
[
index
]));
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetOutputDeviceType
(
const
TypeId
&
output_device_type
,
size_t
index
)
{
if
(
index
>=
kernel_build_info_
->
outputs_device_type_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"index outof range!"
;
}
kernel_build_info_
->
outputs_device_type_
[
index
]
=
output_device_type
;
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetInputDeviceType
(
const
TypeId
&
input_device_type
,
size_t
index
)
{
if
(
index
>=
kernel_build_info_
->
inputs_device_type_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"index outof range!"
;
}
kernel_build_info_
->
inputs_device_type_
[
index
]
=
input_device_type
;
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h
浏览文件 @
7cb567eb
...
...
@@ -71,6 +71,10 @@ class KernelBuildInfo {
std
::
vector
<
TypeId
>
GetAllOutputDeviceTypes
()
const
;
std
::
vector
<
std
::
vector
<
Axis
>>
GetAllOutputReshapeType
()
const
;
std
::
vector
<
std
::
vector
<
Axis
>>
GetAllInputReshapeType
()
const
;
OpPattern
op_pattern
()
const
{
return
op_pattern_
;
}
FusionType
fusion_type
()
const
{
return
fusion_type_
;
}
...
...
@@ -109,7 +113,22 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
KernelBuildInfoBuilder
()
{
kernel_build_info_
=
std
::
make_shared
<
KernelBuildInfo
>
();
}
explicit
KernelBuildInfoBuilder
(
std
::
shared_ptr
<
KernelBuildInfo
>
kernel_build_info
)
:
kernel_build_info_
(
std
::
move
(
kernel_build_info
))
{}
:
kernel_build_info_
(
std
::
make_shared
<
KernelBuildInfo
>
())
{
SetKernelType
(
kernel_build_info
->
kernel_type
());
SetFusionType
(
kernel_build_info
->
fusion_type
());
SetProcessor
(
kernel_build_info
->
processor
());
OpPattern
(
kernel_build_info
->
op_pattern
());
for
(
size_t
index
=
0
;
index
<
kernel_build_info
->
GetInputNum
();
++
index
)
{
kernel_build_info_
->
inputs_device_type_
.
emplace_back
(
kernel_build_info
->
GetInputDeviceType
(
index
));
kernel_build_info_
->
inputs_format_
.
emplace_back
(
kernel_build_info
->
GetInputFormat
(
index
));
kernel_build_info_
->
input_reshape_type_
.
emplace_back
(
kernel_build_info
->
GetInputReshapeType
(
index
));
}
for
(
size_t
index
=
0
;
index
<
kernel_build_info
->
GetOutputNum
();
++
index
)
{
kernel_build_info_
->
outputs_device_type_
.
emplace_back
(
kernel_build_info
->
GetOutputDeviceType
(
index
));
kernel_build_info_
->
outputs_format_
.
emplace_back
(
kernel_build_info
->
GetOutputFormat
(
index
));
kernel_build_info_
->
output_reshape_type_
.
emplace_back
(
kernel_build_info
->
GetOutputReshapeType
(
index
));
}
}
~
KernelBuildInfoBuilder
()
=
default
;
...
...
@@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void
SetOutputsDeviceType
(
const
std
::
vector
<
TypeId
>
&
outputs_device_type
);
void
SetInputReshapeType
(
const
std
::
vector
<
std
::
vector
<
Axis
>>
&
input_reshape_type
);
void
SetInput
s
ReshapeType
(
const
std
::
vector
<
std
::
vector
<
Axis
>>
&
input_reshape_type
);
void
SetOutputReshapeType
(
const
std
::
vector
<
std
::
vector
<
Axis
>>
&
output_reshape_type
);
void
SetOutput
s
ReshapeType
(
const
std
::
vector
<
std
::
vector
<
Axis
>>
&
output_reshape_type
);
void
SetFusionType
(
FusionType
fusion_type
);
...
...
@@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void
SetOutputFormat
(
const
std
::
string
&
format
,
size_t
index
);
void
SetInputReshapeType
(
const
std
::
vector
<
Axis
>
&
input_reshape_type
,
size_t
index
);
void
SetOutputReshapeType
(
const
std
::
vector
<
Axis
>
&
output_reshape_type
,
size_t
index
);
void
SetInputDeviceType
(
const
TypeId
&
input_device_type
,
size_t
index
);
void
SetOutputDeviceType
(
const
TypeId
&
output_device_type
,
size_t
index
);
std
::
shared_ptr
<
KernelBuildInfo
>
Build
();
private:
...
...
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
浏览文件 @
7cb567eb
...
...
@@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
}
builder
.
SetInputsDeviceType
(
inputs_device_type
);
builder
.
SetInputsFormat
(
inputs_format
);
builder
.
SetInputReshapeType
(
inputs_reshape_type
);
builder
.
SetInput
s
ReshapeType
(
inputs_reshape_type
);
// output
std
::
vector
<
std
::
string
>
outputs_format
;
std
::
vector
<
TypeId
>
outputs_device_type
;
...
...
@@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
}
builder
.
SetOutputsDeviceType
(
outputs_device_type
);
builder
.
SetOutputsFormat
(
outputs_format
);
builder
.
SetOutputReshapeType
(
outputs_reshape_type
);
builder
.
SetOutput
s
ReshapeType
(
outputs_reshape_type
);
kernel_info_list_
->
emplace_back
(
builder
.
Build
());
}
MS_LOG
(
INFO
)
<<
"end."
;
...
...
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
7cb567eb
...
...
@@ -59,6 +59,7 @@
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/pass/optimize_dependence.h"
#include "backend/optimizer/pass/erase_visit_attr.h"
...
...
@@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
MergeCastToOp
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
LayerNormBetaGammaBackpropFusion
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
SplitUnsupportedTransData
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
ConvertUnSupportNodeToAICPU
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
RemoveInternalOutputCast
>
());
optimizer
->
AddPassManager
(
mixed_precision_pm
);
...
...
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
浏览文件 @
7cb567eb
...
...
@@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
MS_EXCEPTION_IF_NULL
(
ori_build_info
);
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
ori_build_info
);
builder
->
SetInputsFormat
({
input_format
});
builder
->
SetInputReshapeType
({
reshape_type
});
builder
->
SetOutputReshapeType
({
reshape_type
});
builder
->
SetInput
s
ReshapeType
({
reshape_type
});
builder
->
SetOutput
s
ReshapeType
({
reshape_type
});
builder
->
SetOutputsFormat
({
output_format
});
if
(
type_id
!=
kTypeUnknown
)
{
builder
->
SetOutputsDeviceType
({
type_id
});
...
...
mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc
0 → 100644
浏览文件 @
7cb567eb
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include <vector>
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
opt
{
const
BaseRef
SplitUnsupportedTransData
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
return
VectorRef
({
prim
::
KPrimTransData
,
X
});
}
const
AnfNodePtr
SplitUnsupportedTransData
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
()
||
!
AnfAlgo
::
IsRealKernel
(
node
))
{
return
nullptr
;
}
auto
ori_trans_data
=
node
->
cast
<
CNodePtr
>
();
if
(
AnfAlgo
::
GetCNodeName
(
ori_trans_data
)
!=
prim
::
KPrimTransData
->
name
())
{
return
nullptr
;
}
auto
kernel_info
=
AnfAlgo
::
GetSelectKernelBuildInfo
(
ori_trans_data
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
if
(
kernel_info
->
GetInputNum
()
!=
1
||
kernel_info
->
GetOutputNum
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Transdata node's kernel info's input and output format size is not 1"
<<
ori_trans_data
->
DebugString
();
}
return
SplitTransData
(
func_graph
,
ori_trans_data
);
}
AnfNodePtr
SplitUnsupportedTransData
::
SplitTransData
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
trans_node
)
const
{
auto
kernel_info
=
AnfAlgo
::
GetSelectKernelBuildInfo
(
trans_node
);
if
(
kHWSpecialFormatSet
.
find
(
kernel_info
->
GetInputFormat
(
0
))
==
kHWSpecialFormatSet
.
end
()
||
kHWSpecialFormatSet
.
find
(
kernel_info
->
GetOutputFormat
(
0
))
==
kHWSpecialFormatSet
.
end
())
{
return
trans_node
;
}
auto
builder_info_to_default
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
kernel_info
);
auto
builder_info_to_special_foramt
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
kernel_info
);
builder_info_to_default
->
SetOutputsFormat
({
kOpFormat_DEFAULT
});
builder_info_to_special_foramt
->
SetInputsFormat
({
kOpFormat_DEFAULT
});
std
::
vector
<
AnfNodePtr
>
next_trans_node_inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
KPrimTransData
->
name
())),
trans_node
};
auto
next_trans_node
=
func_graph
->
NewCNode
(
next_trans_node_inputs
);
next_trans_node
->
set_abstract
(
trans_node
->
abstract
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder_info_to_default
->
Build
(),
trans_node
.
get
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder_info_to_special_foramt
->
Build
(),
next_trans_node
.
get
());
return
next_trans_node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h
0 → 100644
浏览文件 @
7cb567eb
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#include "backend/optimizer/common/optimizer.h"
namespace
mindspore
{
namespace
opt
{
class
SplitUnsupportedTransData
:
public
PatternProcessPass
{
public:
explicit
SplitUnsupportedTransData
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"split_unsupported_transdata"
,
multigraph
)
{}
~
SplitUnsupportedTransData
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
AnfNodePtr
SplitTransData
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
trans_node
)
const
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc
浏览文件 @
7cb567eb
...
...
@@ -50,6 +50,8 @@ class TestHWInsertTransOp : public BackendCommon {
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
format
,
format
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
(),
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({{},{}});
builder
.
SetOutputsReshapeType
({});
builder
.
SetOutputsFormat
({
format
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
add
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
...
...
@@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon {
EXPECT_NE
(
ret
->
input
(
1
)
->
cast
<
CNodePtr
>
()
->
input
(
1
)
->
cast
<
CNodePtr
>
()
->
input
(
1
),
nullptr
);
auto
max_pool
=
ret
->
input
(
1
)
->
cast
<
CNodePtr
>
()
->
input
(
1
)
->
cast
<
CNodePtr
>
()
->
input
(
1
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{},{}});
builder
.
SetInputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
format
,
format
});
...
...
@@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
~
MockInsertTransOpKernelSelectTrans4Dto5D
()
override
=
default
;
void
SelectKernel
(
const
CNodePtr
&
cnode
)
override
{
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{}});
builder
.
SetInputsFormat
({
"NCHW"
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
...
...
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
浏览文件 @
7cb567eb
...
...
@@ -53,6 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_DEFAULT
,
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
(),
kFloat32
->
type_id
()});
builder
.
SetInputsReshapeType
({{},
{}});
builder
.
SetOutputsReshapeType
({{}});
builder
.
SetOutputsFormat
({
kOpFormat_NC1HWC0
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
add
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
...
...
@@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
kg
->
AddInternalOutput
(
tuple_getitem1
,
max_pool
);
kg
->
AddInternalOutput
(
tuple_getitem2
,
max_pool
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{},
{}});
builder
.
SetInputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_NC1HWC0
,
kOpFormat_NC1HWC0
});
...
...
@@ -99,6 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetOutputsDeviceType
({
kFloat32
->
type_id
()});
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{}});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
};
...
...
tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc
浏览文件 @
7cb567eb
...
...
@@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({});
builder
.
SetOutputsReshapeType
({});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
else
{
KernelBuildInfoBuilder
builder
;
...
...
@@ -58,7 +60,10 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({});
builder
.
SetOutputsReshapeType
({});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
}
};
...
...
@@ -74,6 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NCHW"
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{}});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
else
{
KernelBuildInfoBuilder
builder
;
...
...
@@ -81,6 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NCHW"
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{}});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
}
...
...
@@ -116,6 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
builder
.
SetKernelType
(
KernelType
::
TBE_KERNEL
);
builder
.
SetFusionType
(
kernel
::
FusionType
::
ELEMWISE
);
builder
.
SetProcessor
(
kernel
::
Processor
::
AICORE
);
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{}});
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
kernel_info
->
set_select_kernel_build_info
(
builder
.
Build
());
transpose
->
set_kernel_info
(
kernel_info
);
...
...
@@ -162,6 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
builder
.
SetKernelType
(
KernelType
::
TBE_KERNEL
);
builder
.
SetFusionType
(
kernel
::
FusionType
::
ELEMWISE
);
builder
.
SetProcessor
(
kernel
::
Processor
::
AICORE
);
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{}});
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
kernel_info
->
set_select_kernel_build_info
(
builder
.
Build
());
transpose
->
set_kernel_info
(
kernel_info
);
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
浏览文件 @
7cb567eb
...
...
@@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({});
builder
.
SetOutputsReshapeType
({});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
else
{
KernelBuildInfoBuilder
builder
;
...
...
@@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({});
builder
.
SetOutputsReshapeType
({});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
}
...
...
@@ -93,6 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
EXPECT_NE
(
transpose
,
nullptr
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsReshapeType
({});
builder
.
SetOutputsReshapeType
({});
builder
.
SetInputsFormat
({
"NCHW"
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
...
...
tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc
浏览文件 @
7cb567eb
...
...
@@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect {
~
MockEliminate5To4And4To5KernelSelect
()
override
=
default
;
void
SelectKernel
(
const
CNodePtr
&
cnode
)
override
{
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsReshapeType
({{}});
builder
.
SetOutputsReshapeType
({{}});
builder
.
SetInputsFormat
({
"NCHW"
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
...
...
@@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
(),
kFloat16
->
type_id
()});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({{},
{}});
builder
.
SetOutputsReshapeType
({{}});
sub
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
add
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
sub
.
get
());
...
...
@@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) {
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
(),
kFloat16
->
type_id
()});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({{},
{}});
builder
.
SetOutputsReshapeType
({{},
{}});
sub
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
add
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
sub
.
get
());
...
...
@@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) {
builder
.
SetOutputsFormat
({
"NC1HWC0"
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
(),
kFloat16
->
type_id
()});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetInputsReshapeType
({{},
{}});
builder
.
SetOutputsReshapeType
({{}});
sub
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
add
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
sub
.
get
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录