Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
28ea9aad
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
28ea9aad
编写于
12月 14, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
12月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] rewrite convert_to_mixed_precision (#48853)
上级
b9fad5da
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
324 addition
and
950 deletion
+324
-950
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
+163
-159
paddle/fluid/framework/ir/auto_mixed_precision_pass.h
paddle/fluid/framework/ir/auto_mixed_precision_pass.h
+26
-12
paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
.../fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
+1
-1
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
.../fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
+1
-1
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+1
-1
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+8
-7
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+8
-8
paddle/fluid/inference/analysis/passes/CMakeLists.txt
paddle/fluid/inference/analysis/passes/CMakeLists.txt
+1
-1
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+64
-740
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
...id/inference/analysis/passes/convert_to_mixed_precision.h
+39
-9
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+4
-4
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+4
-4
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+1
-1
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-1
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
28ea9aad
...
@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference)
...
@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library
(
preln_residual_bias_fuse_pass inference
)
pass_library
(
preln_residual_bias_fuse_pass inference
)
pass_library
(
delete_fill_constant_op_pass inference
)
pass_library
(
delete_fill_constant_op_pass inference
)
pass_library
(
constant_folding_pass inference
)
pass_library
(
constant_folding_pass inference
)
pass_library
(
float_to_half
_pass inference
)
pass_library
(
auto_mixed_precision
_pass inference
)
pass_library
(
conv2d_fusion_layout_transfer_pass inference
)
pass_library
(
conv2d_fusion_layout_transfer_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/
float_to_half
_pass.cc
→
paddle/fluid/framework/ir/
auto_mixed_precision
_pass.cc
浏览文件 @
28ea9aad
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/
float_to_half
_pass.h"
#include "paddle/fluid/framework/ir/
auto_mixed_precision
_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
...
@@ -29,7 +29,7 @@ namespace ir {
...
@@ -29,7 +29,7 @@ namespace ir {
namespace
{
namespace
{
using
VarType
=
FloatToHalf
Pass
::
VarType
;
using
VarType
=
AutoMixedPrecision
Pass
::
VarType
;
bool
PhiKernelSupportPrecision
(
bool
PhiKernelSupportPrecision
(
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
...
@@ -71,6 +71,23 @@ bool GpuKernelSupportPrecision(
...
@@ -71,6 +71,23 @@ bool GpuKernelSupportPrecision(
return
support
;
return
support
;
}
}
inline
bool
VarNodeHasDtype
(
Node
*
var_node
)
{
auto
type
=
var_node
->
Var
()
->
GetType
();
return
(
type
==
VarType
::
SELECTED_ROWS
)
||
(
type
==
VarType
::
LOD_TENSOR
)
||
(
type
==
VarType
::
LOD_TENSOR_ARRAY
)
||
(
type
==
VarType
::
STRINGS
)
||
(
type
==
VarType
::
VOCAB
);
}
inline
bool
IsFloatType
(
VarType
::
Type
type
)
{
return
(
type
==
VarType
::
FP64
)
||
(
type
==
VarType
::
FP32
);
}
inline
bool
IsHalfType
(
VarType
::
Type
type
)
{
return
(
type
==
VarType
::
FP16
)
||
(
type
==
VarType
::
BF16
);
}
};
// namespace
void
DoInsertCastOp
(
Graph
*
graph
,
void
DoInsertCastOp
(
Graph
*
graph
,
Node
*
var_node
,
Node
*
var_node
,
Node
*
op_node
,
Node
*
op_node
,
...
@@ -123,27 +140,26 @@ void DoInsertCastOp(Graph* graph,
...
@@ -123,27 +140,26 @@ void DoInsertCastOp(Graph* graph,
IR_NODE_UNLINK
(
var_node
,
op_node
);
IR_NODE_UNLINK
(
var_node
,
op_node
);
}
}
inline
bool
VarNodeHasDtype
(
Node
*
var_node
)
{
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
auto
type
=
var_node
->
Var
()
->
GetType
();
phi
::
Backend
backend
,
return
(
type
==
VarType
::
SELECTED_ROWS
)
||
(
type
==
VarType
::
LOD_TENSOR
)
||
phi
::
DataType
precision
,
(
type
==
VarType
::
LOD_TENSOR_ARRAY
)
||
(
type
==
VarType
::
STRINGS
)
||
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
{
(
type
==
VarType
::
VOCAB
);
bool
support
=
false
;
}
if
(
black_list
.
count
(
op_type
)
==
0
)
{
if
(
backend
==
phi
::
Backend
::
GPU
)
{
inline
bool
IsFloatType
(
VarType
::
Type
type
)
{
support
=
GpuKernelSupportPrecision
(
op_type
,
precision
);
return
(
type
==
VarType
::
FP64
)
||
(
type
==
VarType
::
FP32
);
}
else
{
}
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"Now, only support backend of GPU."
));
inline
bool
IsHalfType
(
VarType
::
Type
type
)
{
}
return
(
type
==
VarType
::
FP16
)
||
(
type
==
VarType
::
BF16
);
}
return
support
;
}
}
};
// namespace
// The set of ops that support fp16 calculation and are considered
// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
// downstream ops.
void
FloatToHalf
Pass
::
SetDefaultBlacklist
()
const
{
void
AutoMixedPrecision
Pass
::
SetDefaultBlacklist
()
const
{
black_list_
.
insert
({
black_list_
.
insert
({
// numerically-dangerous
// numerically-dangerous
"acos"
,
"acos"
,
...
@@ -175,12 +191,27 @@ void FloatToHalfPass::SetDefaultBlacklist() const {
...
@@ -175,12 +191,27 @@ void FloatToHalfPass::SetDefaultBlacklist() const {
});
});
}
}
void
FloatToHalfPass
::
Init
(
Graph
*
graph
)
const
{
void
AutoMixedPrecisionPass
::
Init
(
Graph
*
graph
)
const
{
keep_io_types_
=
true
;
bool
enable_gpu_mixed
=
Get
<
bool
>
(
"enable_gpu_mixed"
);
half_precision_
=
if
(
enable_gpu_mixed
)
{
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"mixed_precision_mode"
));
backend_
=
phi
::
Backend
::
GPU
;
}
skip_pass_
=
!
enable_gpu_mixed
;
low_precision_
=
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"mixed_precision_mode"
));
black_list_
=
Get
<
std
::
unordered_set
<
std
::
string
>>
(
"mixed_black_list"
);
black_list_
=
Get
<
std
::
unordered_set
<
std
::
string
>>
(
"mixed_black_list"
);
SetDefaultBlacklist
();
SetDefaultBlacklist
();
VLOG
(
4
)
<<
"black_list has "
;
for
(
const
auto
&
name
:
black_list_
)
{
VLOG
(
4
)
<<
" - "
<<
name
;
}
keep_io_types_
=
true
;
if
(
Has
(
"keep_io_types"
))
{
keep_io_types_
=
Get
<
bool
>
(
"keep_io_types"
);
}
auto
graph_size
=
graph
->
SubGraphsSize
();
auto
graph_size
=
graph
->
SubGraphsSize
();
VLOG
(
4
)
<<
"graph size: "
<<
graph_size
;
VLOG
(
4
)
<<
"graph size: "
<<
graph_size
;
...
@@ -204,24 +235,27 @@ void FloatToHalfPass::Init(Graph* graph) const {
...
@@ -204,24 +235,27 @@ void FloatToHalfPass::Init(Graph* graph) const {
}
}
}
}
void
FloatToHalfPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
AutoMixedPrecisionPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
auto
enable_gpu_half
=
Get
<
bool
>
(
"enable_gpu_half"
);
PADDLE_ENFORCE_NOT_NULL
(
graph
,
if
(
!
enable_gpu_half
)
return
;
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"During the float to half pass, the graph should not be nullptr."
));
"During the auto_mixed_precision_pass, the graph "
PADDLE_ENFORCE_EQ
(
"should not be nullptr."
));
graph
->
IsMainGraph
(),
PADDLE_ENFORCE_EQ
(
graph
->
IsMainGraph
(),
true
,
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"During the float to half pass, the graph should be main graph."
));
"During the auto_mixed_precision_pass, the graph "
"should be main graph."
));
FusePassBase
::
Init
(
"
float_to_half
"
,
graph
);
FusePassBase
::
Init
(
"
auto_mixed_precision
"
,
graph
);
Init
(
graph
);
Init
(
graph
);
VLOG
(
4
)
<<
"Init done"
;
VLOG
(
4
)
<<
"Init done"
;
if
(
skip_pass_
)
{
VLOG
(
3
)
<<
"Skip auto_mixed_precision_pass."
;
return
;
}
SetOpUniqueType
();
SetOpUniqueType
();
VLOG
(
4
)
<<
"SetOpUniqueType done"
;
VLOG
(
4
)
<<
"SetOpUniqueType done"
;
GetOpPrecision
();
GetOpPrecision
();
...
@@ -240,19 +274,7 @@ void FloatToHalfPass::ApplyImpl(Graph* graph) const {
...
@@ -240,19 +274,7 @@ void FloatToHalfPass::ApplyImpl(Graph* graph) const {
VLOG
(
4
)
<<
"RestoreOpOriginType done"
;
VLOG
(
4
)
<<
"RestoreOpOriginType done"
;
}
}
bool
FloatToHalfPass
::
OpSupportPrecision
(
const
std
::
string
&
op_type
,
void
AutoMixedPrecisionPass
::
SetOpUniqueType
()
const
{
phi
::
DataType
precision
,
phi
::
Backend
backend
)
const
{
bool
support
=
false
;
if
(
black_list_
.
count
(
op_type
)
==
0
)
{
if
(
backend
==
phi
::
Backend
::
GPU
)
{
support
=
GpuKernelSupportPrecision
(
op_type
,
precision
);
}
}
return
support
;
}
void
FloatToHalfPass
::
SetOpUniqueType
()
const
{
int
suffix
=
0
;
int
suffix
=
0
;
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
for
(
auto
*
op_node
:
nodes
)
{
...
@@ -269,7 +291,7 @@ void FloatToHalfPass::SetOpUniqueType() const {
...
@@ -269,7 +291,7 @@ void FloatToHalfPass::SetOpUniqueType() const {
}
}
}
}
void
FloatToHalf
Pass
::
RestoreOpOriginType
()
const
{
void
AutoMixedPrecision
Pass
::
RestoreOpOriginType
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
op_type
=
op_node
->
Op
()
->
Type
();
...
@@ -281,7 +303,7 @@ void FloatToHalfPass::RestoreOpOriginType() const {
...
@@ -281,7 +303,7 @@ void FloatToHalfPass::RestoreOpOriginType() const {
}
}
}
}
inline
std
::
string
FloatToHalf
Pass
::
GetOpOriginalType
(
inline
std
::
string
AutoMixedPrecision
Pass
::
GetOpOriginalType
(
const
std
::
string
&
op_type
)
const
{
const
std
::
string
&
op_type
)
const
{
if
(
op_original_type_
.
count
(
op_type
))
{
if
(
op_original_type_
.
count
(
op_type
))
{
return
op_original_type_
.
at
(
op_type
);
return
op_original_type_
.
at
(
op_type
);
...
@@ -289,22 +311,21 @@ inline std::string FloatToHalfPass::GetOpOriginalType(
...
@@ -289,22 +311,21 @@ inline std::string FloatToHalfPass::GetOpOriginalType(
return
op_type
;
return
op_type
;
}
}
void
FloatToHalf
Pass
::
ProcessOpWithDtypeAttr
()
const
{
void
AutoMixedPrecision
Pass
::
ProcessOpWithDtypeAttr
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_run_
half
_
.
count
(
op_type
)
==
0
)
continue
;
if
(
op_run_
low_precision
_
.
count
(
op_type
)
==
0
)
continue
;
if
(
op_node
->
Op
()
->
HasAttr
(
"dtype"
))
{
if
(
op_node
->
Op
()
->
HasAttr
(
"dtype"
))
{
auto
dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dtype"
);
auto
dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dtype"
);
if
(
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
dtype
)))
{
if
(
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
dtype
)))
{
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
"dtype"
,
static_cast
<
int
>
(
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
low_precision_
)));
framework
::
TransToProtoVarType
(
half_precision_
)));
op_node
->
Op
()
->
Flush
();
op_node
->
Op
()
->
Flush
();
VLOG
(
4
)
<<
"process op with dtype attr: "
<<
op_type
<<
" ( "
<<
dtype
VLOG
(
4
)
<<
"process op with dtype attr: "
<<
op_type
<<
" ( "
<<
dtype
<<
" --->"
<<
static_cast
<
int
>
(
half
_precision_
)
<<
" )"
;
<<
" --->"
<<
static_cast
<
int
>
(
low
_precision_
)
<<
" )"
;
}
}
}
}
if
(
op_node
->
Op
()
->
HasAttr
(
"out_dtype"
))
{
if
(
op_node
->
Op
()
->
HasAttr
(
"out_dtype"
))
{
...
@@ -312,11 +333,10 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
...
@@ -312,11 +333,10 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
if
(
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
out_dtype
)))
{
if
(
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
out_dtype
)))
{
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
"out_dtype"
,
static_cast
<
int
>
(
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
low_precision_
)));
framework
::
TransToProtoVarType
(
half_precision_
)));
op_node
->
Op
()
->
Flush
();
op_node
->
Op
()
->
Flush
();
VLOG
(
4
)
<<
"process op with out_dtype attr: "
<<
op_type
<<
" ( "
VLOG
(
4
)
<<
"process op with out_dtype attr: "
<<
op_type
<<
" ( "
<<
out_dtype
<<
" --->"
<<
static_cast
<
int
>
(
half
_precision_
)
<<
out_dtype
<<
" --->"
<<
static_cast
<
int
>
(
low
_precision_
)
<<
" )"
;
<<
" )"
;
}
}
}
}
...
@@ -324,37 +344,39 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
...
@@ -324,37 +344,39 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
}
}
}
}
void
FloatToHalf
Pass
::
GetOpPrecision
()
const
{
void
AutoMixedPrecision
Pass
::
GetOpPrecision
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
op_type
=
op_node
->
Op
()
->
Type
();
bool
support_
half
=
true
;
bool
support_
low_precision
=
true
;
if
(
GetOpOriginalType
(
op_type
)
==
"feed"
||
if
(
GetOpOriginalType
(
op_type
)
==
"feed"
||
GetOpOriginalType
(
op_type
)
==
"fetch"
)
{
GetOpOriginalType
(
op_type
)
==
"fetch"
)
{
support_
half
=
!
keep_io_types_
;
support_
low_precision
=
!
keep_io_types_
;
}
else
{
}
else
{
support_
half
=
support_
low_precision
=
OpSupportPrecision
(
OpSupportPrecision
(
GetOpOriginalType
(
op_type
),
half_precision
_
);
GetOpOriginalType
(
op_type
),
backend_
,
low_precision_
,
black_list
_
);
}
}
if
(
op_node
->
Op
()
->
HasAttr
(
"dtype"
))
{
if
(
op_node
->
Op
()
->
HasAttr
(
"dtype"
))
{
auto
dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dtype"
);
auto
dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dtype"
);
support_
half
=
support_
low_precision
=
support_low_precision
&&
support_half
&&
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
dtype
));
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
dtype
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_dtype"
))
{
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_dtype"
))
{
auto
out_dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"out_dtype"
);
auto
out_dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"out_dtype"
);
support_half
=
support_low_precision
=
support_half
&&
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
out_dtype
));
support_low_precision
&&
IsFloatType
(
static_cast
<
VarType
::
Type
>
(
out_dtype
));
}
else
{
}
else
{
// if op's input var and output var is not dense tensor, the op should
// if op's input var and output var is not dense tensor, the op should
// not run
half
.
// not run
at low precision
.
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
auto
*
real_in_var_node
=
real_vars_
[
in_var_node
->
Var
()
->
Name
()];
auto
*
real_in_var_node
=
real_vars_
[
in_var_node
->
Var
()
->
Name
()];
if
(
real_in_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
real_in_var_node
->
Var
()
->
Persistable
())
continue
;
support_half
=
support_half
&&
(
real_in_var_node
->
Var
()
->
GetType
()
==
support_low_precision
=
VarType
::
LOD_TENSOR
);
support_low_precision
&&
(
real_in_var_node
->
Var
()
->
GetType
()
==
VarType
::
LOD_TENSOR
);
}
}
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
...
@@ -362,23 +384,25 @@ void FloatToHalfPass::GetOpPrecision() const {
...
@@ -362,23 +384,25 @@ void FloatToHalfPass::GetOpPrecision() const {
auto
*
real_out_var_node
=
real_vars_
[
out_var_node
->
Var
()
->
Name
()];
auto
*
real_out_var_node
=
real_vars_
[
out_var_node
->
Var
()
->
Name
()];
if
(
real_out_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
real_out_var_node
->
Var
()
->
Persistable
())
continue
;
support_half
=
support_half
&&
(
real_out_var_node
->
Var
()
->
GetType
()
==
support_low_precision
=
VarType
::
LOD_TENSOR
);
support_low_precision
&&
(
real_out_var_node
->
Var
()
->
GetType
()
==
VarType
::
LOD_TENSOR
);
}
}
}
}
if
(
support_
half
)
{
if
(
support_
low_precision
)
{
op_run_
half
_
.
insert
(
op_type
);
op_run_
low_precision
_
.
insert
(
op_type
);
VLOG
(
4
)
<<
"support precision: "
<<
op_type
<<
" run at
half
"
;
VLOG
(
4
)
<<
"support precision: "
<<
op_type
<<
" run at
low precision
"
;
}
else
{
}
else
{
VLOG
(
4
)
<<
"support precision: "
<<
op_type
<<
" not run at half"
;
VLOG
(
4
)
<<
"support precision: "
<<
op_type
<<
" not run at low precision"
;
}
}
}
}
}
}
}
}
void
FloatToHalf
Pass
::
UpdateOpPrecision
()
const
{
void
AutoMixedPrecision
Pass
::
UpdateOpPrecision
()
const
{
std
::
unordered_set
<
std
::
string
>
vars_should_not_
half
;
std
::
unordered_set
<
std
::
string
>
vars_should_not_
low_precision
;
// var -> the var's all input op
// var -> the var's all input op
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
Node
*>>
var_input_ops
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
Node
*>>
var_input_ops
;
...
@@ -401,30 +425,16 @@ void FloatToHalfPass::UpdateOpPrecision() const {
...
@@ -401,30 +425,16 @@ void FloatToHalfPass::UpdateOpPrecision() const {
<<
" is output of "
<<
op_type
;
<<
" is output of "
<<
op_type
;
}
}
// the select_input op's input var should not convert to
half. when
// the select_input op's input var should not convert to
low precision.
//
op's output var is select_input op's input var, the op should not run
//
when op's output var is select_input op's input var, the op should
//
half
.
//
not run at low precision
.
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
==
"select_input"
)
{
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
==
"select_input"
)
{
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
in_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
in_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
!
VarNodeHasDtype
(
in_var_node
))
continue
;
if
(
!
VarNodeHasDtype
(
in_var_node
))
continue
;
vars_should_not_half
.
insert
(
in_var_node
->
Var
()
->
Name
());
vars_should_not_low_precision
.
insert
(
in_var_node
->
Var
()
->
Name
());
}
}
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run half.
if
(
GetOpOriginalType
(
op_type
)
!=
"feed"
&&
!
GpuKernelSupportPrecision
(
GetOpOriginalType
(
op_type
),
phi
::
DataType
::
FLOAT32
))
{
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
if
(
out_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
!
VarNodeHasDtype
(
out_var_node
))
continue
;
vars_should_not_half
.
insert
(
out_var_node
->
Var
()
->
Name
());
}
}
}
}
}
}
...
@@ -437,25 +447,7 @@ void FloatToHalfPass::UpdateOpPrecision() const {
...
@@ -437,25 +447,7 @@ void FloatToHalfPass::UpdateOpPrecision() const {
precision_updated
=
false
;
precision_updated
=
false
;
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
for
(
auto
*
op_node
:
nodes
)
{
if
(
op_run_half_
.
count
(
op_node
->
Op
()
->
Type
())
==
0
)
continue
;
if
(
op_run_low_precision_
.
count
(
op_node
->
Op
()
->
Type
())
==
0
)
continue
;
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
!
VarNodeHasDtype
(
in_var_node
))
continue
;
auto
*
real_in_var_node
=
real_vars_
[
in_var_node
->
Var
()
->
Name
()];
if
(
real_in_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
vars_should_not_half
.
count
(
real_in_var_node
->
Var
()
->
Name
()))
{
op_run_half_
.
erase
(
op_node
->
Op
()
->
Type
());
precision_updated
=
true
;
VLOG
(
4
)
<<
op_node
->
Op
()
->
Type
()
<<
" should not support half precision."
;
break
;
}
}
if
(
op_run_half_
.
count
(
op_node
->
Op
()
->
Type
())
==
0
)
continue
;
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
...
@@ -464,24 +456,25 @@ void FloatToHalfPass::UpdateOpPrecision() const {
...
@@ -464,24 +456,25 @@ void FloatToHalfPass::UpdateOpPrecision() const {
auto
*
real_out_var_node
=
real_vars_
[
out_var_node
->
Var
()
->
Name
()];
auto
*
real_out_var_node
=
real_vars_
[
out_var_node
->
Var
()
->
Name
()];
if
(
real_out_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
real_out_var_node
->
Var
()
->
Persistable
())
continue
;
bool
not_run_
half
=
false
;
bool
not_run_
low_precision
=
false
;
const
auto
&
input_op_nodes
=
const
auto
&
input_op_nodes
=
var_input_ops
[
real_out_var_node
->
Var
()
->
Name
()];
var_input_ops
[
real_out_var_node
->
Var
()
->
Name
()];
if
(
vars_should_not_half
.
count
(
real_out_var_node
->
Var
()
->
Name
()))
{
if
(
vars_should_not_low_precision
.
count
(
not_run_half
=
true
;
real_out_var_node
->
Var
()
->
Name
()))
{
not_run_low_precision
=
true
;
}
else
{
}
else
{
for
(
auto
*
node
:
input_op_nodes
)
{
for
(
auto
*
node
:
input_op_nodes
)
{
if
(
op_run_
half
_
.
count
(
node
->
Op
()
->
Type
())
==
0
)
{
if
(
op_run_
low_precision
_
.
count
(
node
->
Op
()
->
Type
())
==
0
)
{
not_run_
half
=
true
;
not_run_
low_precision
=
true
;
break
;
break
;
}
}
}
}
}
}
if
(
not_run_
half
)
{
if
(
not_run_
low_precision
)
{
op_run_
half
_
.
erase
(
op_node
->
Op
()
->
Type
());
op_run_
low_precision
_
.
erase
(
op_node
->
Op
()
->
Type
());
precision_updated
=
true
;
precision_updated
=
true
;
VLOG
(
4
)
<<
op_node
->
Op
()
->
Type
()
VLOG
(
4
)
<<
op_node
->
Op
()
->
Type
()
<<
" should not
support half
precision."
;
<<
" should not
run at low
precision."
;
break
;
break
;
}
}
}
}
...
@@ -491,8 +484,8 @@ void FloatToHalfPass::UpdateOpPrecision() const {
...
@@ -491,8 +484,8 @@ void FloatToHalfPass::UpdateOpPrecision() const {
}
}
// special ops, its weights should not be low precision.
// special ops, its weights should not be low precision.
bool
FloatToHalfPass
::
InputVarsNotConvert
(
Node
*
op_node
,
bool
AutoMixedPrecisionPass
::
InputVarsNotConvert
(
const
std
::
string
&
var_name
)
const
{
Node
*
op_node
,
const
std
::
string
&
var_name
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
auto
*
op_desc
=
op_node
->
Op
();
if
(
GetOpOriginalType
(
op_desc
->
Type
())
==
"batch_norm"
)
{
if
(
GetOpOriginalType
(
op_desc
->
Type
())
==
"batch_norm"
)
{
auto
vecs
=
op_desc
->
Input
(
"Bias"
);
auto
vecs
=
op_desc
->
Input
(
"Bias"
);
...
@@ -532,8 +525,8 @@ bool FloatToHalfPass::InputVarsNotConvert(Node* op_node,
...
@@ -532,8 +525,8 @@ bool FloatToHalfPass::InputVarsNotConvert(Node* op_node,
return
false
;
return
false
;
}
}
bool
FloatToHalfPass
::
OutputVarsNotConvert
(
Node
*
op_node
,
bool
AutoMixedPrecisionPass
::
OutputVarsNotConvert
(
const
std
::
string
&
var_name
)
const
{
Node
*
op_node
,
const
std
::
string
&
var_name
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
auto
*
op_desc
=
op_node
->
Op
();
// batch_norm's input and output (variance and mean) are the same.
// batch_norm's input and output (variance and mean) are the same.
if
(
GetOpOriginalType
(
op_desc
->
Type
())
==
"batch_norm"
)
{
if
(
GetOpOriginalType
(
op_desc
->
Type
())
==
"batch_norm"
)
{
...
@@ -557,10 +550,14 @@ bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node,
...
@@ -557,10 +550,14 @@ bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node,
return
false
;
return
false
;
}
}
void
FloatToHalf
Pass
::
SetVarPrecision
()
const
{
void
AutoMixedPrecision
Pass
::
SetVarPrecision
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
for
(
auto
*
op_node
:
nodes
)
{
if
(
op_run_half_
.
count
(
op_node
->
Op
()
->
Type
()))
{
if
(
op_run_low_precision_
.
count
(
op_node
->
Op
()
->
Type
())
==
0
)
{
continue
;
}
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
!=
"feed"
)
{
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
...
@@ -573,11 +570,13 @@ void FloatToHalfPass::SetVarPrecision() const {
...
@@ -573,11 +570,13 @@ void FloatToHalfPass::SetVarPrecision() const {
if
(
real_in_var_node
->
Var
()
->
Persistable
())
{
if
(
real_in_var_node
->
Var
()
->
Persistable
())
{
real_in_var_node
->
Var
()
->
SetDataType
(
real_in_var_node
->
Var
()
->
SetDataType
(
framework
::
TransToProtoVarType
(
half_precision_
));
framework
::
TransToProtoVarType
(
low_precision_
));
vars_convert_to_half_
.
insert
(
in_var_name
);
vars_convert_to_low_precision_
.
insert
(
in_var_name
);
}
}
}
}
}
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
!=
"fetch"
)
{
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
...
@@ -589,9 +588,9 @@ void FloatToHalfPass::SetVarPrecision() const {
...
@@ -589,9 +588,9 @@ void FloatToHalfPass::SetVarPrecision() const {
if
(
OutputVarsNotConvert
(
op_node
,
out_var_name
))
continue
;
if
(
OutputVarsNotConvert
(
op_node
,
out_var_name
))
continue
;
real_out_var_node
->
Var
()
->
SetDataType
(
real_out_var_node
->
Var
()
->
SetDataType
(
framework
::
TransToProtoVarType
(
half
_precision_
));
framework
::
TransToProtoVarType
(
low
_precision_
));
if
(
real_out_var_node
->
Var
()
->
Persistable
())
{
if
(
real_out_var_node
->
Var
()
->
Persistable
())
{
vars_convert_to_
half
_
.
insert
(
out_var_name
);
vars_convert_to_
low_precision
_
.
insert
(
out_var_name
);
}
}
}
}
}
}
...
@@ -606,24 +605,24 @@ void FloatToHalfPass::SetVarPrecision() const {
...
@@ -606,24 +605,24 @@ void FloatToHalfPass::SetVarPrecision() const {
if
(
!
VarNodeHasDtype
(
var_node
))
continue
;
if
(
!
VarNodeHasDtype
(
var_node
))
continue
;
auto
var_name
=
var_node
->
Var
()
->
Name
();
auto
var_name
=
var_node
->
Var
()
->
Name
();
if
(
vars_convert_to_
half
_
.
count
(
var_name
))
{
if
(
vars_convert_to_
low_precision
_
.
count
(
var_name
))
{
var_node
->
Var
()
->
SetDataType
(
var_node
->
Var
()
->
SetDataType
(
framework
::
TransToProtoVarType
(
half
_precision_
));
framework
::
TransToProtoVarType
(
low
_precision_
));
}
}
}
}
}
}
}
}
void
FloatToHalf
Pass
::
ConvertWeightsData
()
const
{
void
AutoMixedPrecision
Pass
::
ConvertWeightsData
()
const
{
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
scope
,
scope
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"During the float to half pass, the scope should not be null."
));
"During the auto_mixed_precision_pass, the scope "
"should not be null."
));
auto
var_names
=
scope
->
LocalVarNames
();
auto
var_names
=
scope
->
LocalVarNames
();
for
(
const
auto
&
var_name
:
var_names
)
{
for
(
const
auto
&
var_name
:
var_names
)
{
if
(
vars_convert_to_
half
_
.
count
(
var_name
))
{
if
(
vars_convert_to_
low_precision
_
.
count
(
var_name
))
{
VLOG
(
4
)
<<
var_name
<<
"'s data type was convert to half"
;
VLOG
(
4
)
<<
var_name
<<
"'s data type was convert to half"
;
auto
*
var
=
scope
->
FindLocalVar
(
var_name
);
auto
*
var
=
scope
->
FindLocalVar
(
var_name
);
...
@@ -631,25 +630,29 @@ void FloatToHalfPass::ConvertWeightsData() const {
...
@@ -631,25 +630,29 @@ void FloatToHalfPass::ConvertWeightsData() const {
auto
*
origin_tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
origin_tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensor
half
_tensor
;
phi
::
DenseTensor
low_precision
_tensor
;
half
_tensor
.
Resize
(
origin_tensor
->
dims
());
low_precision
_tensor
.
Resize
(
origin_tensor
->
dims
());
half_tensor
.
set_type
(
half
_precision_
);
low_precision_tensor
.
set_type
(
low
_precision_
);
if
(
half_precision_
==
phi
::
DataType
::
FLOAT16
)
{
if
(
low_precision_
==
phi
::
DataType
::
FLOAT16
)
{
auto
*
half_data
=
auto
*
low_precision_data
=
half_tensor
.
mutable_data
<
phi
::
dtype
::
float16
>
(
phi
::
CPUPlace
{});
low_precision_tensor
.
mutable_data
<
phi
::
dtype
::
float16
>
(
phi
::
CPUPlace
{});
for
(
int64_t
i
=
0
;
i
<
origin_tensor
->
numel
();
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
origin_tensor
->
numel
();
i
++
)
{
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT64
)
{
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT64
)
{
auto
*
origin_data
=
origin_tensor
->
data
<
double
>
();
auto
*
origin_data
=
origin_tensor
->
data
<
double
>
();
half_data
[
i
]
=
static_cast
<
phi
::
dtype
::
float16
>
(
origin_data
[
i
]);
low_precision_data
[
i
]
=
static_cast
<
phi
::
dtype
::
float16
>
(
origin_data
[
i
]);
}
else
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
}
else
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
auto
*
origin_data
=
origin_tensor
->
data
<
float
>
();
auto
*
origin_data
=
origin_tensor
->
data
<
float
>
();
half_data
[
i
]
=
static_cast
<
phi
::
dtype
::
float16
>
(
origin_data
[
i
]);
low_precision_data
[
i
]
=
static_cast
<
phi
::
dtype
::
float16
>
(
origin_data
[
i
]);
}
}
}
}
}
else
if
(
half
_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
}
else
if
(
low
_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
auto
*
half_data
=
auto
*
half_data
=
half_tensor
.
mutable_data
<
phi
::
dtype
::
bfloat16
>
(
phi
::
CPUPlace
{});
low_precision_tensor
.
mutable_data
<
phi
::
dtype
::
bfloat16
>
(
phi
::
CPUPlace
{});
for
(
int64_t
i
=
0
;
i
<
origin_tensor
->
numel
();
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
origin_tensor
->
numel
();
i
++
)
{
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT64
)
{
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT64
)
{
auto
*
origin_data
=
origin_tensor
->
data
<
double
>
();
auto
*
origin_data
=
origin_tensor
->
data
<
double
>
();
...
@@ -662,12 +665,12 @@ void FloatToHalfPass::ConvertWeightsData() const {
...
@@ -662,12 +665,12 @@ void FloatToHalfPass::ConvertWeightsData() const {
}
}
origin_tensor
->
clear
();
origin_tensor
->
clear
();
paddle
::
framework
::
TensorCopySync
(
paddle
::
framework
::
TensorCopySync
(
half
_tensor
,
phi
::
CPUPlace
{},
origin_tensor
);
low_precision
_tensor
,
phi
::
CPUPlace
{},
origin_tensor
);
}
}
}
}
}
}
void
FloatToHalf
Pass
::
InsertCastOp
()
const
{
void
AutoMixedPrecision
Pass
::
InsertCastOp
()
const
{
int
suffix
=
0
;
int
suffix
=
0
;
std
::
unordered_map
<
Node
*
,
Node
*>
cache
;
std
::
unordered_map
<
Node
*
,
Node
*>
cache
;
...
@@ -681,7 +684,7 @@ void FloatToHalfPass::InsertCastOp() const {
...
@@ -681,7 +684,7 @@ void FloatToHalfPass::InsertCastOp() const {
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
continue
;
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
continue
;
VLOG
(
4
)
<<
"process op: "
<<
op_type
VLOG
(
4
)
<<
"process op: "
<<
op_type
<<
" run
half: "
<<
op_run_half
_
.
count
(
op_type
);
<<
" run
low precision: "
<<
op_run_low_precision
_
.
count
(
op_type
);
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
inputs
)
{
for
(
auto
*
in_var_node
:
inputs
)
{
...
@@ -696,17 +699,17 @@ void FloatToHalfPass::InsertCastOp() const {
...
@@ -696,17 +699,17 @@ void FloatToHalfPass::InsertCastOp() const {
VLOG
(
4
)
<<
"process var: "
<<
real_in_var_node
->
Var
()
->
Name
()
VLOG
(
4
)
<<
"process var: "
<<
real_in_var_node
->
Var
()
->
Name
()
<<
" with type "
<<
in_var_type
;
<<
" with type "
<<
in_var_type
;
if
(
IsFloatType
(
in_var_type
)
&&
op_run_
half
_
.
count
(
op_type
))
{
if
(
IsFloatType
(
in_var_type
)
&&
op_run_
low_precision
_
.
count
(
op_type
))
{
DoInsertCastOp
(
subgraphes_
[
i
],
DoInsertCastOp
(
subgraphes_
[
i
],
in_var_node
,
in_var_node
,
op_node
,
op_node
,
in_var_type
,
in_var_type
,
framework
::
TransToProtoVarType
(
half
_precision_
),
framework
::
TransToProtoVarType
(
low
_precision_
),
block_desc
,
block_desc
,
&
suffix
,
&
suffix
,
&
cache
);
&
cache
);
}
else
if
(
IsHalfType
(
in_var_type
)
&&
}
else
if
(
IsHalfType
(
in_var_type
)
&&
op_run_
half
_
.
count
(
op_type
)
==
0
)
{
op_run_
low_precision
_
.
count
(
op_type
)
==
0
)
{
DoInsertCastOp
(
subgraphes_
[
i
],
DoInsertCastOp
(
subgraphes_
[
i
],
in_var_node
,
in_var_node
,
op_node
,
op_node
,
...
@@ -738,4 +741,5 @@ void FloatToHalfPass::InsertCastOp() const {
...
@@ -738,4 +741,5 @@ void FloatToHalfPass::InsertCastOp() const {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
float_to_half_pass
,
paddle
::
framework
::
ir
::
FloatToHalfPass
);
REGISTER_PASS
(
auto_mixed_precision_pass
,
paddle
::
framework
::
ir
::
AutoMixedPrecisionPass
);
paddle/fluid/framework/ir/
float_to_half
_pass.h
→
paddle/fluid/framework/ir/
auto_mixed_precision
_pass.h
浏览文件 @
28ea9aad
...
@@ -27,13 +27,13 @@ namespace paddle {
...
@@ -27,13 +27,13 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
FloatToHalf
Pass
:
public
FusePassBase
{
class
AutoMixedPrecision
Pass
:
public
FusePassBase
{
public:
public:
using
VarType
=
framework
::
proto
::
VarType
;
using
VarType
=
framework
::
proto
::
VarType
;
public:
public:
FloatToHalf
Pass
()
=
default
;
AutoMixedPrecision
Pass
()
=
default
;
~
FloatToHalf
Pass
()
=
default
;
~
AutoMixedPrecision
Pass
()
=
default
;
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
...
@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase {
...
@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase {
void
SetDefaultBlacklist
()
const
;
void
SetDefaultBlacklist
()
const
;
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
DataType
precision
,
phi
::
Backend
backend
=
phi
::
Backend
::
GPU
)
const
;
void
SetOpUniqueType
()
const
;
void
SetOpUniqueType
()
const
;
void
RestoreOpOriginType
()
const
;
void
RestoreOpOriginType
()
const
;
...
@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase {
...
@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase {
void
ConvertWeightsData
()
const
;
void
ConvertWeightsData
()
const
;
private:
private:
mutable
bool
keep_io_types_
;
mutable
bool
skip_pass_
{
false
};
mutable
bool
keep_io_types_
{
false
};
// float16 or bfloat16 now
// float16 or bfloat16 now
mutable
phi
::
DataType
half_precision_
;
mutable
phi
::
DataType
low_precision_
{
phi
::
DataType
::
FLOAT16
};
mutable
phi
::
Backend
backend_
{
phi
::
Backend
::
GPU
};
mutable
std
::
unordered_set
<
std
::
string
>
black_list_
;
mutable
std
::
unordered_set
<
std
::
string
>
black_list_
;
...
@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase {
...
@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase {
mutable
std
::
vector
<
std
::
vector
<
Node
*>>
all_op_nodes_
;
mutable
std
::
vector
<
std
::
vector
<
Node
*>>
all_op_nodes_
;
// op's unique type -> the op's origin type
// op's unique type -> the op's origin type
mutable
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_original_type_
;
mutable
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_original_type_
;
// op's unique type -> whether the op run at
half
precision
// op's unique type -> whether the op run at
low
precision
mutable
std
::
unordered_set
<
std
::
string
>
op_run_
half
_
;
mutable
std
::
unordered_set
<
std
::
string
>
op_run_
low_precision
_
;
mutable
std
::
unordered_set
<
std
::
string
>
vars_convert_to_
half
_
;
mutable
std
::
unordered_set
<
std
::
string
>
vars_convert_to_
low_precision
_
;
};
};
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
void
DoInsertCastOp
(
Graph
*
graph
,
Node
*
var_node
,
Node
*
op_node
,
proto
::
VarType
::
Type
from_type
,
proto
::
VarType
::
Type
to_type
,
framework
::
BlockDesc
*
block_desc
,
int
*
suffix
,
std
::
unordered_map
<
Node
*
,
Node
*>*
cache
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
浏览文件 @
28ea9aad
...
@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
bool
is_fp16_precision
=
bool
is_fp16_precision
=
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
phi
::
DataType
::
FLOAT16
||
phi
::
DataType
::
FLOAT16
||
Get
<
bool
>
(
"enable_gpu_
half
"
);
Get
<
bool
>
(
"enable_gpu_
mixed
"
);
bool
cutlass_enable
=
false
;
bool
cutlass_enable
=
false
;
#ifdef PADDLE_WITH_CUTLASS
#ifdef PADDLE_WITH_CUTLASS
...
...
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
浏览文件 @
28ea9aad
...
@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
bool
is_fp16_precision
=
bool
is_fp16_precision
=
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
phi
::
DataType
::
FLOAT16
||
phi
::
DataType
::
FLOAT16
||
Get
<
bool
>
(
"enable_gpu_
half
"
);
Get
<
bool
>
(
"enable_gpu_
mixed
"
);
constexpr
int
CUTLASS_NHWC_ALIGNMENT
=
8
;
constexpr
int
CUTLASS_NHWC_ALIGNMENT
=
8
;
if
(
is_fp16_precision
)
{
if
(
is_fp16_precision
)
{
#ifdef PADDLE_WITH_CUTLASS
#ifdef PADDLE_WITH_CUTLASS
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
28ea9aad
...
@@ -365,7 +365,7 @@ struct Argument {
...
@@ -365,7 +365,7 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
mixed_black_list
,
DECL_ARGUMENT_FIELD
(
mixed_black_list
,
MixedBlackList
,
MixedBlackList
,
std
::
unordered_set
<
std
::
string
>
);
std
::
unordered_set
<
std
::
string
>
);
DECL_ARGUMENT_FIELD
(
enable_gpu_
half
,
EnableGPUHalf
,
bool
);
DECL_ARGUMENT_FIELD
(
enable_gpu_
mixed
,
EnableGPUMixed
,
bool
);
DECL_ARGUMENT_FIELD
(
mixed_precision_mode
,
MixedPrecisionMode
,
int
);
DECL_ARGUMENT_FIELD
(
mixed_precision_mode
,
MixedPrecisionMode
,
int
);
// cinn compiler related
// cinn compiler related
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
28ea9aad
...
@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) {
...
@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) {
void
IRPassManager
::
CreatePasses
(
Argument
*
argument
,
void
IRPassManager
::
CreatePasses
(
Argument
*
argument
,
const
std
::
vector
<
std
::
string
>
&
passes
)
{
const
std
::
vector
<
std
::
string
>
&
passes
)
{
// For graph_viz_pass
std
::
string
pre_pass
;
std
::
string
pre_pass
;
int
pass_num
=
0
;
int
pass_num
=
0
;
for
(
const
std
::
string
&
pass_name
:
passes
)
{
for
(
const
std
::
string
&
pass_name
:
passes
)
{
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
pass
->
Set
(
"use_varseqlen"
,
new
bool
(
argument
->
tensorrt_use_varseqlen
()));
pass
->
Set
(
"use_varseqlen"
,
new
bool
(
argument
->
tensorrt_use_varseqlen
()));
...
@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument
->
tensorrt_tuned_dynamic_shape
();
argument
->
tensorrt_tuned_dynamic_shape
();
pass
->
Set
(
"with_dynamic_shape"
,
new
bool
(
with_dynamic_shape
));
pass
->
Set
(
"with_dynamic_shape"
,
new
bool
(
with_dynamic_shape
));
// mixed precision related
// Mixed precision related.
pass
->
Set
(
"model_precision"
,
new
int
(
argument
->
model_precision
()));
pass
->
Set
(
pass
->
Set
(
"mixed_black_list"
,
"mixed_black_list"
,
new
std
::
unordered_set
<
std
::
string
>
(
argument
->
mixed_black_list
()));
new
std
::
unordered_set
<
std
::
string
>
(
argument
->
mixed_black_list
()));
pass
->
Set
(
"enable_gpu_
half"
,
new
bool
(
argument
->
enable_gpu_half
()));
pass
->
Set
(
"enable_gpu_
mixed"
,
new
bool
(
argument
->
enable_gpu_mixed
()));
pass
->
Set
(
"mixed_precision_mode"
,
pass
->
Set
(
"mixed_precision_mode"
,
new
int
(
argument
->
mixed_precision_mode
()));
new
int
(
argument
->
mixed_precision_mode
()));
pass
->
Set
(
"model_precision"
,
new
int
(
argument
->
model_precision
()));
if
(
pass_name
==
"graph_viz_pass"
)
{
if
(
pass_name
==
"graph_viz_pass"
)
{
std
::
string
optim_cache_dir
=
argument
->
optim_cache_dir
();
std
::
string
optim_cache_dir
=
argument
->
optim_cache_dir
();
...
@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new
std
::
vector
<
std
::
string
>
(
argument
->
tensorrt_disabled_ops
()));
new
std
::
vector
<
std
::
string
>
(
argument
->
tensorrt_disabled_ops
()));
pass
->
Set
(
"trt_use_dla"
,
new
bool
(
argument
->
tensorrt_use_dla
()));
pass
->
Set
(
"trt_use_dla"
,
new
bool
(
argument
->
tensorrt_use_dla
()));
pass
->
Set
(
"trt_dla_core"
,
new
int
(
argument
->
tensorrt_dla_core
()));
pass
->
Set
(
"trt_dla_core"
,
new
int
(
argument
->
tensorrt_dla_core
()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not run fp16.
// not run fp16.
pass
->
Set
(
"disable_trt_plugin_fp16"
,
pass
->
Set
(
"disable_trt_plugin_fp16"
,
...
@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"root_predictor_id"
,
new
int
(
argument
->
root_predictor_id
()));
pass
->
Set
(
"root_predictor_id"
,
new
int
(
argument
->
root_predictor_id
()));
}
else
if
(
pass_name
==
"build_cinn_pass"
)
{
}
else
if
(
pass_name
==
"build_cinn_pass"
)
{
pass
->
Set
(
"is_inference_stage"
,
new
bool
(
argument
->
use_cinn_compiler
()));
pass
->
Set
(
"is_inference_stage"
,
new
bool
(
argument
->
use_cinn_compiler
()));
}
}
else
if
(
pass_name
==
"lite_subgraph_pass"
)
{
if
(
pass_name
==
"lite_subgraph_pass"
)
{
bool
lite_enable_int8
=
bool
lite_enable_int8
=
argument
->
lite_precision_mode
()
==
AnalysisConfig
::
Precision
::
kInt8
;
argument
->
lite_precision_mode
()
==
AnalysisConfig
::
Precision
::
kInt8
;
pass
->
Set
(
"program"
,
pass
->
Set
(
"program"
,
...
@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"nnadapter_model_cache_token"
,
pass
->
Set
(
"nnadapter_model_cache_token"
,
new
std
::
vector
<
std
::
string
>
(
new
std
::
vector
<
std
::
string
>
(
argument
->
nnadapter_model_cache_token
()));
argument
->
nnadapter_model_cache_token
()));
}
}
else
if
(
pass_name
==
"fc_fuse_pass"
)
{
if
(
pass_name
==
"fc_fuse_pass"
)
{
pass
->
Set
(
"use_gpu"
,
new
bool
(
argument
->
use_gpu
()));
pass
->
Set
(
"use_gpu"
,
new
bool
(
argument
->
use_gpu
()));
bool
fc_mkldnn_pass
=
0
;
bool
fc_mkldnn_pass
=
0
;
for
(
const
std
::
string
&
pass_n
:
passes
)
{
for
(
const
std
::
string
&
pass_n
:
passes
)
{
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
28ea9aad
...
@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph,
...
@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph,
backend
,
backend
,
precision
,
precision
,
blacklist
))
{
blacklist
))
{
Add
CastOp
(
graph
,
Insert
CastOp
(
graph
,
var_node
,
var_node
,
next_op
,
next_op
,
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
to_type
,
to_type
,
&
suffix
,
block_desc
,
block_desc
,
&
suffix
,
&
var_to_cast_op_map
);
&
var_to_cast_op_map
);
var_node
->
Var
()
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
var_node
->
Var
()
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
}
}
...
...
paddle/fluid/inference/analysis/passes/CMakeLists.txt
浏览文件 @
28ea9aad
...
@@ -13,7 +13,7 @@ cc_library(
...
@@ -13,7 +13,7 @@ cc_library(
cc_library
(
cc_library
(
convert_to_mixed_precision
convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc
SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass
)
DEPS analysis_pass ir_graph_build_pass
auto_mixed_precision_pass
)
cc_library
(
cc_library
(
ir_params_sync_among_devices_pass
ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc
SRCS ir_params_sync_among_devices_pass.cc
...
...
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
28ea9aad
...
@@ -14,82 +14,17 @@
...
@@ -14,82 +14,17 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
namespace
{
ConvertToMixedPrecisionPass
::
ConvertToMixedPrecisionPass
(
using
VarType
=
framework
::
proto
::
VarType
;
bool
PhiKernelSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
DataType
data_type
,
phi
::
DataLayout
layout
=
phi
::
DataLayout
::
ALL_LAYOUT
)
{
auto
kernels
=
phi
::
KernelFactory
::
Instance
().
kernels
();
if
(
kernels
.
find
(
op_type
)
==
kernels
.
end
())
{
return
false
;
}
phi
::
KernelKey
kernel_key
(
backend
,
layout
,
data_type
);
return
phi
::
KernelFactory
::
Instance
().
HasKernel
(
op_type
,
kernel_key
);
}
bool
GpuKernelSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
DataType
data_type
,
phi
::
DataLayout
layout
=
phi
::
DataLayout
::
ALL_LAYOUT
)
{
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
bool
res
=
PhiKernelSupportPrecision
(
phi_op_type
,
phi
::
Backend
::
GPU
,
data_type
,
layout
);
res
|=
PhiKernelSupportPrecision
(
phi_op_type
,
phi
::
Backend
::
GPUDNN
,
data_type
,
layout
);
if
(
!
res
)
{
auto
&
all_kernels
=
framework
::
OperatorWithKernel
::
AllOpKernels
();
auto
it
=
all_kernels
.
find
(
op_type
);
if
(
it
!=
all_kernels
.
end
())
{
for
(
auto
&
kern_pair
:
it
->
second
)
{
if
(
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
)
&&
kern_pair
.
first
.
data_type_
==
VarType
::
FP16
)
{
res
=
true
;
break
;
}
}
}
}
return
res
;
}
class
ConvertToMixedPrecisionPass
{
using
BlockID
=
size_t
;
public:
explicit
ConvertToMixedPrecisionPass
(
const
std
::
string
&
model_file
,
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_model_file
,
...
@@ -105,571 +40,46 @@ class ConvertToMixedPrecisionPass {
...
@@ -105,571 +40,46 @@ class ConvertToMixedPrecisionPass {
mixed_precision_
(
mixed_precision
),
mixed_precision_
(
mixed_precision
),
backend_
(
backend
),
backend_
(
backend
),
keep_io_types_
(
keep_io_types
),
keep_io_types_
(
keep_io_types
),
black_list_
(
black_list
),
black_list_
(
black_list
)
{
place_
(
paddle
::
CPUPlace
()),
if
(
mixed_precision_
!=
phi
::
DataType
::
FLOAT16
&&
executor_
(
place_
)
{
mixed_precision_
!=
phi
::
DataType
::
BFLOAT16
)
{
VLOG
(
4
)
<<
"black_list has "
;
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
for
(
auto
&
name
:
black_list_
)
{
"mixed_precision currently not supported dtype %d, we now only "
VLOG
(
4
)
<<
" - "
<<
name
;
"support fp16 and bf16."
,
}
static_cast
<
int
>
(
mixed_precision_
)));
}
void
Run
();
private:
void
LoadAndPrepare
();
inline
bool
VarNodeHasDtype
(
framework
::
ir
::
Node
*
node
);
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
);
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
);
void
SaveMixedModel
();
void
ConvertTensorDtype
(
BlockID
block_idx
);
void
ProcessInputNode
(
bool
support_precision
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
VarType
::
Type
to_type
,
BlockID
block_idx
);
void
ProcessOutputNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
);
inline
bool
IsFloatVarType
(
VarType
::
Type
type
);
bool
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealVarNode
(
framework
::
ir
::
Node
*
node
);
// Fallback to fp32 dtype when encounter circle (Not a DAG graph).
void
ProcessCircleCases
();
private:
std
::
string
model_file_
;
std
::
string
params_file_
;
std
::
string
mixed_model_file_
;
std
::
string
mixed_params_file_
;
phi
::
DataType
mixed_precision_
;
phi
::
Backend
backend_
;
bool
keep_io_types_
;
std
::
unordered_set
<
std
::
string
>
black_list_
;
paddle
::
CPUPlace
place_
;
framework
::
Executor
executor_
;
framework
::
Scope
scope_
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
name2node_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
int
suffix_
{
0
};
std
::
set
<
std
::
string
>
var_names_in_circles_
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
if
(
name2node_
.
count
(
var_node
->
Name
()))
return
name2node_
[
var_node
->
Name
()];
return
var_node
;
}
inline
bool
ConvertToMixedPrecisionPass
::
VarNodeHasDtype
(
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
auto
type
=
var_node
->
Var
()
->
GetType
();
return
(
type
==
VarType
::
SELECTED_ROWS
)
||
(
type
==
VarType
::
LOD_TENSOR
)
||
(
type
==
VarType
::
LOD_TENSOR_ARRAY
)
||
(
type
==
VarType
::
STRINGS
)
||
(
type
==
VarType
::
VOCAB
);
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
VarType
::
Type
to_type
,
BlockID
block_idx
)
{
if
(
!
in_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
graph
=
graphes_
[
block_idx
];
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
in_var
->
SetDataType
(
to_type
);
in_var_type
=
to_type
;
VLOG
(
3
)
<<
" in_node name "
<<
in_var
->
Name
()
<<
" from "
<<
prev_type
<<
" to "
<<
to_type
;
}
else
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var_type
)
&&
in_var_type
!=
to_type
)
{
AddCastOp
(
graph
,
in_node
,
op_node
,
in_var_type
,
to_type
,
suffix
,
block_desc
,
&
cast_map_
);
VLOG
(
3
)
<<
" in_node name "
<<
in_var
->
Name
()
<<
"("
<<
prev_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
to_type
<<
")"
;
}
}
else
{
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var_type
)
&&
in_var_type
!=
to_type
)
{
AddCastOp
(
graph
,
in_node
,
op_node
,
in_var_type
,
to_type
,
suffix
,
block_desc
,
&
cast_map_
);
VLOG
(
3
)
<<
" in_node name "
<<
in_var
->
Name
()
<<
"("
<<
prev_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
to_type
<<
")"
;
}
}
}
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
if
(
!
var_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
var_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
if
(
out_var
->
GetDataType
()
==
VarType
::
FP32
)
{
if
(
OutShouldNotConvert
(
var_node
))
return
;
out_var
->
SetDataType
(
to_type
);
}
VLOG
(
3
)
<<
" out_node name "
<<
var_node
->
Name
()
<<
" from dtype "
<<
prev_type
<<
" to "
<<
out_var
->
GetDataType
();
}
// Just process special cases.
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_node
=
var_node
->
inputs
[
0
];
auto
*
op_desc
=
op_node
->
Op
();
// batch_norm's input and output (variance and mean) are the same.
if
(
op_desc
->
Type
()
==
"batch_norm"
)
{
auto
vecs
=
op_desc
->
Output
(
"MeanOut"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Output
(
"VarianceOut"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Output
(
"SavedMean"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Output
(
"SavedVariance"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
}
return
false
;
}
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_nodes
=
var_node
->
outputs
;
for
(
auto
*
op_node
:
op_nodes
)
{
auto
*
op_desc
=
op_node
->
Op
();
// batch_norm op's bias, mean, scale and variance just be float32, so we can
// not convert the dtype.
if
(
op_desc
->
Type
()
==
"batch_norm"
)
{
auto
vecs
=
op_desc
->
Input
(
"Bias"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"Mean"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"Scale"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"Variance"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
}
else
if
(
op_desc
->
Type
()
==
"fused_multi_transformer"
)
{
auto
vecs
=
op_desc
->
Input
(
"LnScale"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"LnBias"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"FFNLnScale"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"FFNLnBias"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_node
->
Name
())
!=
vecs
.
end
())
{
return
true
;
}
}
}
if
(
backend_
!=
phi
::
Backend
::
GPU
)
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"mixed_precision currently not supported place %d, we now only "
"support gpu."
,
static_cast
<
int
>
(
backend_
)));
}
}
return
false
;
}
}
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
VarType
::
Type
type
)
{
void
ConvertToMixedPrecisionPass
::
LoadModel
()
{
return
(
type
==
VarType
::
FP16
)
||
(
type
==
VarType
::
FP32
)
||
framework
::
Executor
exe
{
platform
::
CPUPlace
{}};
(
type
==
VarType
::
BF16
);
}
void
ConvertToMixedPrecisionPass
::
LoadAndPrepare
()
{
auto
program_desc
=
inference
::
Load
(
&
exe
,
&
scope_
,
model_file_
,
params_file_
);
program_desc_
=
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc_
));
new
framework
::
ir
::
Graph
(
*
program_desc
));
main_graph_
->
SetNotOwned
(
framework
::
ir
::
kParamScopeAttr
,
&
scope_
);
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
if
(
!
name2node_
.
count
(
node
->
Name
()))
{
name2node_
[
node
->
Name
()]
=
node
;
}
}
}
ProcessCircleCases
();
}
// Find var names which in circles.
void
ConvertToMixedPrecisionPass
::
ProcessCircleCases
()
{
std
::
vector
<
std
::
string
>
vars_in_circles
;
for
(
size_t
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
// TODO(inference): batch_norm has circle, but we need to fuse it in conv
// op.
if
(
op
->
Type
()
==
"batch_norm"
)
continue
;
const
auto
&
in_names
=
op
->
InputArgumentNames
();
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
std
::
set
<
std
::
string
>
in_names_set
(
in_names
.
begin
(),
in_names
.
end
());
std
::
set
<
std
::
string
>
out_names_set
(
out_names
.
begin
(),
out_names
.
end
());
std
::
set_intersection
(
in_names_set
.
begin
(),
in_names_set
.
end
(),
out_names_set
.
begin
(),
out_names_set
.
end
(),
std
::
back_inserter
(
vars_in_circles
));
}
}
for
(
auto
&
name
:
vars_in_circles
)
{
var_names_in_circles_
.
insert
(
name
);
}
for
(
auto
&
name
:
var_names_in_circles_
)
{
LOG
(
INFO
)
<<
name
<<
" in circles, so we will skip process those vars and ops."
;
}
}
inline
void
ProcessConstantOpAttr
(
framework
::
ir
::
Node
*
op_node
,
VarType
::
Type
from_type
,
VarType
::
Type
to_type
)
{
if
(
!
op_node
->
IsOp
())
return
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
return
;
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
to_type
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
to_type
));
}
}
void
ConvertToMixedPrecisionPass
::
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP64
,
VarType
::
FP32
);
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
VarType
::
FP64
)
{
in_var
->
SetDataType
(
VarType
::
FP32
);
}
}
}
}
}
void
ConvertToMixedPrecisionPass
::
Run
()
{
void
ConvertToMixedPrecisionPass
::
Run
()
{
Load
AndPrepare
();
Load
Model
();
for
(
size_t
i
=
0
;
i
<
graphes_
.
size
();
++
i
)
{
framework
::
ir
::
AutoMixedPrecisionPass
pass
;
auto
*
graph
=
graphes_
[
i
];
pass
.
Set
(
"mixed_precision_mode"
,
new
int
{
static_cast
<
int
>
(
mixed_precision_
)});
VLOG
(
2
)
<<
" -------- handle subgraph "
<<
i
<<
", has "
pass
.
Set
(
"mixed_black_list"
,
<<
graph
->
Nodes
().
size
()
<<
" nodes --------"
;
new
std
::
unordered_set
<
std
::
string
>
{
black_list_
});
pass
.
Set
(
"enable_gpu_mixed"
,
new
bool
{
true
});
pass
.
Set
(
"keep_io_types"
,
new
bool
{
keep_io_types_
});
ConvertAllFp64ToFp32
(
graph
);
pass
.
Apply
(
main_graph_
.
get
());
ConvertTensorDtype
(
i
);
FixCastAttr
(
graph
);
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
}
SaveMixedModel
();
SaveMixedModel
();
}
}
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
BlockID
block_idx
)
{
auto
*
graph
=
graphes_
[
block_idx
];
VarType
::
Type
to_type
;
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
to_type
=
VarType
::
FP16
;
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
to_type
=
VarType
::
BF16
;
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"mixed_precision currently not supported dtype %d, we now only "
"support fp16 and bf16."
,
static_cast
<
int
>
(
mixed_precision_
)));
}
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
auto
*
block_desc
=
op_nodes
[
0
]
->
Op
()
->
Block
();
int
num_low_precision
=
0
;
std
::
vector
<
framework
::
ir
::
Node
*>
output_nodes
;
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
VLOG
(
3
)
<<
"-------------------- op_type "
<<
op_type
<<
", phi_type "
<<
phi
::
TransToPhiKernelName
(
op_type
);
// 1. set input dtype.
if
(
op_type
==
"feed"
)
{
auto
feed_var
=
op_node
->
outputs
[
0
]
->
Var
();
if
(
!
keep_io_types_
&&
feed_var
->
GetDataType
()
==
VarType
::
FP32
)
{
feed_var
->
SetDataType
(
to_type
);
}
}
else
if
(
op_type
==
"fetch"
)
{
auto
*
fetch_var
=
op_node
->
inputs
[
0
];
output_nodes
.
push_back
(
fetch_var
);
continue
;
}
else
if
(
op_type
==
"cast"
)
{
continue
;
}
// We can not add cast operator before ops who have sub_block, as in
// sub_block we may get a var which may be transformer by cast op.
else
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
// NOLINT
continue
;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
else
if
(
black_list_
.
count
(
op_type
)
==
0
)
{
// NOLINT
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
// If op's output in circle, we should not convert to fp16.
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
var_names_in_circles_
.
count
(
out_node
->
Name
()))
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op's output "
<<
out_node
->
Name
()
<<
" is in circle, we can not support this case, just skip."
;
break
;
}
}
// If the op has no input or output of float type, we will not choose the
// low precision kernel.
if
(
support_precision
)
{
bool
has_float_in_out
{
false
};
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
!
in_node
->
IsVar
())
continue
;
if
(
in_node
->
Var
()
->
GetType
()
!=
VarType
::
LOD_TENSOR
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op has tensor array input["
<<
in_node
->
Name
()
<<
"], just skip."
;
break
;
}
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in_out
=
true
;
break
;
}
}
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
!
out_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
out_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in_out
=
true
;
break
;
}
}
if
(
!
has_float_in_out
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op doesn't has float input and output, just skip."
;
}
}
VLOG
(
2
)
<<
"op type: "
<<
op_type
<<
" support low precision: "
<<
support_precision
;
if
(
support_precision
)
{
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP32
,
to_type
);
VLOG
(
2
)
<<
" process input nodes:"
;
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
ProcessInputNode
(
true
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
to_type
,
block_idx
);
}
VLOG
(
2
)
<<
" process output nodes:"
;
auto
outputs
=
op_node
->
outputs
;
for
(
auto
*
out_node
:
outputs
)
{
ProcessOutputNode
(
block_idx
,
out_node
,
to_type
);
}
}
else
{
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
ProcessInputNode
(
false
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
VarType
::
FP32
,
block_idx
);
}
}
}
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
<<
", node input size "
<<
op_node
->
inputs
.
size
();
auto
in_nodes
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
in_nodes
)
{
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
in_node
,
op_node
,
to_type
,
VarType
::
FP32
,
&
suffix_
,
block_desc
,
&
cast_map_
);
VLOG
(
3
)
<<
"-- "
<<
in_node
->
Name
()
<<
"("
<<
to_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
VarType
::
FP32
<<
")"
;
}
}
}
}
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for
(
auto
*
node
:
output_nodes
)
{
framework
::
ir
::
Node
*
fetch_op
{
nullptr
};
for
(
auto
*
op_node
:
node
->
outputs
)
{
if
(
op_node
->
IsOp
()
&&
op_node
->
Op
()
->
Type
()
==
"fetch"
)
{
fetch_op
=
op_node
;
}
}
CHECK_NOTNULL
(
fetch_op
);
auto
*
var
=
node
->
Var
();
if
(
keep_io_types_
&&
var
->
GetDataType
()
==
to_type
)
{
// fp16/bf16 -> fp32.
AddCastOp
(
graph
,
node
,
fetch_op
,
to_type
,
VarType
::
FP32
,
&
suffix_
,
block_desc
,
&
cast_map_
);
}
else
if
(
!
keep_io_types_
&&
var
->
GetDataType
()
==
VarType
::
FP32
)
{
// fp32 -> fp16/bf16
AddCastOp
(
graph
,
node
,
fetch_op
,
VarType
::
FP32
,
to_type
,
&
suffix_
,
block_desc
,
&
cast_map_
);
}
}
if
(
num_low_precision
)
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
// TODO(inference): we need a cast elimination pass.
void
ConvertToMixedPrecisionPass
::
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
!=
"cast"
)
continue
;
auto
input
=
op_node
->
inputs
[
0
];
auto
output
=
op_node
->
outputs
[
0
];
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
input
->
Var
()
->
GetDataType
()));
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
output
->
Var
()
->
GetDataType
()));
}
}
void
ConvertToMixedPrecisionPass
::
SaveMixedModel
()
{
void
ConvertToMixedPrecisionPass
::
SaveMixedModel
()
{
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
main_graph_
,
&
mixed_program_desc
);
framework
::
ir
::
GraphToProgram
(
*
main_graph_
,
&
mixed_program_desc
);
...
@@ -677,51 +87,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -677,51 +87,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
auto
parameters
=
scope_
.
LocalVarNames
();
auto
parameters
=
scope_
.
LocalVarNames
();
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
for
(
auto
*
node
:
main_graph_
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
if
(
VarNodeHasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
()
<<
", ptr "
<<
reinterpret_cast
<
void
*>
(
node
->
Var
());
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
}
}
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(origin_data[i]); \
} \
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
mixed_tensor, platform::CPUPlace(), origin_tensor)
for
(
const
auto
&
param_name
:
parameters
)
{
if
(
weights_should_be_fp32
.
count
(
param_name
))
continue
;
auto
*
var
=
scope_
.
FindLocalVar
(
param_name
);
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
auto
*
origin_tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
origin_tensor
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
phi
::
DenseTensor
mixed_tensor
;
mixed_tensor
.
Resize
(
origin_tensor
->
dims
());
auto
*
origin_data
=
origin_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
FLOAT16
,
phi
::
dtype
::
float16
);
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
BFLOAT16
,
phi
::
dtype
::
bfloat16
);
}
}
}
#undef CONVERT_TENSOR_DTYPE
auto
SerializeParams
=
[
&
]()
->
std
::
string
{
auto
SerializeParams
=
[
&
]()
->
std
::
string
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
phi
::
CPUContext
ctx
;
phi
::
CPUContext
ctx
;
...
@@ -746,73 +111,32 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -746,73 +111,32 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
}
}
}
// namespace
void
AddCastOp
(
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
next_op
,
VarType
::
Type
from_type
,
VarType
::
Type
to_type
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
map
)
{
auto
update_cast_desc
=
[
&
](
framework
::
OpDesc
&
desc
,
const
std
::
string
&
x_name
,
const
std
::
string
&
out_name
,
const
int
in_dtype
,
const
int
out_dtype
)
{
desc
.
SetType
(
"cast"
);
desc
.
SetInput
(
"X"
,
{
x_name
});
desc
.
SetOutput
(
"Out"
,
{
out_name
});
desc
.
SetAttr
(
"in_dtype"
,
in_dtype
);
desc
.
SetAttr
(
"out_dtype"
,
out_dtype
);
desc
.
SetAttr
(
"use_mkldnn"
,
false
);
desc
.
SetAttr
(
"with_quant_attr"
,
false
);
desc
.
Flush
();
};
if
(
map
->
count
(
node
)
==
0
)
{
// insert cast op before node.
std
::
string
cast_input_name
=
node
->
Var
()
->
Name
();
std
::
string
cast_output_name
=
node
->
Var
()
->
Name
()
+
"_cast.tmp_"
+
std
::
to_string
((
*
suffix
)
++
);
CHECK_NOTNULL
(
block_desc
);
framework
::
OpDesc
cast_op_desc
(
block_desc
);
update_cast_desc
(
cast_op_desc
,
cast_input_name
,
cast_output_name
,
static_cast
<
int
>
(
from_type
),
static_cast
<
int
>
(
to_type
));
auto
*
cast_op_node
=
graph
->
CreateOpNode
(
&
cast_op_desc
);
auto
*
cast_output_vardesc
=
block_desc
->
Var
(
cast_output_name
);
cast_output_vardesc
->
SetPersistable
(
false
);
cast_output_vardesc
->
SetDataType
(
to_type
);
cast_output_vardesc
->
SetShape
(
node
->
Var
()
->
GetShape
());
auto
*
cast_output_node
=
graph
->
CreateVarNode
(
cast_output_vardesc
);
IR_NODE_LINK_TO
(
cast_op_node
,
cast_output_node
);
(
*
map
)[
node
]
=
cast_output_node
;
}
next_op
->
Op
()
->
Rename
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_UNLINK
(
node
,
next_op
);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
}
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
)
{
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
{
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
return
framework
::
ir
::
OpSupportPrecision
(
bool
support_precision
=
false
;
op_type
,
backend
,
precision
,
black_list
);
if
(
blacklist
.
count
(
op_type
)
==
0
)
{
}
if
(
backend
==
phi
::
Backend
::
GPU
)
support_precision
=
GpuKernelSupportPrecision
(
op_type
,
precision
);
void
InsertCastOp
(
else
framework
::
ir
::
Graph
*
graph
,
support_precision
=
framework
::
ir
::
Node
*
var_node
,
PhiKernelSupportPrecision
(
phi_op_type
,
backend
,
precision
);
framework
::
ir
::
Node
*
op_node
,
}
framework
::
proto
::
VarType
::
Type
from_type
,
return
support_precision
;
framework
::
proto
::
VarType
::
Type
to_type
,
framework
::
BlockDesc
*
block_desc
,
int
*
suffix
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
visited
)
{
framework
::
ir
::
DoInsertCastOp
(
graph
,
var_node
,
op_node
,
from_type
,
to_type
,
block_desc
,
suffix
,
visited
);
}
}
void
ConvertToMixedPrecision
(
void
ConvertToMixedPrecision
(
...
...
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
浏览文件 @
28ea9aad
...
@@ -15,14 +15,12 @@
...
@@ -15,14 +15,12 @@
#pragma once
#pragma once
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
...
@@ -30,20 +28,52 @@ namespace paddle {
...
@@ -30,20 +28,52 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
class
ConvertToMixedPrecisionPass
{
public:
explicit
ConvertToMixedPrecisionPass
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
void
Run
();
private:
void
LoadModel
();
void
SaveMixedModel
();
private:
std
::
string
model_file_
;
std
::
string
params_file_
;
std
::
string
mixed_model_file_
;
std
::
string
mixed_params_file_
;
phi
::
DataType
mixed_precision_
;
phi
::
Backend
backend_
;
bool
keep_io_types_
;
std
::
unordered_set
<
std
::
string
>
black_list_
;
framework
::
Scope
scope_
;
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
};
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
);
const
std
::
unordered_set
<
std
::
string
>&
black
_
list
);
void
Add
CastOp
(
void
Insert
CastOp
(
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
var_
node
,
framework
::
ir
::
Node
*
next_op
,
framework
::
ir
::
Node
*
op_node
,
framework
::
proto
::
VarType
::
Type
from_type
,
framework
::
proto
::
VarType
::
Type
from_type
,
framework
::
proto
::
VarType
::
Type
to_type
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
map
);
int
*
suffix
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
visited
);
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
params_file
,
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
28ea9aad
...
@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
...
@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
// default
// default
}
else
if
(
precision_mode
==
Precision
::
kHalf
||
}
else
if
(
precision_mode
==
Precision
::
kHalf
||
precision_mode
==
Precision
::
kBf16
)
{
precision_mode
==
Precision
::
kBf16
)
{
enable_gpu_
half
_
=
true
;
enable_gpu_
mixed
_
=
true
;
}
else
{
}
else
{
LOG
(
ERROR
)
LOG
(
ERROR
)
<<
"The Paddle-GPU inference currently only supports "
<<
"The Paddle-GPU inference currently only supports "
...
@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
...
@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Mixed precision related.
// Mixed precision related.
CP_MEMBER
(
mixed_black_list_
);
CP_MEMBER
(
mixed_black_list_
);
CP_MEMBER
(
enable_gpu_
half
_
);
CP_MEMBER
(
enable_gpu_
mixed
_
);
CP_MEMBER
(
mixed_precision_mode_
);
CP_MEMBER
(
mixed_precision_mode_
);
CP_MEMBER
(
enable_memory_optim_
);
CP_MEMBER
(
enable_memory_optim_
);
...
@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
...
@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss
<<
params_file_
;
ss
<<
params_file_
;
ss
<<
use_gpu_
;
ss
<<
use_gpu_
;
ss
<<
enable_gpu_
half
_
;
ss
<<
enable_gpu_
mixed
_
;
ss
<<
use_external_stream_
;
ss
<<
use_external_stream_
;
ss
<<
exec_stream_
;
ss
<<
exec_stream_
;
ss
<<
use_fc_padding_
;
ss
<<
use_fc_padding_
;
...
@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() {
...
@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() {
os
.
InsertRow
({
"use_gpu"
,
use_gpu_
?
"true"
:
"false"
});
os
.
InsertRow
({
"use_gpu"
,
use_gpu_
?
"true"
:
"false"
});
if
(
use_gpu_
)
{
if
(
use_gpu_
)
{
os
.
InsertRow
({
"gpu_device_id"
,
std
::
to_string
(
gpu_device_id_
)});
os
.
InsertRow
({
"gpu_device_id"
,
std
::
to_string
(
gpu_device_id_
)});
os
.
InsertRow
({
"enable_gpu_
half_"
,
std
::
to_string
(
enable_gpu_half
_
)});
os
.
InsertRow
({
"enable_gpu_
mixed_"
,
std
::
to_string
(
enable_gpu_mixed
_
)});
os
.
InsertRow
({
"memory_pool_init_size"
,
os
.
InsertRow
({
"memory_pool_init_size"
,
std
::
to_string
(
memory_pool_init_size_mb_
)
+
"MB"
});
std
::
to_string
(
memory_pool_init_size_mb_
)
+
"MB"
});
os
.
InsertRow
(
os
.
InsertRow
(
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
28ea9aad
...
@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() {
if
(
!
config_
.
ir_optim
())
{
if
(
!
config_
.
ir_optim
())
{
argument_
.
SetEnableIrOptim
(
false
);
argument_
.
SetEnableIrOptim
(
false
);
if
(
config_
.
enable_gpu_
half
_
)
{
if
(
config_
.
enable_gpu_
mixed
_
)
{
argument_
.
SetEnableIrOptim
(
true
);
argument_
.
SetEnableIrOptim
(
true
);
pass_builder
->
ClearPasses
();
pass_builder
->
ClearPasses
();
pass_builder
->
AppendPass
(
"
float_to_half
_pass"
);
pass_builder
->
AppendPass
(
"
auto_mixed_precision
_pass"
);
LOG
(
INFO
)
LOG
(
INFO
)
<<
"This model run in Paddle-GPU mixed precision mode with no ir "
<<
"This model run in Paddle-GPU mixed precision mode with no ir "
"optimization."
;
"optimization."
;
...
@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() {
if
(
config_
.
ir_debug_
)
{
if
(
config_
.
ir_debug_
)
{
pass_builder
->
TurnOnDebug
();
pass_builder
->
TurnOnDebug
();
}
}
if
(
config_
.
enable_gpu_
half
_
)
{
if
(
config_
.
enable_gpu_
mixed
_
)
{
LOG
(
INFO
)
<<
"This model run in Paddle-GPU mixed precision mode."
;
LOG
(
INFO
)
<<
"This model run in Paddle-GPU mixed precision mode."
;
}
}
}
}
...
@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison.
// mixed precison.
argument_
.
SetModelPrecision
(
static_cast
<
int
>
(
model_precision_
));
argument_
.
SetModelPrecision
(
static_cast
<
int
>
(
model_precision_
));
argument_
.
SetMixedBlackList
(
config_
.
mixed_black_list_
);
argument_
.
SetMixedBlackList
(
config_
.
mixed_black_list_
);
argument_
.
SetEnableGPU
Half
(
config_
.
enable_gpu_half
_
);
argument_
.
SetEnableGPU
Mixed
(
config_
.
enable_gpu_mixed
_
);
argument_
.
SetMixedPrecisionMode
(
static_cast
<
int
>
(
argument_
.
SetMixedPrecisionMode
(
static_cast
<
int
>
(
paddle
::
ConvertPrecision
(
config_
.
mixed_precision_mode_
)));
paddle
::
ConvertPrecision
(
config_
.
mixed_precision_mode_
)));
}
}
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
28ea9aad
...
@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig {
...
@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool
use_gpu_
{
false
};
bool
use_gpu_
{
false
};
int
gpu_device_id_
{
0
};
int
gpu_device_id_
{
0
};
uint64_t
memory_pool_init_size_mb_
{
100
};
// initial size is 100MB.
uint64_t
memory_pool_init_size_mb_
{
100
};
// initial size is 100MB.
bool
enable_gpu_
half
_
{
false
};
bool
enable_gpu_
mixed
_
{
false
};
bool
thread_local_stream_
{
false
};
bool
thread_local_stream_
{
false
};
bool
use_cudnn_
{
false
};
bool
use_cudnn_
{
false
};
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
28ea9aad
...
@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
...
@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
#endif //
#endif //
"transpose_flatten_concat_fuse_pass"
,
//
"transpose_flatten_concat_fuse_pass"
,
//
"float_to_half_pass"
,
//
"constant_folding_pass"
,
//
"auto_mixed_precision_pass"
,
//
});
});
use_gpu_
=
true
;
use_gpu_
=
true
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录