Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
007f3614
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
007f3614
编写于
6月 20, 2022
作者:
W
whs
提交者:
GitHub
6月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add passes and plugins for distributed inference of NLU (#43049)
上级
ec3e0a13
变更
31
隐藏空白更改
内联
并排
Showing
31 changed file
with
2064 addition
and
28 deletion
+2064
-28
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/delete_c_identity_op_pass.cc
paddle/fluid/framework/ir/delete_c_identity_op_pass.cc
+127
-0
paddle/fluid/framework/ir/delete_c_identity_op_pass.h
paddle/fluid/framework/ir/delete_c_identity_op_pass.h
+52
-0
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc
+215
-0
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h
+87
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+4
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+2
-0
paddle/fluid/inference/tensorrt/convert/c_allreduce_op.cc
paddle/fluid/inference/tensorrt/convert/c_allreduce_op.cc
+93
-0
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+0
-3
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+0
-3
paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc
...e/fluid/inference/tensorrt/convert/preln_residual_bias.cc
+95
-0
paddle/fluid/inference/tensorrt/convert/reshape_op.cc
paddle/fluid/inference/tensorrt/convert/reshape_op.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/transpose_op.cc
paddle/fluid/inference/tensorrt/convert/transpose_op.cc
+1
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+10
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+3
-1
paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu
.../fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu
+192
-0
paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h
...e/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h
+105
-0
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
...d/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
+298
-0
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h
...id/inference/tensorrt/plugin/preln_residual_bias_plugin.h
+154
-0
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
...d/operators/fused/fused_layernorm_residual_dropout_bias.h
+103
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-1
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
.../paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
+12
-0
python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py
...ddle/fluid/tests/unittests/ir/inference/auto_scan_test.py
+9
-9
python/paddle/fluid/tests/unittests/ir/inference/program_config.py
...ddle/fluid/tests/unittests/ir/inference/program_config.py
+2
-0
python/paddle/fluid/tests/unittests/ir/inference/test_delete_c_identity_op_pass.py
.../unittests/ir/inference/test_delete_c_identity_op_pass.py
+58
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_c_allreduce_infer_script.py
...ittests/ir/inference/test_trt_c_allreduce_infer_script.py
+110
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_c_allreduce.py
...ts/unittests/ir/inference/test_trt_convert_c_allreduce.py
+74
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py
...ests/ir/inference/test_trt_convert_preln_residual_bias.py
+181
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_skip_layernorm.py
...unittests/ir/inference/test_trt_convert_skip_layernorm.py
+9
-9
python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py
...sts/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py
+61
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
007f3614
...
...
@@ -143,6 +143,8 @@ pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library
(
delete_weight_dequant_linear_op_pass inference
)
pass_library
(
delete_quant_dequant_linear_op_pass inference
)
pass_library
(
delete_dropout_op_pass inference
)
pass_library
(
delete_c_identity_op_pass inference
)
pass_library
(
preln_residual_bias_fuse_pass inference
)
pass_library
(
delete_fill_constant_op_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/delete_c_identity_op_pass.cc
0 → 100644
浏览文件 @
007f3614
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_c_identity_op_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
void
DeleteCIdentityOpPattern
::
operator
()()
{
auto
any_op_out
=
pattern
->
NewNode
(
any_op_out_repr
())
->
assert_is_op_input
(
"c_identity"
,
"X"
)
->
AsInput
();
auto
c_identity_op
=
pattern
->
NewNode
(
c_identity_op_repr
())
->
assert_is_op
(
"c_identity"
);
auto
c_identity_op_out
=
pattern
->
NewNode
(
c_identity_op_out_repr
())
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
AsIntermediate
();
auto
any_op2
=
pattern
->
NewNode
(
any_op2_repr
())
->
assert_is_op
()
->
AsOutput
();
c_identity_op
->
LinksFrom
({
any_op_out
});
c_identity_op_out
->
LinksFrom
({
c_identity_op
});
any_op2
->
LinksFrom
({
c_identity_op_out
});
}
}
// namespace patterns
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(c_identity_op); \
GET_IR_NODE(c_identity_op_out); \
GET_IR_NODE(any_op2);
void
DeleteCIdentityOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_c_identity_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
DeleteCIdentityOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
IR_NODE_LINK_TO
(
any_op_out
,
any_op2
);
std
::
string
any_op_out_name
=
any_op_out
->
Var
()
->
Name
();
std
::
string
c_identity_op_out_name
=
c_identity_op_out
->
Var
()
->
Name
();
auto
*
any_op2_desc
=
any_op2
->
Op
();
auto
var_map
=
any_op2_desc
->
Inputs
();
std
::
string
arg_name
=
""
;
for
(
auto
&
name_m
:
var_map
)
{
if
(
std
::
find
(
name_m
.
second
.
begin
(),
name_m
.
second
.
end
(),
c_identity_op_out_name
)
!=
name_m
.
second
.
end
())
{
arg_name
=
name_m
.
first
;
}
}
if
(
arg_name
.
size
()
==
0
)
{
LOG
(
INFO
)
<<
"Delete c_identity op pass: can not find the input "
<<
c_identity_op_out_name
;
return
;
}
// modify the any_op2's inputs
for
(
auto
&
name_m
:
var_map
)
{
if
(
std
::
find
(
name_m
.
second
.
begin
(),
name_m
.
second
.
end
(),
c_identity_op_out_name
)
!=
name_m
.
second
.
end
())
{
std
::
vector
<
std
::
string
>
new_inputs
;
for
(
auto
&
i_n
:
name_m
.
second
)
{
if
(
i_n
!=
c_identity_op_out_name
)
{
new_inputs
.
push_back
(
i_n
);
}
}
new_inputs
.
push_back
(
any_op_out_name
);
any_op2_desc
->
SetInput
(
name_m
.
first
,
new_inputs
);
any_op2_desc
->
Flush
();
}
}
any_op2_desc
->
Flush
();
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
{
c_identity_op
,
c_identity_op_out
});
};
gpd
(
graph
,
handler
);
}
DeleteCIdentityOpPass
::
DeleteCIdentityOpPass
()
{
AddOpCompat
(
OpCompat
(
"c_identity"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_c_identity_op_pass
,
paddle
::
framework
::
ir
::
DeleteCIdentityOpPass
);
REGISTER_PASS_CAPABILITY
(
delete_c_identity_op_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
LE
(
"c_identity"
,
1
));
paddle/fluid/framework/ir/delete_c_identity_op_pass.h
0 → 100644
浏览文件 @
007f3614
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
DeleteCIdentityOpPattern
:
public
PatternBase
{
DeleteCIdentityOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_c_identity_op_pattern"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
any_op_out
);
PATTERN_DECL_NODE
(
c_identity_op
);
PATTERN_DECL_NODE
(
c_identity_op_out
);
PATTERN_DECL_NODE
(
any_op2
);
};
}
// namespace patterns
class
Graph
;
class
DeleteCIdentityOpPass
:
public
FusePassBase
{
public:
DeleteCIdentityOpPass
();
virtual
~
DeleteCIdentityOpPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc
0 → 100644
浏览文件 @
007f3614
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Node
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
PrelnResidualBias
:
public
PatternBase
{
PrelnResidualBias
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_residual_bias"
)
{}
void
operator
()(
PDNode
*
x
,
PDNode
*
y
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise_bias
);
PATTERN_DECL_NODE
(
elementwise0
);
PATTERN_DECL_NODE
(
elementwise1
);
PATTERN_DECL_NODE
(
layer_norm
);
// declare variable node's name
PATTERN_DECL_NODE
(
elementwise0_out
);
PATTERN_DECL_NODE
(
elementwise1_out
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
};
void
PrelnResidualBias
::
operator
()(
PDNode
*
x
,
PDNode
*
y
)
{
// Create nodes for elementwise add op.
x
->
assert_is_op_input
(
"elementwise_add"
);
y
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
elementwise0
=
pattern
->
NewNode
(
elementwise0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise_bias_var
=
pattern
->
NewNode
(
elementwise_bias_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
elementwise0_out_var
=
pattern
->
NewNode
(
elementwise0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_input
(
"elementwise_add"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
1
)
{
return
true
;
}
else
{
return
false
;
}
});
auto
*
elementwise1
=
pattern
->
NewNode
(
elementwise1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise1_out_var
=
pattern
->
NewNode
(
elementwise1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// Add links for elementwise_add op.
elementwise0
->
LinksFrom
({
y
,
elementwise_bias_var
})
.
LinksTo
({
elementwise0_out_var
});
elementwise1
->
LinksFrom
({
x
,
elementwise0_out_var
})
.
LinksTo
({
elementwise1_out_var
});
// Create nodes for layer_norm op.
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
// Add links for layer_norm op.
layer_norm
->
LinksFrom
(
{
elementwise1_out_var
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
}
}
// namespace patterns
void
PrelnResidualBiasFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"preln_residual_bias_fuse"
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_residual_bias_fuse/x"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
)
->
assert_var_not_persistable
();
auto
*
y
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_residual_bias_fuse/y"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_var_not_persistable
();
patterns
::
PrelnResidualBias
fused_pattern
(
gpd
.
mutable_pattern
(),
"preln_residual_bias_fuse"
);
fused_pattern
(
x
,
y
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
subgraph
.
count
(
x
)
<=
0
||
subgraph
.
count
(
y
)
<=
0
)
{
LOG
(
WARNING
)
<<
"The subgraph is empty."
;
return
;
}
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"preln_residual_bias pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle PrelnResidualBias fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_bias
,
elementwise_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise0
,
elementwise0
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise0_out
,
elementwise0_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise1
,
elementwise1
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise1_out
,
elementwise1_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an PrelnResidualBias op node
OpDesc
new_desc
;
new_desc
.
SetType
(
"preln_residual_bias"
);
// inputs
new_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
new_desc
.
SetInput
(
"Y"
,
{
subgraph
.
at
(
y
)
->
Name
()});
new_desc
.
SetInput
(
"Scale"
,
{
layer_norm_scale
->
Name
()});
new_desc
.
SetInput
(
"Bias"
,
{
layer_norm_bias
->
Name
()});
new_desc
.
SetInput
(
"EleBias"
,
{
elementwise_bias
->
Name
()});
// outputs
new_desc
.
SetOutput
(
"Out_0"
,
{
layer_norm_out
->
Name
()});
new_desc
.
SetOutput
(
"Out_1"
,
{
elementwise1_out
->
Name
()});
// attrs
new_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
new_desc
.
SetAttr
(
"begin_norm_axis"
,
layer_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
auto
fused_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// OpDesc will be copied.
del_node_set
.
insert
(
elementwise0
);
del_node_set
.
insert
(
elementwise1
);
del_node_set
.
insert
(
elementwise0_out
);
del_node_set
.
insert
(
layer_norm
);
del_node_set
.
insert
(
layer_norm_mean
);
del_node_set
.
insert
(
layer_norm_variance
);
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
y
),
fused_node
);
IR_NODE_LINK_TO
(
elementwise_bias
,
fused_node
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_node
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_node
);
IR_NODE_LINK_TO
(
fused_node
,
layer_norm_out
);
IR_NODE_LINK_TO
(
fused_node
,
elementwise1_out
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
preln_residual_bias_fuse_pass
,
paddle
::
framework
::
ir
::
PrelnResidualBiasFusePass
);
REGISTER_PASS_CAPABILITY
(
preln_residual_bias_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"layer_norm"
,
0
));
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h
0 → 100644
浏览文件 @
007f3614
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// other_op2
// | | | |
// other_op1 elementwise_add other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> preln_residual_bias
// | | | |
// other_op4 layer_norm other_op4 other_op3
// |
// other_op3
class
Graph
;
class
PrelnResidualBiasFusePass
:
public
FusePassBase
{
public:
PrelnResidualBiasFusePass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
0
,
-
1
})
.
End
();
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
}
virtual
~
PrelnResidualBiasFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
007f3614
...
...
@@ -1897,6 +1897,7 @@ USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER
(
elementwise_min_tensor
);
USE_TRT_CONVERTER
(
elementwise_pow_tensor
);
USE_TRT_CONVERTER
(
transpose
);
USE_TRT_CONVERTER
(
transpose2
);
USE_TRT_CONVERTER
(
flatten
);
USE_TRT_CONVERTER
(
flatten_contiguous_range
);
USE_TRT_CONVERTER
(
matmul
);
...
...
@@ -1945,6 +1946,7 @@ USE_TRT_CONVERTER(nearest_interp);
USE_TRT_CONVERTER
(
nearest_interp_v2
);
USE_TRT_CONVERTER
(
bilinear_interp_v2
);
USE_TRT_CONVERTER
(
reshape
);
USE_TRT_CONVERTER
(
reshape2
);
USE_TRT_CONVERTER
(
reduce_sum
);
USE_TRT_CONVERTER
(
gather_nd
);
USE_TRT_CONVERTER
(
reduce_mean
);
...
...
@@ -1956,6 +1958,8 @@ USE_TRT_CONVERTER(deformable_conv);
USE_TRT_CONVERTER
(
pool3d
)
USE_TRT_CONVERTER
(
fused_preln_embedding_eltwise_layernorm
)
USE_TRT_CONVERTER
(
preln_skip_layernorm
)
USE_TRT_CONVERTER
(
preln_residual_bias
)
USE_TRT_CONVERTER
(
c_allreduce_sum
)
USE_TRT_CONVERTER
(
roll
)
USE_TRT_CONVERTER
(
strided_slice
)
USE_TRT_CONVERTER
(
transformer_input_convert
)
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
007f3614
...
...
@@ -97,10 +97,12 @@ const std::vector<std::string> kTRTSubgraphPasses({
"simplify_with_basic_ops_pass"
,
//
"trt_embedding_eltwise_layernorm_fuse_pass"
,
//
"preln_embedding_eltwise_layernorm_fuse_pass"
,
//
"delete_c_identity_op_pass"
,
//
"trt_multihead_matmul_fuse_pass_v2"
,
//
"trt_multihead_matmul_fuse_pass_v3"
,
//
"trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
"preln_residual_bias_fuse_pass"
,
//
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
007f3614
...
...
@@ -62,6 +62,8 @@ list(
transformer_input_convert_op.cc
remove_padding_op.cc
recover_padding_op.cc
preln_residual_bias.cc
c_allreduce_op.cc
top_k_op.cc
squeeze2_op.cc
unsqueeze2_op.cc
)
...
...
paddle/fluid/inference/tensorrt/convert/c_allreduce_op.cc
0 → 100644
浏览文件 @
007f3614
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
using
ReduceType
=
paddle
::
inference
::
tensorrt
::
plugin
::
ReduceType
;
std
::
map
<
std
::
string
,
ReduceType
>
op_to_reduce_type
=
{
{
"c_allreduce_sum"
,
paddle
::
inference
::
tensorrt
::
plugin
::
kRedSum
},
{
"c_allreduce_max"
,
paddle
::
inference
::
tensorrt
::
plugin
::
kRedMax
},
{
"c_allreduce_min"
,
paddle
::
inference
::
tensorrt
::
plugin
::
kRedMin
},
{
"c_allreduce_prod"
,
paddle
::
inference
::
tensorrt
::
plugin
::
kRedProd
}};
class
CAllReduceOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fluid callreduce op to tensorrt layer"
;
if
(
!
engine_
->
with_dynamic_shape
())
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupported static mode. Please set dynamic shape of inputs."
));
}
ReduceType
red_type
=
op_to_reduce_type
[
op
.
type
()];
std
::
string
name
=
op
.
type
();
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
int
input_num
=
op_desc
.
Input
(
"X"
).
size
();
PADDLE_ENFORCE_EQ
(
input_num
,
1
,
platform
::
errors
::
InvalidArgument
(
"The input X's size must equal to 1 in TRT c_allreduce op."
" But received X's size %d."
,
input_num
));
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
// Get output
size_t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
PADDLE_ENFORCE_EQ
(
output_num
,
1UL
,
platform
::
errors
::
InvalidArgument
(
"The ouput Out's size must equal to 1 in TRT c_allreduce op. "
"But received Out's size %u."
,
output_num
));
// Get attrs
int
ring_id
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"ring_id"
));
bool
use_calc_stream
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"use_calc_stream"
));
nvinfer1
::
ILayer
*
layer
=
nullptr
;
#if IS_TRT_VERSION_GE(6000)
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
)
{
with_fp16
=
true
;
}
plugin
::
CAllReducePluginDynamic
*
plugin
=
new
plugin
::
CAllReducePluginDynamic
(
ring_id
,
use_calc_stream
,
red_type
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
input_num
,
plugin
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"
));
#endif
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
name
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
c_allreduce_sum
,
CAllReduceOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
c_allreduce_max
,
CAllReduceOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
c_allreduce_min
,
CAllReduceOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
c_allreduce_prod
,
CAllReduceOpConverter
);
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
007f3614
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
007f3614
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
...
...
paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc
0 → 100644
浏览文件 @
007f3614
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
using
half
=
paddle
::
platform
::
float16
;
class
PrelnResidualBiasOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fused preln_residual_bias op to tensorrt layer"
;
if
(
!
engine_
->
with_dynamic_shape
())
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupported static mode. Please set dynamic shape of inputs."
));
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
std
::
vector
<
nvinfer1
::
ITensor
*>
inputs
;
inputs
.
push_back
(
input1
);
inputs
.
push_back
(
input2
);
auto
get_persistable_data
=
[
&
](
const
std
::
string
&
arg_name
,
framework
::
DDim
*
dims
)
->
float
*
{
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
);
return
temp_data
;
};
framework
::
DDim
bias_dims
,
scale_dims
,
ele_bias_dims
;
auto
*
bias
=
get_persistable_data
(
"Bias"
,
&
bias_dims
);
auto
*
scale
=
get_persistable_data
(
"Scale"
,
&
scale_dims
);
auto
*
ele_bias
=
get_persistable_data
(
"EleBias"
,
&
ele_bias_dims
);
int
bias_size
=
phi
::
product
(
bias_dims
);
int
scale_size
=
phi
::
product
(
scale_dims
);
int
ele_bias_size
=
phi
::
product
(
ele_bias_dims
);
float
epsilon
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
)
{
with_fp16
=
true
;
}
nvinfer1
::
ILayer
*
layer
=
nullptr
;
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
if
(
with_fp16
)
{
auto
half_ele_bias_data
=
new
half
[
bias_size
];
for
(
int
i
=
0
;
i
<
bias_size
;
i
++
)
{
half_ele_bias_data
[
i
]
=
static_cast
<
half
>
(
ele_bias
[
i
]);
}
plugin
=
new
plugin
::
PrelnResidualBiasPluginDynamic
(
bias
,
scale
,
half_ele_bias_data
,
bias_size
,
scale_size
,
ele_bias_size
,
epsilon
,
with_fp16
);
}
else
{
plugin
=
new
plugin
::
PrelnResidualBiasPluginDynamic
(
bias
,
scale
,
ele_bias
,
bias_size
,
scale_size
,
ele_bias_size
,
epsilon
,
with_fp16
);
}
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
input1
);
plugin_inputs
.
emplace_back
(
input2
);
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
2
,
plugin
);
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
push_back
(
op_desc
.
Output
(
"Out_0"
)[
0
]);
output_names
.
push_back
(
op_desc
.
Output
(
"Out_1"
)[
0
]);
RreplenishLayerAndOutput
(
layer
,
"preln_residual_bias"
,
output_names
,
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
preln_residual_bias
,
PrelnResidualBiasOpConverter
);
paddle/fluid/inference/tensorrt/convert/reshape_op.cc
浏览文件 @
007f3614
...
...
@@ -61,3 +61,4 @@ class ReshapeOpConverter : public OpConverter {
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
reshape
,
ReshapeOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
reshape2
,
ReshapeOpConverter
);
paddle/fluid/inference/tensorrt/convert/transpose_op.cc
浏览文件 @
007f3614
...
...
@@ -60,3 +60,4 @@ class TransposeOpConverter : public OpConverter {
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
transpose
,
TransposeOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
transpose2
,
TransposeOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
007f3614
...
...
@@ -156,6 +156,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"slice"
,
"strided_slice"
,
"fused_preln_embedding_eltwise_layernorm"
,
"preln_residual_bias"
,
"c_allreduce_sum"
,
"c_allreduce_min"
,
"c_allreduce_max"
,
"c_allreduce_prod"
,
"roll"
,
"preln_skip_layernorm"
,
"transformer_input_convert"
,
...
...
@@ -254,6 +259,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"strided_slice"
,
"fused_preln_embedding_eltwise_layernorm"
,
"preln_skip_layernorm"
,
"preln_residual_bias"
,
"c_allreduce_sum"
,
"c_allreduce_min"
,
"c_allreduce_max"
,
"c_allreduce_prod"
,
"roll"
,
"multiclass_nms3"
,
"transformer_input_convert"
,
...
...
@@ -1994,9 +2004,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return
false
;
}
OpTeller
::
OpTeller
()
{
tellers_
.
emplace_back
(
new
SimpleOpTypeSetTeller
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
007f3614
...
...
@@ -27,7 +27,9 @@ list(
matmul_op_int8_plugin.cu
transformer_input_convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu
)
recover_padding_plugin.cu
c_allreduce_op_plugin.cu
preln_residual_bias_plugin.cu
)
if
(
CUSPARSELT_FOUND AND
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 8
)
list
(
APPEND TRT_FILES spmm_plugin.cu
)
...
...
paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu
0 → 100644
浏览文件 @
007f3614
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h"
#include "paddle/fluid/platform/collective_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if defined(PADDLE_WITH_NCCL)
inline
ncclDataType_t
NvInferDtypeToNCCLDType
(
nvinfer1
::
DataType
type
)
{
if
(
type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
return
ncclFloat
;
}
else
if
(
type
==
nvinfer1
::
DataType
::
kHALF
)
{
return
ncclFloat16
;
}
else
if
(
type
==
nvinfer1
::
DataType
::
kINT8
)
{
return
ncclInt8
;
}
else
if
(
type
==
nvinfer1
::
DataType
::
kINT32
)
{
return
ncclInt32
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This datatype in nccl is not supported."
));
}
}
#endif
CAllReducePluginDynamic
::
CAllReducePluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ring_id_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
use_calc_stream_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
red_type_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
}
nvinfer1
::
IPluginV2DynamicExt
*
CAllReducePluginDynamic
::
clone
()
const
TRT_NOEXCEPT
{
return
new
CAllReducePluginDynamic
(
ring_id_
,
use_calc_stream_
,
red_type_
,
with_fp16_
);
}
const
char
*
CAllReducePluginDynamic
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
"c_allreduce_plugin_dynamic"
;
}
int
CAllReducePluginDynamic
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
int
CAllReducePluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
};
size_t
CAllReducePluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
SerializedSize
(
ring_id_
)
+
SerializedSize
(
use_calc_stream_
)
+
SerializedSize
(
red_type_
);
+
SerializedSize
(
with_fp16_
);
}
void
CAllReducePluginDynamic
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
SerializeValue
(
&
buffer
,
ring_id_
);
SerializeValue
(
&
buffer
,
use_calc_stream_
);
SerializeValue
(
&
buffer
,
red_type_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
CAllReducePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
return
inputs
[
0
];
}
bool
CAllReducePluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of CAllReduce plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
||
pos
==
1
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
}
void
CAllReducePluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
size_t
CAllReducePluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
return
0
;
}
void
CAllReducePluginDynamic
::
destroy
()
TRT_NOEXCEPT
{
delete
this
;
}
nvinfer1
::
DataType
CAllReducePluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The CAllReduce Plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
return
input_types
[
0
];
}
int
CAllReducePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
#if defined(PADDLE_WITH_NCCL)
auto
input_dims
=
input_desc
[
0
].
dims
;
size_t
numel
=
ProductDim
(
input_dims
);
auto
input_type
=
input_desc
[
0
].
type
;
void
*
sendbuff
=
const_cast
<
void
*>
(
inputs
[
0
]);
void
*
recvbuff
=
outputs
[
0
];
ncclDataType_t
dtype
=
NvInferDtypeToNCCLDType
(
input_type
);
ncclRedOp_t
nccl_red_type
=
ncclSum
;
switch
(
red_type_
)
{
case
kRedSum
:
nccl_red_type
=
ncclSum
;
break
;
case
kRedMax
:
nccl_red_type
=
ncclMax
;
break
;
case
kRedMin
:
nccl_red_type
=
ncclMin
;
break
;
case
kRedProd
:
nccl_red_type
=
ncclProd
;
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Invalid reduce type: %d"
,
red_type_
));
}
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id_
);
cudaStream_t
custream
=
use_calc_stream_
?
stream
:
comm
->
stream
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
dtype
,
nccl_red_type
,
comm
->
comm
(),
stream
));
#endif
return
(
cudaGetLastError
()
!=
cudaSuccess
);
}
const
char
*
CAllReducePluginDynamicCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
"c_allreduce_plugin_dynamic"
;
}
const
char
*
CAllReducePluginDynamicCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
CAllReducePluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
{
auto
plugin
=
new
CAllReducePluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h
0 → 100644
浏览文件 @
007f3614
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
enum
ReduceType
{
kRedSum
,
kRedMax
,
kRedMin
,
kRedProd
};
class
CAllReducePluginDynamic
:
public
DynamicPluginTensorRT
{
private:
int
ring_id_
;
bool
use_calc_stream_
;
ReduceType
red_type_
;
public:
explicit
CAllReducePluginDynamic
(
const
int
ring_id
,
const
bool
use_calc_stream
,
const
ReduceType
red_type
,
const
bool
with_fp16
)
{
ring_id_
=
ring_id
;
use_calc_stream_
=
use_calc_stream
;
red_type_
=
red_type
;
with_fp16_
=
with_fp16
;
}
CAllReducePluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
;
};
class
CAllReducePluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
;
};
REGISTER_TRT_PLUGIN_V2
(
CAllReducePluginDynamicCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
0 → 100644
浏览文件 @
007f3614
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h"
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
using
half
=
phi
::
dtype
::
float16
;
#if IS_TRT_VERSION_GE(6000)
int
PrelnResidualBiasPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
float
)
*
bias_size_
);
cudaMemcpy
(
bias_gpu_
,
bias_
.
data
(),
bias_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
float
)
*
scale_size_
);
cudaMemcpy
(
scale_gpu_
,
scale_
.
data
(),
scale_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
if
(
with_fp16_
)
{
cudaMalloc
(
&
ele_bias_gpu_
,
sizeof
(
half
)
*
ele_bias_size_
);
cudaMemcpy
(
ele_bias_gpu_
,
fp16_ele_bias_
.
data
(),
ele_bias_size_
*
sizeof
(
half
),
cudaMemcpyHostToDevice
);
}
else
{
cudaMalloc
(
&
ele_bias_gpu_
,
sizeof
(
float
)
*
ele_bias_size_
);
cudaMemcpy
(
ele_bias_gpu_
,
fp32_ele_bias_
.
data
(),
ele_bias_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
return
0
;
}
void
PrelnResidualBiasPluginDynamic
::
terminate
()
TRT_NOEXCEPT
{
if
(
bias_gpu_
)
{
cudaFree
(
bias_gpu_
);
bias_gpu_
=
nullptr
;
}
if
(
scale_gpu_
)
{
cudaFree
(
scale_gpu_
);
scale_gpu_
=
nullptr
;
}
if
(
ele_bias_gpu_
)
{
cudaFree
(
ele_bias_gpu_
);
ele_bias_gpu_
=
nullptr
;
}
}
nvinfer1
::
IPluginV2DynamicExt
*
PrelnResidualBiasPluginDynamic
::
clone
()
const
TRT_NOEXCEPT
{
PrelnResidualBiasPluginDynamic
*
ptr
=
nullptr
;
if
(
with_fp16_
)
{
ptr
=
new
PrelnResidualBiasPluginDynamic
(
bias_
.
data
(),
scale_
.
data
(),
fp16_ele_bias_
.
data
(),
bias_size_
,
scale_size_
,
ele_bias_size_
,
eps_
,
with_fp16_
);
}
else
{
ptr
=
new
PrelnResidualBiasPluginDynamic
(
bias_
.
data
(),
scale_
.
data
(),
fp32_ele_bias_
.
data
(),
bias_size_
,
scale_size_
,
ele_bias_size_
,
eps_
,
with_fp16_
);
}
ptr
->
bias_gpu_
=
bias_gpu_
;
ptr
->
scale_gpu_
=
scale_gpu_
;
ptr
->
ele_bias_gpu_
=
ele_bias_gpu_
;
return
ptr
;
}
const
char
*
PrelnResidualBiasPluginDynamic
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
"preln_residual_bias_plugin_dynamic"
;
}
int
PrelnResidualBiasPluginDynamic
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
2
;
}
size_t
PrelnResidualBiasPluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
ser_size
=
SerializedSize
(
bias_
)
+
SerializedSize
(
scale_
)
+
SerializedSize
(
fp32_ele_bias_
)
+
SerializedSize
(
fp16_ele_bias_
)
+
SerializedSize
(
bias_size_
)
+
SerializedSize
(
scale_size_
)
+
SerializedSize
(
ele_bias_size_
)
+
SerializedSize
(
eps_
)
+
SerializedSize
(
with_fp16_
);
return
ser_size
;
}
void
PrelnResidualBiasPluginDynamic
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
SerializeValue
(
&
buffer
,
bias_
);
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
fp32_ele_bias_
);
SerializeValue
(
&
buffer
,
fp16_ele_bias_
);
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
SerializeValue
(
&
buffer
,
ele_bias_size_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
PrelnResidualBiasPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
if
(
output_index
<
2
)
{
return
inputs
[
0
];
}
else
{
// moving mean and var
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
1
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
2
];
return
ret
;
}
}
bool
PrelnResidualBiasPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return
(
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"TRT plugin supported FP16 is not available "
"while with_fp16 is set true."
));
#endif
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
if
(
pos
==
1
)
{
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
// output
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
void
PrelnResidualBiasPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
TRT_NOEXCEPT
{
}
size_t
PrelnResidualBiasPluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nb_outputs
)
const
TRT_NOEXCEPT
{
return
0
;
}
nvinfer1
::
DataType
PrelnResidualBiasPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
input_types
[
0
];
}
void
PrelnResidualBiasPluginDynamic
::
destroy
()
TRT_NOEXCEPT
{
delete
this
;
}
int
PrelnResidualBiasPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
int
hidden
=
input_dims
.
d
[
2
];
const
size_t
rows
=
static_cast
<
size_t
>
(
input_dims
.
d
[
0
]
*
input_dims
.
d
[
1
]);
// batch * seq_length
const
size_t
cols
=
static_cast
<
size_t
>
(
input_dims
.
d
[
2
]);
auto
input_type
=
input_desc
[
0
].
type
;
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. PrelnResidualBias-->fp32"
;
const
float
*
input1
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
float
*
input2
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
uint64_t
seed
=
0
;
const
float
dropout_prob
=
0.
;
const
bool
is_upscale_in_train
=
false
;
const
bool
is_test
=
true
;
const
uint64_t
increment
=
0
;
const
float
epsilon
=
eps_
;
const
float
*
src
=
input2
;
const
float
*
residual
=
input1
;
const
float
*
bias
=
static_cast
<
float
*>
(
ele_bias_gpu_
);
const
float
*
scale
=
scale_gpu_
;
const
float
*
layernorm_bias
=
bias_gpu_
;
uint8_t
*
mask_data
=
nullptr
;
float
*
dst
=
static_cast
<
float
*>
(
outputs
[
1
]);
float
*
layernorm_dst
=
static_cast
<
float
*>
(
outputs
[
0
]);
float
*
mean
=
nullptr
;
float
*
var
=
nullptr
;
const
int
VecSize
=
8
;
paddle
::
operators
::
FusedLayernormResidualDropoutBiasFunctor
<
float
,
uint8_t
,
VecSize
,
float
,
false
>
()(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
,
stream
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG
(
1
)
<<
"TRT Plugin DataType selected. PrelnResidualBias-->fp16"
;
const
half
*
input1
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
half
*
input2
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
uint64_t
seed
=
0
;
const
float
dropout_prob
=
0.
;
const
bool
is_upscale_in_train
=
false
;
const
bool
is_test
=
true
;
const
uint64_t
increment
=
0
;
const
float
epsilon
=
eps_
;
const
half
*
src
=
input2
;
const
half
*
residual
=
input1
;
const
half
*
bias
=
static_cast
<
half
*>
(
ele_bias_gpu_
);
const
float
*
scale
=
scale_gpu_
;
const
float
*
layernorm_bias
=
bias_gpu_
;
uint8_t
*
mask_data
=
nullptr
;
half
*
dst
=
static_cast
<
half
*>
(
outputs
[
1
]);
half
*
layernorm_dst
=
static_cast
<
half
*>
(
outputs
[
0
]);
float
*
mean
=
nullptr
;
float
*
var
=
nullptr
;
const
int
VecSize
=
8
;
paddle
::
operators
::
FusedLayernormResidualDropoutBiasFunctor
<
half
,
uint8_t
,
VecSize
,
float
,
false
>
()(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
,
stream
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The PrelnResidualBias TRT Plugin's input type "
"should be float or half."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
const
char
*
PrelnResidualBiasPluginDynamicCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
"preln_residual_bias_plugin_dynamic"
;
}
const
char
*
PrelnResidualBiasPluginDynamicCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
PrelnResidualBiasPluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
{
return
new
PrelnResidualBiasPluginDynamic
(
serial_data
,
serial_length
);
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h
0 → 100644
浏览文件 @
007f3614
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
using
half
=
phi
::
dtype
::
float16
;
#if IS_TRT_VERSION_GE(6000)
class
PrelnResidualBiasPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
PrelnResidualBiasPluginDynamic
(
const
float
*
bias
,
const
float
*
scale
,
const
half
*
ele_bias
,
int
bias_size
,
int
scale_size
,
int
ele_bias_size
,
const
float
eps
,
bool
with_fp16
)
:
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
ele_bias_size_
(
ele_bias_size
),
eps_
(
eps
)
{
with_fp16_
=
with_fp16
;
bias_
.
resize
(
bias_size
);
scale_
.
resize
(
scale_size
);
fp16_ele_bias_
.
resize
(
ele_bias_size
);
std
::
copy
(
ele_bias
,
ele_bias
+
ele_bias_size
,
fp16_ele_bias_
.
data
());
std
::
copy
(
bias
,
bias
+
bias_size
,
bias_
.
data
());
std
::
copy
(
scale
,
scale
+
scale_size
,
scale_
.
data
());
}
explicit
PrelnResidualBiasPluginDynamic
(
const
float
*
bias
,
const
float
*
scale
,
const
float
*
ele_bias
,
int
bias_size
,
int
scale_size
,
int
ele_bias_size
,
const
float
eps
,
bool
with_fp16
)
:
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
ele_bias_size_
(
ele_bias_size
),
eps_
(
eps
)
{
with_fp16_
=
with_fp16
;
bias_
.
resize
(
bias_size
);
scale_
.
resize
(
scale_size
);
fp32_ele_bias_
.
resize
(
ele_bias_size
);
std
::
copy
(
ele_bias
,
ele_bias
+
ele_bias_size
,
fp32_ele_bias_
.
data
());
std
::
copy
(
bias
,
bias
+
bias_size
,
bias_
.
data
());
std
::
copy
(
scale
,
scale
+
scale_size
,
scale_
.
data
());
}
PrelnResidualBiasPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
fp32_ele_bias_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
fp16_ele_bias_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
ele_bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nb_outputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
private:
std
::
vector
<
float
>
bias_
;
std
::
vector
<
float
>
scale_
;
std
::
vector
<
float
>
fp32_ele_bias_
;
std
::
vector
<
half
>
fp16_ele_bias_
;
float
*
bias_gpu_
{
nullptr
};
float
*
scale_gpu_
{
nullptr
};
void
*
ele_bias_gpu_
{
nullptr
};
int
bias_size_
;
int
scale_size_
;
int
ele_bias_size_
;
float
eps_
;
bool
with_fp16_
;
};
class
PrelnResidualBiasPluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
;
};
REGISTER_TRT_PLUGIN_V2
(
PrelnResidualBiasPluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
浏览文件 @
007f3614
...
...
@@ -155,6 +155,109 @@ __global__ void FusedLayernormResidualDropoutBias(
invvar
);
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
* rows: batch_size * seq_len
* cols: feature_size or hidden_size
* src: [rows, cols], inputs
* bias: [cols], linear bias, can be null
* residual:[rows, cols]
* mask: [rows, cols], dropout result
* dst: [rows, cols], residual + dropout(src+bias)
* layernorm_dst: [rows, cols], layernorm result
* layernorm_bias: [cols], layernorm bias, can be null
* scale: [cols]: layernorm scale, can be null
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
typename
U
,
bool
ScaleBiasWithSameTypeX
=
false
>
__global__
void
FusedLayernormResidualDropoutBiasInfer
(
const
size_t
rows
,
const
size_t
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
uint64_t
increment
,
const
float
epsilon
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
scale
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
layernorm_bias
,
MaskType
*
mask
,
T
*
dst
,
T
*
layernorm_dst
)
{
int
col_id
=
threadIdx
.
x
;
int
row_id
=
blockIdx
.
x
;
int
idx
=
row_id
*
cols
+
col_id
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
T
factor
=
GetFactor
<
T
>
(
dropout_prob
,
is_upscale_in_train
,
is_test
);
__shared__
U
mean_share
;
__shared__
U
var_share
;
__shared__
U
shared_mean
[
32
];
__shared__
U
shared_var
[
32
];
phi
::
funcs
::
ReluFunctor
<
T
>
relu
;
U
mean_val
=
0
;
U
var_val
=
0
;
for
(
int
i
=
col_id
*
VecSize
;
i
<
cols
;
i
+=
blockDim
.
x
*
VecSize
)
{
FusedResidualDropoutBiasOneThread
<
T
,
MaskType
,
VecSize
,
true
,
false
,
phi
::
funcs
::
ReluFunctor
<
T
>>
(
row_id
,
i
,
cols
,
&
state
,
dropout_prob
,
factor
,
src
,
residual
,
bias
,
dst
,
mask
,
is_test
,
&
mean_val
,
&
var_val
,
relu
);
}
mean_val
=
BlockReduceSum
<
U
>
(
mean_val
,
shared_mean
);
var_val
=
BlockReduceSum
<
U
>
(
var_val
,
shared_var
);
if
(
threadIdx
.
x
==
0
)
{
auto
scale
=
static_cast
<
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>>
(
static_cast
<
float
>
(
1.
)
/
static_cast
<
float
>
(
cols
));
auto
tmp
=
mean_val
*
static_cast
<
U
>
(
scale
);
mean_share
=
static_cast
<
U
>
(
tmp
);
var_share
=
static_cast
<
U
>
(
var_val
*
static_cast
<
U
>
(
scale
)
-
mean_share
*
mean_share
);
var_share
=
var_share
>
U
(
0
)
?
var_share
:
U
(
0
);
}
__syncthreads
();
mean_val
=
mean_share
;
U
invvar
=
rsqrt_
<
U
>
(
var_share
+
static_cast
<
U
>
(
epsilon
));
// calculate layernorm_dst
CalcLayernormY
<
T
,
VecSize
,
U
,
ScaleBiasWithSameTypeX
>
(
scale
,
layernorm_bias
,
dst
,
layernorm_dst
,
row_id
,
col_id
,
cols
,
mean_val
,
invvar
);
}
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
typename
U
,
bool
ScaleBiasWithSameTypeX
=
false
>
struct
FusedLayernormResidualDropoutBiasFunctor
{
void
operator
()(
const
size_t
rows
,
const
size_t
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
uint64_t
increment
,
const
float
epsilon
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
scale
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
layernorm_bias
,
MaskType
*
mask
,
T
*
dst
,
T
*
layernorm_dst
,
LayerNormParamType
<
T
>
*
mean
,
LayerNormParamType
<
T
>
*
var
,
cudaStream_t
stream
)
{
int
blockDim
=
GetDesiredBlockDim
(
cols
/
VecSize
);
if
(
mean
!=
nullptr
&&
var
!=
nullptr
)
{
FusedLayernormResidualDropoutBias
<
T
,
MaskType
,
VecSize
,
U
,
ScaleBiasWithSameTypeX
>
<<<
rows
,
blockDim
,
0
,
stream
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask
,
dst
,
layernorm_dst
,
mean
,
var
);
}
else
{
FusedLayernormResidualDropoutBiasInfer
<
T
,
MaskType
,
VecSize
,
U
,
ScaleBiasWithSameTypeX
>
<<<
rows
,
blockDim
,
0
,
stream
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask
,
dst
,
layernorm_dst
);
}
}
};
template
struct
FusedLayernormResidualDropoutBiasFunctor
<
paddle
::
platform
::
float16
,
uint8_t
,
8
,
float
,
false
>;
/*
* @brief layernorm(residual + dropout(x));
* Conditions:
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
007f3614
...
...
@@ -231,7 +231,7 @@ if(NOT WITH_DISTRIBUTE OR WIN32)
list
(
REMOVE_ITEM TEST_OPS test_fleet_rolemaker_2
)
list
(
REMOVE_ITEM TEST_OPS test_fleet_utils
)
list
(
REMOVE_ITEM TEST_OPS test_collective_cpu_barrier_with_gloo
)
list
(
REMOVE_ITEM TEST_OPS test_delete_c_identity_op_pass
)
# TODO: Fix these unittests failed on Windows
list
(
REMOVE_ITEM TEST_OPS test_fake_init_op
)
endif
()
...
...
@@ -244,6 +244,7 @@ endif()
if
(
WIN32
)
list
(
REMOVE_ITEM TEST_OPS test_complex_matmul
)
list
(
REMOVE_ITEM TEST_OPS test_ops_nms
)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_fleet_checkpoint
)
...
...
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
浏览文件 @
007f3614
...
...
@@ -16,6 +16,18 @@ file(
"test_trt_convert_*.py"
)
string
(
REPLACE
".py"
""
TEST_TRT_CONVERTER
"
${
TEST_TRT_CONVERTER
}
"
)
if
(
NOT WITH_DISTRIBUTE
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_delete_c_identity_op_pass"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_preln_residual_bias"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_preln_residual_bias"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_preln_residual_bias"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_c_allreduce"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_c_allreduce"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_c_allreduce"
)
endif
()
# Only for cpu(mkl + openblas)
set
(
TEST_INFERENCE_CPU_UT
"test_mul_lstm_fuse_pass"
"test_mul_gru_fuse_pass"
)
...
...
python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py
浏览文件 @
007f3614
...
...
@@ -605,7 +605,7 @@ class TrtLayerAutoScanTest(AutoScanTest):
dic
[
'use_trt'
]
=
False
return
str
(
dic
)
def
run_test
(
self
,
quant
=
False
,
*
args
,
**
kwargs
):
def
run_test
(
self
,
quant
=
False
,
skip_baseline
=
False
,
*
args
,
**
kwargs
):
status
=
True
run_flags
=
[]
for
prog_config
in
self
.
sample_program_configs
(
*
args
,
**
kwargs
):
...
...
@@ -636,14 +636,14 @@ class TrtLayerAutoScanTest(AutoScanTest):
}
results
:
List
[
Dict
[
str
,
np
.
ndarray
]]
=
[]
#
baseline: gpu run
logging
.
info
(
'RUN program_config: '
+
str
(
prog_config
))
gpu_config
=
self
.
create_inference_config
(
use_trt
=
False
)
results
.
append
(
self
.
run_test_config
(
model
,
params
,
prog_config
,
gpu_config
,
feed_data
))
self
.
success_log
(
'RUN_GPU_BASELINE done'
)
if
not
skip_baseline
:
#
baseline: gpu run
logging
.
info
(
'RUN program_config: '
+
str
(
prog_config
))
gpu_config
=
self
.
create_inference_config
(
use_trt
=
False
)
results
.
append
(
self
.
run_test_config
(
model
,
params
,
prog_config
,
gpu_config
,
feed_data
))
self
.
success_log
(
'RUN_GPU_BASELINE done'
)
for
pred_config
,
nodes_num
,
threshold
in
self
.
sample_predictor_configs
(
prog_config
):
...
...
python/paddle/fluid/tests/unittests/ir/inference/program_config.py
浏览文件 @
007f3614
...
...
@@ -226,6 +226,7 @@ def create_fake_model(program_config):
var_desc
.
set_type
(
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
var_desc
.
set_dtype
(
convert_np_dtype_to_dtype_
(
tensor_config
.
dtype
))
var_desc
.
set_shape
(
tensor_config
.
shape
)
print
(
f
"name:
{
name
}
; shape:
{
tensor_config
.
shape
}
"
)
var_desc
.
set_need_check_feed
(
True
)
if
tensor_config
.
lod
is
not
None
:
var_desc
.
set_lod_level
(
len
(
tensor_config
.
lod
))
...
...
@@ -323,6 +324,7 @@ def create_fake_model(program_config):
with
fluid
.
scope_guard
(
scope
):
executor
.
run
(
util_program
)
params
=
scope
.
find_var
(
"out_var_0"
).
get_bytes
()
return
model
,
params
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_delete_c_identity_op_pass.py
0 → 100644
浏览文件 @
007f3614
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
TensorConfig
,
ProgramConfig
,
OpConfig
import
paddle.inference
as
paddle_infer
import
unittest
import
hypothesis.strategies
as
st
class
TestDeleteCIdentityPass
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_trt_inference_config
()
config
.
enable_tensorrt_engine
(
max_batch_size
=
8
,
workspace_size
=
0
,
min_subgraph_size
=
0
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Float32
,
use_static
=
False
,
use_calib_mode
=
False
)
yield
config
,
[
'relu'
],
(
1e-5
,
1e-5
)
def
sample_program_config
(
self
,
draw
):
n
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
2
))
relu_op
=
OpConfig
(
"relu"
,
inputs
=
{
"X"
:
[
"relu_x"
]},
outputs
=
{
"Out"
:
[
"relu_out"
]})
c_identity_op
=
OpConfig
(
"c_identity"
,
inputs
=
{
"X"
:
[
"relu_out"
]},
outputs
=
{
"Out"
:
[
"id_out"
]})
program_config
=
ProgramConfig
(
ops
=
[
relu_op
,
c_identity_op
],
weights
=
{},
inputs
=
{
"relu_x"
:
TensorConfig
(
shape
=
[
n
])},
outputs
=
[
"id_out"
])
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
max_examples
=
2
,
min_success_num
=
2
,
passes
=
[
"delete_c_identity_op_pass"
])
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_c_allreduce_infer_script.py
0 → 100644
浏览文件 @
007f3614
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
tempfile
import
paddle
import
paddle.distributed.fleet
as
fleet
from
paddle.distributed
import
ReduceOp
from
paddle.distributed
import
init_parallel_env
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddle.inference
import
PrecisionType
from
paddle.fluid
import
core
def
run
(
op_type
,
precision
):
fleet
.
init
(
is_collective
=
True
)
paddle
.
enable_static
()
main_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
block
=
main_program
.
blocks
[
0
]
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
data
=
paddle
.
static
.
data
(
name
=
'data'
,
shape
=
[
3
,
4
],
dtype
=
'float32'
)
c_data
=
block
.
create_var
(
shape
=
data
.
shape
,
dtype
=
data
.
dtype
,
type
=
data
.
type
,
lod_level
=
data
.
lod_level
,
persistable
=
False
,
is_data
=
False
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
))
block
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
data
},
outputs
=
{
'Out'
:
c_data
},
attrs
=
{
'ring_id'
:
0
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
})
out
=
paddle
.
static
.
nn
.
fc
(
x
=
c_data
,
size
=
1
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
)))
mean
=
paddle
.
mean
(
out
)
exe
=
paddle
.
static
.
Executor
(
paddle
.
CPUPlace
())
exe
.
run
(
startup_program
)
nranks
=
2
current_endpoint
=
"127.0.0.1:600"
+
str
(
fleet
.
worker_index
())
trainer_endpoints
=
[
"127.0.0.1:6000"
,
"127.0.0.1:6001"
]
dist_config
=
core
.
DistConfig
()
dist_config
.
set_carrier_id
(
"inference"
)
dist_config
.
set_endpoints
(
trainer_endpoints
,
current_endpoint
)
dist_config
.
set_ranks
(
nranks
,
fleet
.
worker_index
())
dist_config
.
enable_dist_model
(
True
)
with
tempfile
.
TemporaryDirectory
(
prefix
=
"allreduce_"
)
as
tmpdir
:
paddle
.
static
.
save_inference_model
(
os
.
path
.
join
(
tmpdir
,
"model"
),
[
data
],
[
mean
],
exe
,
program
=
main_program
)
config
=
Config
(
os
.
path
.
join
(
tmpdir
,
"model.pdmodel"
),
os
.
path
.
join
(
tmpdir
,
"model.pdiparams"
))
config
.
enable_memory_optim
()
config
.
enable_use_gpu
(
1000
,
fleet
.
worker_index
())
config
.
set_dist_config
(
dist_config
)
config
.
enable_tensorrt_engine
(
workspace_size
=
1
<<
30
,
max_batch_size
=
1
,
min_subgraph_size
=
1
,
precision_mode
=
PrecisionType
.
Half
if
precision
==
"fp16"
else
PrecisionType
.
Int8
,
use_static
=
False
,
use_calib_mode
=
False
)
config
.
set_trt_dynamic_shape_info
({
"data"
:
[
3
,
4
]},
{
"data"
:
[
3
,
4
]},
{
"data"
:
[
3
,
4
]})
predictor
=
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_tensor
=
predictor
.
get_input_handle
(
"data"
)
input_tensor
.
reshape
([
3
,
4
])
input_tensor
.
copy_from_cpu
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
output_data
=
output_handle
.
copy_to_cpu
()
# numpy.ndarray类型
print
(
f
"c_allreduce_out=
{
output_data
[
0
]
}
"
)
if
__name__
==
"__main__"
:
if
len
(
sys
.
argv
)
<
2
:
# This script just be called by test_trt_convert_c_allreduce.py
sys
.
exit
(
0
)
op_type
=
sys
.
argv
[
1
]
precision
=
sys
.
argv
[
2
]
run
(
op_type
,
precision
)
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_c_allreduce.py
0 → 100644
浏览文件 @
007f3614
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
subprocess
import
sys
import
pickle
import
os
import
unittest
import
paddle
class
TestDistTRT
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
init_case
()
self
.
script
=
"test_trt_c_allreduce_infer_script.py"
def
init_case
(
self
):
self
.
op_type
=
"c_allreduce_sum"
self
.
target_value
=
4.
self
.
precision
=
"fp16"
def
test_run
(
self
):
env
=
dict
(
os
.
environ
)
env
[
"CUDA_VISIBLE_DEVICES"
]
=
"0,1"
cmd
=
f
"python -u -m paddle.distributed.fleet.launch --gpus 0,1
{
self
.
script
}
{
self
.
op_type
}
{
self
.
precision
}
"
cmd
=
cmd
.
split
(
" "
)
local_proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
env
=
env
)
local_out
,
local_err
=
local_proc
.
communicate
()
for
line
in
local_out
.
decode
(
"utf-8"
).
split
(
"
\n
"
):
results
=
line
.
split
(
"="
)
if
len
(
results
)
==
2
and
results
[
0
]
==
"c_allreduce_out"
:
self
.
assertEqual
(
float
(
results
[
1
]),
self
.
target_value
)
class
TestMin
(
TestDistTRT
):
def
init_case
(
self
):
self
.
op_type
=
"c_allreduce_min"
self
.
target_value
=
2.
self
.
precision
=
"int8"
#class TestMax(TestDistTRT):
#
# def init_case(self):
# self.op_type = "c_allreduce_max"
# self.target_value = 2.
# self.precision = "fp16"
#
#
#class TestProd(TestDistTRT):
#
# def init_case(self):
# self.op_type = "c_allreduce_prod"
# self.target_value = 2.
# self.precision = "fp16"
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py
0 → 100644
浏览文件 @
007f3614
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
trt_layer_auto_scan_test
import
TrtLayerAutoScanTest
,
SkipReasons
from
program_config
import
TensorConfig
,
ProgramConfig
import
numpy
as
np
import
paddle.inference
as
paddle_infer
from
functools
import
partial
from
typing
import
Optional
,
List
,
Callable
,
Dict
,
Any
,
Set
import
unittest
class
TrtConvertSkipLayernormTest
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
inputs
=
program_config
.
inputs
weights
=
program_config
.
weights
outputs
=
program_config
.
outputs
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
#The input dimension should be less than or equal to the set axis.
if
'begin_norm_axis'
in
attrs
[
0
]
and
attrs
[
0
][
'begin_norm_axis'
]
>=
0
:
if
len
(
inputs
[
'inputX_data'
].
shape
)
<=
attrs
[
0
][
'begin_norm_axis'
]:
return
False
return
True
def
sample_program_configs
(
self
):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]],
batch
):
return
np
.
ones
([
batch
,
128
,
768
]).
astype
(
np
.
float32
)
def
generate_input2
(
attrs
:
List
[
Dict
[
str
,
Any
]],
batch
):
return
np
.
ones
([
batch
,
128
,
768
]).
astype
(
np
.
float32
)
def
generate_weight1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
768
]).
astype
(
np
.
float32
)
def
generate_weight2
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
768
]).
astype
(
np
.
float32
)
for
batch
in
[
4
]:
for
epsilon
in
[
1e-5
]:
for
begin_norm_axis
in
[
2
]:
for
enable_int8
in
[
False
,
True
]:
dics
=
[{
"epsilon"
:
epsilon
,
"begin_norm_axis"
:
begin_norm_axis
,
},
{}]
ops_config
=
[{
"op_type"
:
"elementwise_add"
,
"op_inputs"
:
{
"X"
:
[
"inputX_data"
],
"Y"
:
[
"EleBias"
]
},
"op_outputs"
:
{
"Out"
:
[
"bias_out"
]
},
"op_attrs"
:
{
"axis"
:
-
1
}
},
{
"op_type"
:
"elementwise_add"
,
"op_inputs"
:
{
"X"
:
[
"bias_out"
],
"Y"
:
[
"inputY_data"
]
},
"op_outputs"
:
{
"Out"
:
[
"ele_out"
]
},
"op_attrs"
:
{
"axis"
:
-
1
}
},
{
"op_type"
:
"layer_norm"
,
"op_inputs"
:
{
"X"
:
[
"ele_out"
],
"Bias"
:
[
"Bias"
],
"Scale"
:
[
"Scale"
]
},
"op_outputs"
:
{
"Y"
:
[
"layernorm_out"
],
"Mean"
:
[
"Mean"
],
"Variance"
:
[
"Variance"
]
},
"op_attrs"
:
dics
[
0
]
}]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{
"Bias"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight1
,
dics
)),
"Scale"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight2
,
dics
)),
"EleBias"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight2
,
dics
))
},
inputs
=
{
"inputX_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input1
,
dics
,
batch
)),
"inputY_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input2
,
dics
,
batch
))
},
outputs
=
[
"ele_out"
,
"layernorm_out"
])
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"inputX_data"
:
[
4
,
128
,
768
],
"inputY_data"
:
[
4
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
]
}
self
.
dynamic_shape
.
max_input_shape
=
{
"inputX_data"
:
[
4
,
128
,
768
],
"inputY_data"
:
[
4
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
]
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"inputX_data"
:
[
4
,
128
,
768
],
"inputY_data"
:
[
4
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
]
}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
return
1
,
4
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# just support dynamic_shape
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-2
# atol=1e-2 while rtol is 1e-8
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-2
# atol=1e-2 while rtol is 1e-8
def
add_skip_trt_case
(
self
):
pass
def
test
(
self
):
self
.
add_skip_trt_case
()
self
.
run_test
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_skip_layernorm.py
浏览文件 @
007f3614
...
...
@@ -190,15 +190,15 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# for static_shape
clear_dynamic_shape
()
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
False
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
False
),
1e-5
#
#
for static_shape
#
clear_dynamic_shape()
#
self.trt_param.precision = paddle_infer.PrecisionType.Float32
#
yield self.create_inference_config(), generate_trt_nodes_num(
#
attrs, False), 1e-5
#
self.trt_param.precision = paddle_infer.PrecisionType.Half
#
yield self.create_inference_config(), generate_trt_nodes_num(
#
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape
(
attrs
)
...
...
python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py
0 → 100644
浏览文件 @
007f3614
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
pass_test
import
PassTest
import
paddle
class
PrelnResidualBiasFusePassTest
(
PassTest
):
def
setUp
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
128
,
768
],
dtype
=
"float32"
,
lod_level
=
0
)
bias
=
paddle
.
static
.
create_parameter
(
shape
=
[
768
],
dtype
=
'float32'
)
y
=
paddle
.
static
.
data
(
name
=
"y"
,
shape
=
[
128
,
768
],
dtype
=
"float32"
,
lod_level
=
0
)
x
=
x
+
bias
elementwise_out
=
x
+
y
out
=
paddle
.
static
.
nn
.
layer_norm
(
input
=
elementwise_out
)
self
.
fetch_list
=
[
out
,
elementwise_out
]
self
.
pass_names
=
"preln_residual_bias_fuse_pass"
self
.
fused_op_type
=
"preln_residual_bias"
self
.
num_fused_ops
=
1
# self.graph_attrs = {
# "embedding_eltwise_layernorm_fuse_pass_flag": True,
# "multihead_matmul_fuse_pass_flag": True
# }
def
test_check_program
(
self
):
use_gpu_set
=
[
False
]
if
paddle
.
device
.
is_compiled_with_cuda
():
use_gpu_set
.
append
(
True
)
for
use_gpu
in
use_gpu_set
:
place
=
paddle
.
CUDAPlace
(
0
)
if
use_gpu
else
paddle
.
CPUPlace
()
opt_program
=
self
.
_apply_ir_passes
()
self
.
check_program
(
opt_program
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录