Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
44044d80
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
44044d80
编写于
5月 25, 2023
作者:
Z
zhoutianzi666
提交者:
GitHub
5月 25, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] Move down the transfer_layout (#52997)
* add tranfer_elim * transfer layout elimination
上级
f2ed4011
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
717 addition
and
2 deletion
+717
-2
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/transfer_layout_elim_pass.cc
paddle/fluid/framework/ir/transfer_layout_elim_pass.cc
+346
-0
paddle/fluid/framework/ir/transfer_layout_elim_pass.h
paddle/fluid/framework/ir/transfer_layout_elim_pass.h
+42
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+3
-2
test/ir/inference/CMakeLists.txt
test/ir/inference/CMakeLists.txt
+1
-0
test/ir/inference/test_transfer_layout_elim_pass.py
test/ir/inference/test_transfer_layout_elim_pass.py
+324
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
44044d80
...
...
@@ -107,6 +107,7 @@ pass_library(preln_residual_bias_fuse_pass inference)
pass_library
(
constant_folding_pass inference
)
pass_library
(
auto_mixed_precision_pass inference
)
pass_library
(
conv2d_fusion_layout_transfer_pass inference
)
pass_library
(
transfer_layout_elim_pass inference
)
pass_library
(
silu_fuse_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/transfer_layout_elim_pass.cc
0 → 100644
浏览文件 @
44044d80
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/transfer_layout_elim_pass.h"
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// (D) means deleted nodes
// (G) means generated node
// var0 var0' var0 var0'
// | | | |
// transfer_layout0(D) transfer_layout0'(D) | |
// | | | |
// var1(D) var1'(D) -> | |
// \ / \ /
// op_node -> op_node
// | |
// | var2
// | |
// | transfer_layout(G)
// | |
// var2 var2'(var2 + suffix)(G)
// | |
// other ops other ops
// Put transfer_layout after op_node
// transfer_info is for case when we need know this transfer_layout info,
// nchw_nhwc or nhwc_nchw
void
TransferLayoutElimPass
::
PutTranferlayoutAfterOp
(
Node
*
op_node
,
ir
::
Graph
*
graph
,
std
::
string
*
transfer_info
)
const
{
std
::
unordered_set
<
const
Node
*>
remove_nodes
;
// Ensure op_node has only one output!
int
op_node_useful_output
=
0
;
Node
*
var2
;
for
(
auto
ele
:
op_node
->
outputs
)
{
if
(
ele
->
outputs
.
size
()
>=
1
)
{
op_node_useful_output
++
;
var2
=
ele
;
}
}
CHECK_EQ
(
op_node_useful_output
==
1
,
true
);
// group_norm has 3 inputs, but we do not need there is a transfer_layout
// before Bias and Scale so we extract useful_var1s from op_node->inputs.
std
::
vector
<
Node
*>
useful_var1s
;
for
(
auto
var1
:
op_node
->
inputs
)
{
// if (var1->inputs.size() >= 1 &&
// var1->inputs[0]->Op()->Type() == "transfer_layout") {
// useful_var1s.push_back(var1);
// }
useful_var1s
.
push_back
(
var1
);
}
CHECK_EQ
(
useful_var1s
.
size
()
>=
1L
,
true
);
auto
transfer_layout_opdesc
=
*
useful_var1s
[
0
]
->
inputs
[
0
]
->
Op
()
->
Proto
();
auto
block
=
useful_var1s
[
0
]
->
inputs
[
0
]
->
Op
()
->
Block
();
framework
::
OpDesc
new_transfer_layout_desc
(
transfer_layout_opdesc
,
block
);
new_transfer_layout_desc
.
SetInput
(
"X"
,
{
var2
->
Name
()});
// Do not use this line code, may result in failing SetShape in netron
// display.
// auto *var2_desc = block->Var(var2->Name());
auto
*
var2_desc
=
var2
->
Var
();
auto
var2_shape
=
var2_desc
->
GetShape
();
CHECK_EQ
(
var2_shape
.
size
()
>=
4L
,
true
);
auto
new_var2_shape
=
var2_shape
;
std
::
string
suffix
=
"_nchw_to_nhwc"
;
auto
dst_layout
=
static_cast
<
DataLayout
>
(
new_transfer_layout_desc
.
GetAttrIfExists
<
int
>
(
"dst_layout"
));
auto
src_layout
=
static_cast
<
DataLayout
>
(
new_transfer_layout_desc
.
GetAttrIfExists
<
int
>
(
"src_layout"
));
if
(
dst_layout
==
DataLayout
::
NCHW
&&
src_layout
==
DataLayout
::
NHWC
)
{
suffix
=
"_nhwc_to_nchw"
;
if
(
transfer_info
)
*
transfer_info
=
"nhwc_nchw"
;
new_var2_shape
[
1
]
=
var2_shape
[
2
];
new_var2_shape
[
2
]
=
var2_shape
[
3
];
new_var2_shape
[
3
]
=
var2_shape
[
1
];
}
else
if
(
dst_layout
==
DataLayout
::
NHWC
&&
src_layout
==
DataLayout
::
NCHW
)
{
suffix
=
"_nchw_to_nhwc"
;
if
(
transfer_info
)
*
transfer_info
=
"nchw_nhwc"
;
new_var2_shape
[
1
]
=
var2_shape
[
3
];
new_var2_shape
[
2
]
=
var2_shape
[
1
];
new_var2_shape
[
3
]
=
var2_shape
[
2
];
}
var2_desc
->
SetShape
(
new_var2_shape
);
std
::
string
var2_dot_name
=
var2
->
Name
()
+
suffix
;
new_transfer_layout_desc
.
SetOutput
(
"Out"
,
{
var2_dot_name
});
new_transfer_layout_desc
.
Flush
();
auto
*
var2_dot_desc
=
block
->
Var
(
var2_dot_name
);
var2_dot_desc
->
SetPersistable
(
false
);
// set var2_dot_desc be var2_shape
var2_dot_desc
->
SetShape
(
var2_shape
);
var2_dot_desc
->
SetDataType
(
var2
->
Var
()
->
GetDataType
());
auto
var2_dot
=
graph
->
CreateVarNode
(
var2_dot_desc
);
auto
*
new_transfer_layout_node
=
graph
->
CreateOpNode
(
&
new_transfer_layout_desc
);
for
(
auto
other_op
:
var2
->
outputs
)
{
IR_NODE_UNLINK
(
var2
,
other_op
);
other_op
->
Op
()
->
RenameInput
(
var2
->
Name
(),
var2_dot_name
);
IR_NODE_LINK_TO
(
var2_dot
,
other_op
);
}
IR_NODE_LINK_TO
(
var2
,
new_transfer_layout_node
);
IR_NODE_LINK_TO
(
new_transfer_layout_node
,
var2_dot
);
for
(
auto
var1
:
useful_var1s
)
{
auto
transfer_layout0_op
=
var1
->
inputs
[
0
];
auto
var0
=
transfer_layout0_op
->
inputs
[
0
];
IR_NODE_UNLINK
(
var0
,
transfer_layout0_op
);
// IR_NODE_UNLINK(var1, op_node);
IR_NODE_LINK_TO
(
var0
,
op_node
);
op_node
->
Op
()
->
RenameInput
(
var1
->
Name
(),
var0
->
Name
());
remove_nodes
.
emplace
(
transfer_layout0_op
);
remove_nodes
.
emplace
(
var1
);
}
GraphSafeRemoveNodes
(
graph
,
remove_nodes
);
}
bool
TransferLayoutElimPass
::
AllInputIsTransferlayout
(
const
ir
::
Node
*
op_node
)
const
{
std
::
set
<
int
>
dst_layouts
;
std
::
set
<
int
>
src_layouts
;
auto
*
scope
=
param_scope
();
for
(
auto
var
:
op_node
->
inputs
)
{
// If this input is a 1D persistable tensor,we allow transfer_layout not
// appear before this var, but temporarily diasble this if.
if
(
var
->
Var
()
->
Persistable
()
&&
0
)
{
auto
var_dims
=
scope
->
FindVar
(
var
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
()
->
dims
();
if
(
var_dims
.
size
()
==
1
)
{
continue
;
}
}
if
(
var
->
inputs
.
size
()
!=
1L
)
{
return
false
;
}
if
(
var
->
outputs
.
size
()
!=
1L
)
{
return
false
;
}
if
(
var
->
inputs
[
0
]
->
Name
()
!=
"transfer_layout"
)
{
return
false
;
}
auto
transfer_layout_desc
=
var
->
inputs
[
0
]
->
Op
();
dst_layouts
.
insert
(
transfer_layout_desc
->
GetAttrIfExists
<
int
>
(
"dst_layout"
));
src_layouts
.
insert
(
transfer_layout_desc
->
GetAttrIfExists
<
int
>
(
"src_layout"
));
}
// Make sure the dst_layout and src_layout attribute is same so that these
// transfer_layout can be moved down.
return
dst_layouts
.
size
()
==
1
&&
src_layouts
.
size
()
==
1
;
}
// (D) means deleted nodes
// (G) means generated node
// var0
// |
// transfer_layout0(D)
// |
// var1
// |
// transfer_layout1(D ,op_node)
// |
// var2
// | | |
// op0 op1 op2
void
TransferLayoutElimPass
::
ElimTwoTranferlayout
(
Node
*
op_node
,
ir
::
Graph
*
graph
,
bool
*
modify
)
const
{
std
::
unordered_set
<
const
Node
*>
remove_nodes
;
auto
var1
=
op_node
->
inputs
[
0
];
auto
transfer_layout0
=
var1
->
inputs
[
0
];
auto
var0
=
transfer_layout0
->
inputs
[
0
];
auto
var2
=
op_node
->
outputs
[
0
];
CHECK_EQ
(
transfer_layout0
->
Name
()
==
"transfer_layout"
,
true
);
CHECK_EQ
(
op_node
->
Name
()
==
"transfer_layout"
,
true
);
int
dst0
=
transfer_layout0
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dst_layout"
);
int
src0
=
transfer_layout0
->
Op
()
->
GetAttrIfExists
<
int
>
(
"src_layout"
);
int
dst1
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dst_layout"
);
int
src1
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"src_layout"
);
if
(
!
(
dst0
==
src1
&&
dst1
==
src0
))
{
// We can not eliminate these two transfer_layout.
*
modify
=
false
;
return
;
}
*
modify
=
true
;
remove_nodes
.
emplace
(
transfer_layout0
);
remove_nodes
.
emplace
(
var1
);
remove_nodes
.
emplace
(
op_node
);
remove_nodes
.
emplace
(
var2
);
for
(
auto
next_op
:
var2
->
outputs
)
{
IR_NODE_LINK_TO
(
var0
,
next_op
);
next_op
->
Op
()
->
RenameInput
(
var2
->
Name
(),
var0
->
Name
());
}
GraphSafeRemoveNodes
(
graph
,
remove_nodes
);
}
void
TransferLayoutElimPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"transfer_layout_elim_pass"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
auto
transfer_format
=
[
&
](
std
::
string
data_format
)
->
std
::
string
{
if
(
data_format
==
"NCHW"
)
{
return
"NHWC"
;
}
else
if
(
data_format
==
"NHWC"
)
{
return
"NCHW"
;
}
return
""
;
};
while
(
true
)
{
auto
op_node_sorted
=
framework
::
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
framework
::
ir
::
SortKind
>
(
0
));
bool
modify
=
false
;
for
(
auto
*
op_node
:
op_node_sorted
)
{
if
(
!
op_node
->
IsOp
())
continue
;
// For these Ops, you can move down the transfer_layout without changing
// any attribute!
std
::
vector
<
std
::
string
>
act_like_ops
=
{
"elementwise_add"
,
"hard_swish"
,
"silu"
,
};
bool
is_act_like_op
=
find
(
act_like_ops
.
begin
(),
act_like_ops
.
end
(),
op_node
->
Name
())
!=
act_like_ops
.
end
();
// For these Ops, you can move down the transfer_layout, but MUST change
// the data_format attribute!
std
::
vector
<
std
::
string
>
pool_like_ops
=
{
// "pool2d",
// "group_norm",
};
bool
is_pool_like_op
=
find
(
pool_like_ops
.
begin
(),
pool_like_ops
.
end
(),
op_node
->
Name
())
!=
pool_like_ops
.
end
();
// For these Ops, you can move down the transfer_layout, but MUST change
// the axis attribute!
std
::
vector
<
std
::
string
>
concat_like_ops
=
{
"concat"
,
};
bool
is_concat_like_op
=
find
(
concat_like_ops
.
begin
(),
concat_like_ops
.
end
(),
op_node
->
Name
())
!=
concat_like_ops
.
end
();
bool
is_elim_op
=
op_node
->
Name
()
==
"transfer_layout"
;
if
(
!
(
is_act_like_op
||
is_concat_like_op
||
is_pool_like_op
||
is_elim_op
))
continue
;
if
(
AllInputIsTransferlayout
(
op_node
))
{
if
(
is_concat_like_op
)
{
std
::
string
transfer_info
;
PutTranferlayoutAfterOp
(
op_node
,
graph
,
&
transfer_info
);
int
axis
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"axis"
);
int
modify_axis
=
axis
;
if
(
transfer_info
==
"nhwc_nchw"
)
{
if
(
axis
==
1
)
{
modify_axis
=
3
;
}
else
if
(
axis
==
2
)
{
modify_axis
=
1
;
}
else
if
(
axis
==
3
)
{
modify_axis
=
2
;
}
}
else
if
(
transfer_info
==
"nchw_nhwc"
)
{
if
(
axis
==
1
)
{
modify_axis
=
2
;
}
else
if
(
axis
==
2
)
{
modify_axis
=
3
;
}
else
if
(
axis
==
3
)
{
modify_axis
=
1
;
}
}
op_node
->
Op
()
->
SetAttr
(
"axis"
,
modify_axis
);
modify
=
true
;
break
;
}
if
(
is_pool_like_op
)
{
PutTranferlayoutAfterOp
(
op_node
,
graph
,
nullptr
);
op_node
->
Op
()
->
SetAttr
(
"data_format"
,
transfer_format
(
op_node
->
Op
()
->
GetAttrIfExists
<
std
::
string
>
(
"data_format"
)));
modify
=
true
;
break
;
}
if
(
is_act_like_op
)
{
PutTranferlayoutAfterOp
(
op_node
,
graph
,
nullptr
);
modify
=
true
;
break
;
}
if
(
is_elim_op
)
{
ElimTwoTranferlayout
(
op_node
,
graph
,
&
modify
);
break
;
}
}
}
if
(
!
modify
)
break
;
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
transfer_layout_elim_pass
,
paddle
::
framework
::
ir
::
TransferLayoutElimPass
);
// Add below for test_transfer_elim_pass passing.
REGISTER_PASS_CAPABILITY
(
transfer_layout_elim_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
());
paddle/fluid/framework/ir/transfer_layout_elim_pass.h
0 → 100644
浏览文件 @
44044d80
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
class
TransferLayoutElimPass
:
public
FusePassBase
{
public:
virtual
~
TransferLayoutElimPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
bool
AllInputIsTransferlayout
(
const
Node
*
op_node
)
const
;
void
PutTranferlayoutAfterOp
(
Node
*
op_node
,
ir
::
Graph
*
graph
,
std
::
string
*
transfer_info
)
const
;
void
ElimTwoTranferlayout
(
Node
*
op_node
,
ir
::
Graph
*
graph
,
bool
*
modify
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
44044d80
...
...
@@ -264,8 +264,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif //
"transpose_flatten_concat_fuse_pass"
,
//
"conv2d_fusion_layout_transfer_pass"
,
//
"auto_mixed_precision_pass"
,
//
"inplace_op_var_pass"
,
// should be the last pass.
"transfer_layout_elim_pass"
,
"auto_mixed_precision_pass"
,
//
"inplace_op_var_pass"
,
// should be the last pass.
});
use_gpu_
=
true
;
...
...
test/ir/inference/CMakeLists.txt
浏览文件 @
44044d80
...
...
@@ -216,6 +216,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties
(
test_fc_fuse_pass PROPERTIES TIMEOUT 240
)
set_tests_properties
(
test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_inplace_op_pass PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_transfer_layout_elim_pass PROPERTIES TIMEOUT 300
)
set_tests_properties
(
test_simplify_with_basic_ops_pass_autoscan
PROPERTIES TIMEOUT 60
)
...
...
test/ir/inference/test_transfer_layout_elim_pass.py
0 → 100644
浏览文件 @
44044d80
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
CutlassAutoScanTest
,
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
os
.
environ
[
'NVIDIA_TF32_OVERRIDE'
]
=
'0'
class
TestTransferElimPass0
(
PassAutoScanTest
):
r
"""input0 input1
| |
transfer_layout transfer_layout
| |
transfer_layout_out0 transfer_layout_out1
\ /
elementwise_add
|
elementwise_add_out
"""
def
sample_predictor_configs
(
self
,
program_config
):
# for gpu
config
=
self
.
create_inference_config
(
use_gpu
=
True
)
yield
config
,
[
"elementwise_add"
,
"transfer_layout"
],
(
1e-4
,
1e-5
)
def
is_program_valid
(
self
,
prog_config
):
return
True
def
sample_program_config
(
self
,
draw
):
transfer_layout0
=
OpConfig
(
"transfer_layout"
,
inputs
=
{
"X"
:
[
"input0"
]},
outputs
=
{
"Out"
:
[
"transfer_layout_out0"
]},
dst_layout
=
1
,
src_layout
=
2
,
)
transfer_layout1
=
OpConfig
(
"transfer_layout"
,
inputs
=
{
"X"
:
[
"input1"
]},
outputs
=
{
"Out"
:
[
"transfer_layout_out1"
]},
dst_layout
=
1
,
src_layout
=
2
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"transfer_layout_out0"
],
"Y"
:
[
"transfer_layout_out1"
],
},
outputs
=
{
"Out"
:
[
"elementwise_add_out"
]},
axis
=-
1
,
)
ops
=
[
transfer_layout0
,
transfer_layout1
,
add_op
]
x_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
10
,
max_value
=
100
),
min_size
=
4
,
max_size
=
4
)
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"input0"
:
TensorConfig
(
shape
=
x_shape
),
"input1"
:
TensorConfig
(
shape
=
x_shape
),
},
outputs
=
[
"elementwise_add_out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
30
,
passes
=
[
"transfer_layout_elim_pass"
],
)
class
TestTransferElimPass1
(
PassAutoScanTest
):
r
"""input0 input1
| |
transfer_layout transfer_layout
| |
transfer_layout_out0 transfer_layout_out1
\ /
elementwise_add
|
elementwise_add_out
|
transfer_layout
|
transfer_layout2
"""
def
sample_predictor_configs
(
self
,
program_config
):
# for gpu
config
=
self
.
create_inference_config
(
use_gpu
=
True
)
yield
config
,
[
"elementwise_add"
],
(
1e-4
,
1e-5
)
def
is_program_valid
(
self
,
prog_config
):
return
True
def
sample_program_config
(
self
,
draw
):
transfer_layout0
=
OpConfig
(
"transfer_layout"
,
inputs
=
{
"X"
:
[
"input0"
]},
outputs
=
{
"Out"
:
[
"transfer_layout_out0"
]},
dst_layout
=
1
,
src_layout
=
2
,
)
transfer_layout1
=
OpConfig
(
"transfer_layout"
,
inputs
=
{
"X"
:
[
"input1"
]},
outputs
=
{
"Out"
:
[
"transfer_layout_out1"
]},
dst_layout
=
1
,
src_layout
=
2
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"transfer_layout_out0"
],
"Y"
:
[
"transfer_layout_out1"
],
},
outputs
=
{
"Out"
:
[
"elementwise_add_out"
]},
axis
=-
1
,
)
transfer_layout2
=
OpConfig
(
"transfer_layout"
,
inputs
=
{
"X"
:
[
"elementwise_add_out"
]},
outputs
=
{
"Out"
:
[
"transfer_layout_out2"
]},
dst_layout
=
2
,
src_layout
=
1
,
)
ops
=
[
transfer_layout0
,
transfer_layout1
,
add_op
,
transfer_layout2
]
x_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
10
,
max_value
=
100
),
min_size
=
4
,
max_size
=
4
)
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"input0"
:
TensorConfig
(
shape
=
x_shape
),
"input1"
:
TensorConfig
(
shape
=
x_shape
),
},
outputs
=
[
"transfer_layout_out2"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
30
,
passes
=
[
"transfer_layout_elim_pass"
],
)
class
TestTransferElimPass2
(
PassAutoScanTest
):
r
"""input0 input1
| |
transfer_layout transfer_layout
| |
transfer_layout_out0 transfer_layout_out1
\ /
concat
|
concat_out
"""
def
sample_predictor_configs
(
self
,
program_config
):
# for gpu
config
=
self
.
create_inference_config
(
use_gpu
=
True
)
yield
config
,
[
"concat"
,
"transfer_layout"
],
(
1e-4
,
1e-5
)
def
is_program_valid
(
self
,
prog_config
):
return
True
def
sample_program_config
(
self
,
draw
):
# nhwc -> nchw
transfer_layout0
=
OpConfig
(
"transfer_layout"
,
inputs
=
{
"X"
:
[
"input0"
]},
outputs
=
{
"Out"
:
[
"transfer_layout_out0"
]},
dst_layout
=
1
,
src_layout
=
2
,
)
transfer_layout1
=
OpConfig
(
"transfer_layout"
,
inputs
=
{
"X"
:
[
"input1"
]},
outputs
=
{
"Out"
:
[
"transfer_layout_out1"
]},
dst_layout
=
1
,
src_layout
=
2
,
)
concat_op
=
OpConfig
(
"concat"
,
inputs
=
{
"X"
:
[
"transfer_layout_out0"
,
"transfer_layout_out1"
]},
outputs
=
{
"Out"
:
[
"concat_out"
]},
axis
=
1
,
)
ops
=
[
transfer_layout0
,
transfer_layout1
,
concat_op
]
x_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
10
,
max_value
=
100
),
min_size
=
4
,
max_size
=
4
)
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"input0"
:
TensorConfig
(
shape
=
x_shape
),
"input1"
:
TensorConfig
(
shape
=
x_shape
),
},
outputs
=
[
"concat_out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
30
,
passes
=
[
"transfer_layout_elim_pass"
],
)
class
TestTransferElimPass3
(
CutlassAutoScanTest
):
def
sample_program_configs
(
self
,
*
args
,
**
kwargs
):
def
generate_input
(
input_shape
):
return
(
np
.
random
.
random
(
input_shape
)
-
0.5
).
astype
(
np
.
float32
)
# src_layout should be NCHW, because it is the model's input
for
dst_layout
,
src_layout
in
[[
1
,
2
]]:
for
axis
in
[
0
,
1
,
2
,
3
]:
ops_config
=
[
{
"op_type"
:
"transfer_layout"
,
"op_inputs"
:
{
"X"
:
[
"input0"
]},
"op_outputs"
:
{
"Out"
:
[
"transfer_layout_out0"
]},
"op_attrs"
:
{
"dst_layout"
:
dst_layout
,
"src_layout"
:
src_layout
,
},
},
{
"op_type"
:
"transfer_layout"
,
"op_inputs"
:
{
"X"
:
[
"input1"
]},
"op_outputs"
:
{
"Out"
:
[
"transfer_layout_out1"
]},
"op_attrs"
:
{
"dst_layout"
:
dst_layout
,
"src_layout"
:
src_layout
,
},
# nchw -> nhwc
},
{
"op_type"
:
"concat"
,
"op_inputs"
:
{
"X"
:
[
"transfer_layout_out0"
,
"transfer_layout_out1"
,
]
},
"op_outputs"
:
{
"Out"
:
[
"concat_out0"
]},
"op_attrs"
:
{
"axis"
:
axis
},
},
]
ops
=
self
.
generate_op_config
(
ops_config
)
input_shape
=
[
12
,
13
,
14
,
15
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"input0"
:
TensorConfig
(
data_gen
=
partial
(
generate_input
,
input_shape
)
),
"input1"
:
TensorConfig
(
data_gen
=
partial
(
generate_input
,
input_shape
)
),
},
outputs
=
[
"concat_out0"
],
)
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_gpu
=
True
)
config
.
enable_use_gpu
(
256
,
0
)
yield
config
,
(
1e-2
,
1e-2
)
def
test
(
self
,
*
args
,
**
kwargs
):
self
.
run_test
(
quant
=
False
,
*
args
,
**
kwargs
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录