Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e43f7102
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e43f7102
编写于
1月 16, 2023
作者:
Y
Yuanle Liu
提交者:
GitHub
1月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-TRT] support nhwc (#49633)
* add trt_support_nhwc_pass
上级
7de9420a
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
588 addition
and
75 deletion
+588
-75
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+18
-6
paddle/fluid/framework/data_layout_transform.h
paddle/fluid/framework/data_layout_transform.h
+10
-11
paddle/fluid/framework/data_layout_transform_test.cc
paddle/fluid/framework/data_layout_transform_test.cc
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-3
paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
.../fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
+21
-51
paddle/fluid/framework/ir/trt_support_nhwc_pass.cc
paddle/fluid/framework/ir/trt_support_nhwc_pass.cc
+365
-0
paddle/fluid/framework/ir/trt_support_nhwc_pass.h
paddle/fluid/framework/ir/trt_support_nhwc_pass.h
+35
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-1
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py
...tests/ir/inference/test_trt_convert_bilinear_interp_v2.py
+3
-3
python/paddle/fluid/tests/unittests/ir/inference/test_trt_support_nhwc_pass.py
...ests/unittests/ir/inference/test_trt_support_nhwc_pass.py
+132
-0
未找到文件。
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
e43f7102
...
...
@@ -14,7 +14,7 @@
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/
fluid/framework/convert_utils
.h"
#include "paddle/
phi/core/utils/data_type
.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
...
...
@@ -61,6 +61,18 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
platform
::
errors
::
PreconditionNotMet
(
"TransDataLayout only support DataLayout transform on same place."
));
TransDataLayout
(
kernel_type_for_var
.
layout
(),
expected_kernel_type
.
layout
(),
place
,
in
,
out
);
}
void
TransDataLayout
(
DataLayout
from_layout
,
DataLayout
to_layout
,
phi
::
Place
place
,
const
phi
::
DenseTensor
&
in
,
phi
::
DenseTensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
arity
(
in
.
dims
()),
4
,
...
...
@@ -73,8 +85,7 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
auto
src_dim
=
in
.
dims
();
std
::
vector
<
int64_t
>
dst_dim
;
auto
axis
=
GetAxis
(
kernel_type_for_var
.
layout
(),
expected_kernel_type
.
layout
());
auto
axis
=
GetAxis
(
from_layout
,
to_layout
);
dst_dim
.
resize
(
axis
.
size
());
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
dst_dim
[
i
]
=
src_dim
[
axis
[
i
]];
...
...
@@ -83,10 +94,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
out
->
Resize
(
phi
::
make_ddim
(
dst_dim
));
out
->
mutable_data
(
place
,
in
.
dtype
());
framework
::
VisitDataType
(
framework
::
TransToProtoVarType
(
in
.
dtype
()),
CastDataLayout
(
pool
.
Get
(
place
),
axis
,
in
,
out
));
framework
::
VisitDataType
(
static_cast
<
proto
::
VarType
::
Type
>
(
phi
::
TransToProtoVarType
(
in
.
dtype
())),
CastDataLayout
(
pool
.
Get
(
place
),
axis
,
in
,
out
));
out
->
set_layout
(
expected_kernel_type
.
layout
()
);
out
->
set_layout
(
to_layout
);
}
}
// namespace framework
...
...
paddle/fluid/framework/data_layout_transform.h
浏览文件 @
e43f7102
...
...
@@ -14,21 +14,14 @@
#pragma once
#include <map>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
namespace
paddle
{
namespace
framework
{
class
OpKernelType
;
}
// namespace framework
}
// namespace paddle
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/phi/backends/onednn/onednn_helper.h"
#endif
...
...
@@ -60,5 +53,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
phi
::
DenseTensor
*
out
,
const
phi
::
Place
&
place
);
void
TransDataLayout
(
phi
::
DataLayout
from_layout
,
phi
::
DataLayout
to_layout
,
phi
::
Place
place
,
const
phi
::
DenseTensor
&
in
,
phi
::
DenseTensor
*
out
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/data_layout_transform_test.cc
浏览文件 @
e43f7102
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/data_layout_transform.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/bfloat16.h"
TEST
(
DataTransform
,
DataLayoutFunction
)
{
auto
place
=
paddle
::
platform
::
CPUPlace
();
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
e43f7102
...
...
@@ -141,11 +141,9 @@ if(WITH_TENSORRT)
pass_library
(
layernorm_shift_partition_fuse_pass inference
)
pass_library
(
reverse_roll_fuse_pass inference
)
pass_library
(
preln_layernorm_x_fuse_pass inference
)
pass_library
(
trt_support_nhwc_pass inference
)
pass_library
(
elementwise_groupnorm_act_pass inference
)
pass_library
(
preln_elementwise_groupnorm_act_pass inference
)
endif
()
if
(
WITH_TENSORRT
)
pass_library
(
trt_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
endif
()
...
...
paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
浏览文件 @
e43f7102
...
...
@@ -18,7 +18,6 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/phi/common/layout.h"
...
...
@@ -30,43 +29,11 @@ namespace framework {
namespace
ir
{
namespace
{
void
TransDataLayout
(
DataLayout
from_layout
,
DataLayout
to_layout
,
const
phi
::
DenseTensor
&
in
,
phi
::
DenseTensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
arity
(
in
.
dims
()),
4
,
platform
::
errors
::
InvalidArgument
(
"Input dimension arity only can be 4, the input dimension is %s."
,
in
.
dims
()));
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
src_dim
=
in
.
dims
();
std
::
vector
<
int64_t
>
dst_dim
;
auto
axis
=
GetAxis
(
from_layout
,
to_layout
);
dst_dim
.
resize
(
axis
.
size
());
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
dst_dim
[
i
]
=
src_dim
[
axis
[
i
]];
}
out
->
Resize
(
phi
::
make_ddim
(
dst_dim
));
out
->
mutable_data
(
phi
::
CPUPlace
(),
in
.
dtype
());
framework
::
VisitDataType
(
framework
::
TransToProtoVarType
(
in
.
dtype
()),
CastDataLayout
(
pool
.
Get
(
phi
::
CPUPlace
()),
axis
,
in
,
out
));
out
->
set_layout
(
to_layout
);
}
void
InsertLayoutTransOp
(
ir
::
Graph
*
graph
,
ir
::
Node
*
prev_node
,
ir
::
Node
*
next_node
,
DataLayout
from_layout
,
DataLayout
to_layout
,
phi
::
DataLayout
from_layout
,
phi
::
DataLayout
to_layout
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
*
cache
)
{
auto
do_insert
=
[
&
](
const
std
::
string
&
in_var_name
,
...
...
@@ -91,7 +58,7 @@ void InsertLayoutTransOp(ir::Graph *graph,
op_out_var_desc
->
SetPersistable
(
false
);
op_out_var_desc
->
SetDataType
(
prev_node
->
Var
()
->
GetDataType
());
auto
to_shape
=
prev_node
->
Var
()
->
GetShape
();
if
(
from_layout
==
DataLayout
::
kNCHW
)
{
if
(
from_layout
==
phi
::
DataLayout
::
kNCHW
)
{
auto
n
=
to_shape
[
0
];
auto
c
=
to_shape
[
1
];
auto
h
=
to_shape
[
2
];
...
...
@@ -117,12 +84,13 @@ void InsertLayoutTransOp(ir::Graph *graph,
IR_NODE_UNLINK
(
prev_node
,
next_node
);
};
if
(
from_layout
==
DataLayout
::
kNCHW
&&
to_layout
==
DataLayout
::
kNHWC
)
{
if
(
from_layout
==
phi
::
DataLayout
::
kNCHW
&&
to_layout
==
phi
::
DataLayout
::
kNHWC
)
{
auto
in_var_name
=
prev_node
->
Var
()
->
Name
();
auto
out_var_name
=
in_var_name
+
"_nchw_to_nhwc"
;
do_insert
(
in_var_name
,
out_var_name
);
}
else
if
(
from_layout
==
DataLayout
::
kNHWC
&&
to_layout
==
DataLayout
::
kNCHW
)
{
}
else
if
(
from_layout
==
phi
::
DataLayout
::
kNHWC
&&
to_layout
==
phi
::
DataLayout
::
kNCHW
)
{
auto
in_var_name
=
prev_node
->
Var
()
->
Name
();
auto
out_var_name
=
in_var_name
+
"_nhwc_to_nchw"
;
do_insert
(
in_var_name
,
out_var_name
);
...
...
@@ -135,7 +103,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be nullptr."
));
FusePassBase
::
Init
(
"
data
_layout_transfer"
,
graph
);
FusePassBase
::
Init
(
"
conv2d_fusion
_layout_transfer"
,
graph
);
auto
*
scope
=
param_scope
();
// only float16 compute precision need insert transfer_layout.
...
...
@@ -170,7 +138,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
// Not support multiple block now.
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
cache
;
auto
op_nodes
=
ir
::
TopologySortOperations
(
*
graph
);
auto
op_nodes
=
TopologySortOperations
(
*
graph
);
auto
iter
=
op_nodes
.
cbegin
();
auto
*
block_desc
=
(
*
iter
)
->
Op
()
->
Block
();
...
...
@@ -186,7 +154,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
op_node
->
Op
()
->
GetAttrIfExists
<
std
::
string
>
(
"data_format"
);
if
(
data_format
!=
"NCHW"
)
return
false
;
auto
filter_names
=
op_node
->
Op
()
->
Input
(
"Filter"
);
constexpr
int
CUTLASS_
NHWC_ALIGNMENT
=
8
;
constexpr
int
NHWC_ALIGNMENT
=
8
;
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for
(
const
auto
&
filter_name
:
filter_names
)
{
auto
*
filter_var
=
scope
->
FindLocalVar
(
filter_name
);
...
...
@@ -195,7 +163,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
int
oc
=
filter_tensor
.
dims
()[
0
];
int
ic
=
filter_tensor
.
dims
()[
1
];
bool
cutlass_can_support
=
oc
%
CUTLASS_NHWC_ALIGNMENT
==
0
&&
ic
%
CUTLASS_
NHWC_ALIGNMENT
==
0
;
oc
%
NHWC_ALIGNMENT
==
0
&&
ic
%
NHWC_ALIGNMENT
==
0
;
if
(
!
cutlass_can_support
)
{
return
false
;
}
...
...
@@ -229,8 +197,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
if
(
cuDNNIsValid
(
op_node
))
{
valid_ops
.
insert
(
op_node
);
auto
*
op_desc
=
op_node
->
Op
();
auto
nhwc_attr
=
framework
::
Attribute
(
std
::
string
(
"NHWC"
));
op_desc
->
SetAttr
(
"data_format"
,
nhwc_attr
);
op_desc
->
SetAttr
(
"data_format"
,
std
::
string
{
"NHWC"
});
if
(
cutlass_enable
&&
CutlassIsValid
(
op_node
))
{
op_desc
->
SetType
(
"conv2d_fusion_cutlass"
);
}
...
...
@@ -244,8 +211,11 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
phi
::
DenseTensor
temp_tensor
=
*
filter_tensor
;
filter_tensor
->
clear
();
TransDataLayout
(
DataLayout
::
kNCHW
,
DataLayout
::
kNHWC
,
temp_tensor
,
filter_tensor
);
framework
::
TransDataLayout
(
phi
::
DataLayout
::
kNCHW
,
phi
::
DataLayout
::
kNHWC
,
phi
::
CPUPlace
{},
temp_tensor
,
filter_tensor
);
}
auto
op_inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
op_inputs
)
{
...
...
@@ -290,8 +260,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
InsertLayoutTransOp
(
graph
,
in_var_node
,
op_node
,
DataLayout
::
kNCHW
,
DataLayout
::
kNHWC
,
phi
::
DataLayout
::
kNCHW
,
phi
::
DataLayout
::
kNHWC
,
block_desc
,
&
cache
);
}
...
...
@@ -304,8 +274,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
InsertLayoutTransOp
(
graph
,
in_var_node
,
op_node
,
DataLayout
::
kNHWC
,
DataLayout
::
kNCHW
,
phi
::
DataLayout
::
kNHWC
,
phi
::
DataLayout
::
kNCHW
,
block_desc
,
&
cache
);
}
...
...
paddle/fluid/framework/ir/trt_support_nhwc_pass.cc
0 → 100644
浏览文件 @
e43f7102
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_support_nhwc_pass.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
void
DoInsertTransposeOp
(
ir
::
Graph
*
graph
,
ir
::
Node
*
prev_node
,
ir
::
Node
*
next_node
,
phi
::
DataLayout
from_layout
,
phi
::
DataLayout
to_layout
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
*
cache
)
{
auto
do_insert
=
[
&
](
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
)
{
auto
update_op_desc
=
[
&
](
framework
::
OpDesc
&
desc
,
const
std
::
string
&
x_name
,
const
std
::
string
&
out_name
,
const
std
::
vector
<
int
>
&
axis_attr
)
{
desc
.
SetType
(
"transpose"
);
desc
.
SetInput
(
"X"
,
{
x_name
});
desc
.
SetOutput
(
"Out"
,
{
out_name
});
desc
.
SetAttr
(
"axis"
,
axis_attr
);
desc
.
SetAttr
(
"use_mkldnn"
,
false
);
desc
.
SetAttr
(
"data_format"
,
std
::
string
{
"AnyLayout"
});
desc
.
SetAttr
(
"use_quantizer"
,
false
);
desc
.
SetAttr
(
"mkldnn_data_type"
,
std
::
string
{
"float32"
});
desc
.
Flush
();
};
CHECK_NOTNULL
(
block_desc
);
if
(
cache
->
count
(
prev_node
)
==
0
)
{
framework
::
OpDesc
op_desc
(
block_desc
);
if
(
from_layout
==
phi
::
DataLayout
::
kNCHW
)
{
update_op_desc
(
op_desc
,
in_var_name
,
out_var_name
,
{
0
,
2
,
3
,
1
});
}
else
if
(
from_layout
==
phi
::
DataLayout
::
kNHWC
)
{
update_op_desc
(
op_desc
,
in_var_name
,
out_var_name
,
{
0
,
3
,
1
,
2
});
}
auto
*
op_node
=
graph
->
CreateOpNode
(
&
op_desc
);
auto
*
op_out_var_desc
=
block_desc
->
Var
(
out_var_name
);
op_out_var_desc
->
SetPersistable
(
false
);
op_out_var_desc
->
SetDataType
(
prev_node
->
Var
()
->
GetDataType
());
auto
to_shape
=
prev_node
->
Var
()
->
GetShape
();
if
(
from_layout
==
phi
::
DataLayout
::
kNCHW
)
{
auto
n
=
to_shape
[
0
];
auto
c
=
to_shape
[
1
];
auto
h
=
to_shape
[
2
];
auto
w
=
to_shape
[
3
];
op_out_var_desc
->
SetShape
({
n
,
h
,
w
,
c
});
}
else
if
(
from_layout
==
phi
::
DataLayout
::
kNHWC
)
{
auto
n
=
to_shape
[
0
];
auto
h
=
to_shape
[
1
];
auto
w
=
to_shape
[
2
];
auto
c
=
to_shape
[
3
];
op_out_var_desc
->
SetShape
({
n
,
c
,
h
,
w
});
}
auto
*
op_out_var_node
=
graph
->
CreateVarNode
(
op_out_var_desc
);
IR_NODE_LINK_TO
(
op_node
,
op_out_var_node
);
cache
->
insert
(
std
::
make_pair
(
prev_node
,
op_out_var_node
));
}
next_node
->
Op
()
->
RenameInput
(
prev_node
->
Name
(),
cache
->
at
(
prev_node
)
->
Name
());
IR_NODE_LINK_TO
(
prev_node
,
cache
->
at
(
prev_node
)
->
inputs
.
front
());
IR_NODE_LINK_TO
(
cache
->
at
(
prev_node
),
next_node
);
IR_NODE_UNLINK
(
prev_node
,
next_node
);
};
if
(
from_layout
==
phi
::
DataLayout
::
kNCHW
&&
to_layout
==
phi
::
DataLayout
::
kNHWC
)
{
auto
in_var_name
=
prev_node
->
Var
()
->
Name
();
auto
out_var_name
=
in_var_name
+
"_nchw_to_nhwc"
;
do_insert
(
in_var_name
,
out_var_name
);
}
else
if
(
from_layout
==
phi
::
DataLayout
::
kNHWC
&&
to_layout
==
phi
::
DataLayout
::
kNCHW
)
{
auto
in_var_name
=
prev_node
->
Var
()
->
Name
();
auto
out_var_name
=
in_var_name
+
"_nhwc_to_nchw"
;
do_insert
(
in_var_name
,
out_var_name
);
}
}
bool
ModelLayoutIsNHWC
(
const
std
::
vector
<
ir
::
Node
*>
&
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
op_node
->
IsOp
())
{
auto
*
op_desc
=
op_node
->
Op
();
std
::
string
data_format
;
if
(
op_desc
->
HasAttr
(
"data_format"
))
{
data_format
=
op_desc
->
GetAttrIfExists
<
std
::
string
>
(
"data_format"
);
}
else
if
(
op_desc
->
HasAttr
(
"data_layout"
))
{
data_format
=
op_desc
->
GetAttrIfExists
<
std
::
string
>
(
"data_layout"
);
}
if
(
data_format
==
"NHWC"
)
{
return
true
;
}
}
}
return
false
;
}
}
// namespace
void
TrtSupportNHWCPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"During the trt_support_nhwc_pass, the graph "
"should not be null."
));
FusePassBase
::
Init
(
"trt_support_nhwc_pass"
,
graph
);
auto
*
scope
=
param_scope
();
auto
op_nodes
=
TopologySortOperations
(
*
graph
);
if
(
!
ModelLayoutIsNHWC
(
op_nodes
))
{
return
;
}
//
//
// TODO(liuyuanle): Add other op if needed!
//
//
std
::
unordered_set
<
std
::
string
>
need_trans_weights
{
"prelu"
};
std
::
unordered_set
<
std
::
string
>
not_trans_weights
{
"conv2d"
,
"pool2d"
,
"batch_norm"
,
"bilinear_interp"
,
"bilinear_interp_v2"
,
"nearest_interp"
,
"nearest_interp_v2"
};
// Ops must run under the original layout even though it has
// data_format/data_layout attribute, otherwise it will be very troublesome!
std
::
unordered_set
<
std
::
string
>
must_original_layout_ops
{
"affine_channel"
,
"softmax"
};
// OPs unrelated to layout are consistent according to the layout of input
// var!
std
::
unordered_set
<
std
::
string
>
any_layout_ops
{
"relu"
};
//
//
// TODO(liuyuanle): Add other op if needed!
//
//
// Ops with "data_format" or "data_layout" attribute value of "NHWC"
std
::
unordered_set
<
ir
::
Node
*>
transposed_ops
;
std
::
unordered_set
<
ir
::
Node
*>
vars_to_nchw
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
cache
;
// Not support multiple block now
auto
iter
=
op_nodes
.
cbegin
();
auto
*
block_desc
=
(
*
iter
)
->
Op
()
->
Block
();
for
(
auto
*
op_node
:
op_nodes
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
auto
*
op_desc
=
op_node
->
Op
();
std
::
string
data_format
;
if
(
op_desc
->
HasAttr
(
"data_format"
))
{
data_format
=
op_desc
->
GetAttrIfExists
<
std
::
string
>
(
"data_format"
);
}
else
if
(
op_desc
->
HasAttr
(
"data_layout"
))
{
data_format
=
op_desc
->
GetAttrIfExists
<
std
::
string
>
(
"data_layout"
);
}
bool
input_shape_4
{
true
};
auto
op_inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
op_inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
in_var_node
->
Var
()
->
Persistable
())
continue
;
auto
input_shape
=
in_var_node
->
Var
()
->
GetShape
();
input_shape_4
&=
(
input_shape
.
size
()
==
4
);
}
if
(
data_format
!=
"NHWC"
||
!
input_shape_4
||
any_layout_ops
.
count
(
op_desc
->
Type
())
||
must_original_layout_ops
.
count
(
op_desc
->
Type
()))
{
continue
;
}
// Transpose NHWC --> NCHW
//
// Update current op
transposed_ops
.
insert
(
op_node
);
if
(
op_desc
->
HasAttr
(
"data_format"
))
{
op_desc
->
SetAttr
(
"data_format"
,
std
::
string
{
"NCHW"
});
op_desc
->
Flush
();
}
else
if
(
op_desc
->
HasAttr
(
"data_layout"
))
{
op_desc
->
SetAttr
(
"data_layout"
,
std
::
string
{
"NCHW"
});
op_desc
->
Flush
();
}
auto
UpdateOutputVars
=
[
&
]
{
// Update output var of current op
auto
op_outputs
=
op_node
->
outputs
;
for
(
auto
*
out_var_node
:
op_outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
if
(
out_var_node
->
Var
()
->
Persistable
())
continue
;
auto
from_shape
=
out_var_node
->
Var
()
->
GetShape
();
if
(
from_shape
.
size
()
==
4
)
{
out_var_node
->
Var
()
->
SetShape
(
{
from_shape
[
0
],
from_shape
[
3
],
from_shape
[
1
],
from_shape
[
2
]});
vars_to_nchw
.
insert
(
out_var_node
);
}
}
};
if
(
not_trans_weights
.
count
(
op_desc
->
Type
()))
{
UpdateOutputVars
();
}
else
if
(
need_trans_weights
.
count
(
op_desc
->
Type
()))
{
std
::
vector
<
std
::
string
>
weights
;
if
(
op_desc
->
Type
()
==
"prelu"
)
{
weights
.
push_back
(
"Alpha"
);
}
auto
UpdateWeightVars
=
[
&
]
{
for
(
auto
const
&
weight
:
weights
)
{
// transfer weights
auto
weight_names
=
op_desc
->
Input
(
weight
);
for
(
const
auto
&
weight_name
:
weight_names
)
{
auto
*
weight_var
=
scope
->
FindLocalVar
(
weight_name
);
auto
*
weight_tensor
=
weight_var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
weight_tensor
->
dims
().
size
()
==
4
)
{
phi
::
DenseTensor
temp_tensor
=
*
weight_tensor
;
weight_tensor
->
clear
();
framework
::
TransDataLayout
(
phi
::
DataLayout
::
kNHWC
,
phi
::
DataLayout
::
kNCHW
,
phi
::
CPUPlace
{},
temp_tensor
,
weight_tensor
);
}
}
auto
op_inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
op_inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
in_var_node
->
Var
()
->
Persistable
())
{
if
(
std
::
find
(
weight_names
.
cbegin
(),
weight_names
.
cend
(),
in_var_node
->
Var
()
->
Name
())
!=
weight_names
.
cend
())
{
auto
from_shape
=
in_var_node
->
Var
()
->
GetShape
();
in_var_node
->
Var
()
->
SetShape
({
from_shape
[
0
],
from_shape
[
2
],
from_shape
[
3
],
from_shape
[
1
]});
}
}
}
}
};
UpdateWeightVars
();
UpdateOutputVars
();
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"During the trt_support_nhwc_pass, %s op not supported. Please "
"update the supported op lists."
,
op_desc
->
Type
()));
}
}
auto
ProcessAnyLayoutOps
=
[
&
]
{
// Process any layout ops
for
(
auto
*
op_node
:
op_nodes
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
auto
op_inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
op_inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
transposed_ops
.
count
(
op_node
))
continue
;
if
(
vars_to_nchw
.
count
(
in_var_node
)
&&
any_layout_ops
.
count
(
op_node
->
Op
()
->
Type
()))
{
transposed_ops
.
insert
(
op_node
);
// Update output var of current op
auto
op_outputs
=
op_node
->
outputs
;
for
(
auto
*
out_var_node
:
op_outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
if
(
out_var_node
->
Var
()
->
Persistable
())
continue
;
auto
from_shape
=
out_var_node
->
Var
()
->
GetShape
();
if
(
from_shape
.
size
()
==
4
)
{
out_var_node
->
Var
()
->
SetShape
(
{
from_shape
[
0
],
from_shape
[
3
],
from_shape
[
1
],
from_shape
[
2
]});
vars_to_nchw
.
insert
(
out_var_node
);
}
}
}
}
}
};
ProcessAnyLayoutOps
();
auto
InsertTransposeOp
=
[
&
]
{
// Insert transpose op
for
(
auto
*
op_node
:
op_nodes
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
if
(
transposed_ops
.
count
(
op_node
))
{
auto
op_inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
op_inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
in_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
vars_to_nchw
.
count
(
in_var_node
))
continue
;
DoInsertTransposeOp
(
graph
,
in_var_node
,
op_node
,
phi
::
DataLayout
::
kNHWC
,
phi
::
DataLayout
::
kNCHW
,
block_desc
,
&
cache
);
}
}
else
{
auto
op_inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
op_inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
vars_to_nchw
.
count
(
in_var_node
))
{
DoInsertTransposeOp
(
graph
,
in_var_node
,
op_node
,
phi
::
DataLayout
::
kNCHW
,
phi
::
DataLayout
::
kNHWC
,
block_desc
,
&
cache
);
}
}
}
}
};
InsertTransposeOp
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
trt_support_nhwc_pass
,
paddle
::
framework
::
ir
::
TrtSupportNHWCPass
);
paddle/fluid/framework/ir/trt_support_nhwc_pass.h
0 → 100644
浏览文件 @
e43f7102
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
TrtSupportNHWCPass
:
public
FusePassBase
{
public:
TrtSupportNHWCPass
()
=
default
;
~
TrtSupportNHWCPass
()
=
default
;
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
e43f7102
...
...
@@ -85,7 +85,8 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
void
PaddlePassBuilder
::
ClearPasses
()
{
passes_
.
clear
();
}
const
std
::
vector
<
std
::
string
>
kTRTSubgraphPasses
({
"adaptive_pool2d_convert_global_pass"
,
//
"trt_support_nhwc_pass"
,
"adaptive_pool2d_convert_global_pass"
,
//
"shuffle_channel_detect_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"delete_fill_constant_op_pass"
,
//
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py
浏览文件 @
e43f7102
...
...
@@ -44,10 +44,10 @@ class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest):
for
data_layout
in
[
"NCHW"
,
"NHWC"
]:
for
scale_y
in
[
2.0
,
1.0
]:
for
scale_x
in
[
2.0
,
1.0
]:
for
scale_x
in
[
2.0
]:
scale
=
[
scale_y
,
scale_x
]
for
out_h
in
[
32
,
64
,
128
,
192
]:
for
out_w
in
[
32
,
64
]:
for
out_h
in
[
32
,
128
]:
for
out_w
in
[
64
]:
dics
=
[
{
"data_layout"
:
data_layout
,
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_support_nhwc_pass.py
0 → 100644
浏览文件 @
e43f7102
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
shutil
import
unittest
import
numpy
as
np
import
paddle
import
paddle.inference
as
inference
import
paddle.nn
as
nn
import
paddle.static
as
static
paddle
.
enable_static
()
class
SimpleNet
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
SimpleNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
4
,
out_channels
=
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
,
data_format
=
'NHWC'
,
)
self
.
relu1
=
nn
.
ReLU
()
self
.
conv2
=
nn
.
Conv2D
(
in_channels
=
4
,
out_channels
=
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
,
data_format
=
'NHWC'
,
)
self
.
relu2
=
nn
.
ReLU
()
self
.
conv3
=
nn
.
Conv2D
(
in_channels
=
2
,
out_channels
=
1
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
,
data_format
=
'NHWC'
,
)
self
.
relu3
=
nn
.
ReLU
()
self
.
flatten
=
nn
.
Flatten
()
self
.
fc
=
nn
.
Linear
(
729
,
10
)
self
.
softmax
=
nn
.
Softmax
()
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
relu2
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
relu3
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
x
=
self
.
softmax
(
x
)
return
x
class
TRTNHWCConvertTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
path
=
'./inference_pass/nhwc_convert/infer_model'
def
create_model
(
self
):
image
=
static
.
data
(
name
=
'img'
,
shape
=
[
None
,
224
,
224
,
4
],
dtype
=
'float32'
)
predict
=
SimpleNet
()(
image
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
paddle
.
static
.
save_inference_model
(
self
.
path
,
[
image
],
[
predict
],
exe
)
def
create_predictor
(
self
):
config
=
paddle
.
inference
.
Config
(
self
.
path
+
'.pdmodel'
,
self
.
path
+
'.pdiparams'
)
config
.
enable_memory_optim
()
config
.
enable_use_gpu
(
100
,
0
)
config
.
enable_tensorrt_engine
(
workspace_size
=
1
<<
30
,
max_batch_size
=
1
,
min_subgraph_size
=
3
,
precision_mode
=
inference
.
PrecisionType
.
Float32
,
use_static
=
False
,
use_calib_mode
=
False
,
)
predictor
=
inference
.
create_predictor
(
config
)
return
predictor
def
infer
(
self
,
predictor
,
img
):
input_names
=
predictor
.
get_input_names
()
for
i
,
name
in
enumerate
(
input_names
):
input_tensor
=
predictor
.
get_input_handle
(
name
)
input_tensor
.
reshape
(
img
[
i
].
shape
)
input_tensor
.
copy_from_cpu
(
img
[
i
].
copy
())
predictor
.
run
()
results
=
[]
output_names
=
predictor
.
get_output_names
()
for
i
,
name
in
enumerate
(
output_names
):
output_tensor
=
predictor
.
get_output_handle
(
name
)
output_data
=
output_tensor
.
copy_to_cpu
()
results
.
append
(
output_data
)
return
results
def
test_nhwc_convert
(
self
):
self
.
create_model
()
predictor
=
self
.
create_predictor
()
img
=
np
.
ones
((
1
,
224
,
224
,
4
),
dtype
=
np
.
float32
)
result
=
self
.
infer
(
predictor
,
img
=
[
img
])
def
tearDown
(
self
):
shutil
.
rmtree
(
'./inference_pass/nhwc_convert/'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录