Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0972d6ac
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0972d6ac
编写于
10月 27, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
10月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] improve convert_to_mixed_precision (#47333)
上级
5429d145
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
223 addition
and
221 deletion
+223
-221
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+221
-219
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
...id/inference/analysis/passes/convert_to_mixed_precision.h
+2
-2
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
0972d6ac
...
...
@@ -42,13 +42,13 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
using
namespace
paddle
::
framework
;
// NOLINT
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
namespace
{
using
VarType
=
framework
::
proto
::
VarType
;
bool
PhiKernelSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
...
...
@@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision(
phi_op_type
,
phi
::
Backend
::
GPUDNN
,
data_type
,
layout
);
if
(
!
res
)
{
auto
&
all_kernels
=
OperatorWithKernel
::
AllOpKernels
();
auto
&
all_kernels
=
framework
::
OperatorWithKernel
::
AllOpKernels
();
auto
it
=
all_kernels
.
find
(
op_type
);
if
(
it
!=
all_kernels
.
end
())
{
for
(
auto
&
kern_pair
:
it
->
second
)
{
if
(
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
)
&&
kern_pair
.
first
.
data_type_
==
framework
::
proto
::
VarType
::
FP16
)
{
kern_pair
.
first
.
data_type_
==
VarType
::
FP16
)
{
res
=
true
;
break
;
}
}
}
...
...
@@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision(
}
class
ConvertToMixedPrecisionPass
{
using
BlockID
=
size_t
;
public:
explicit
ConvertToMixedPrecisionPass
(
const
std
::
string
&
model_file
,
...
...
@@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass {
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
)
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
:
model_file_
(
model_file
),
params_file_
(
params_file
),
mixed_model_file_
(
mixed_model_file
),
...
...
@@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass {
keep_io_types_
(
keep_io_types
),
black_list_
(
black_list
),
place_
(
paddle
::
CPUPlace
()),
executor_
(
place_
)
{
black_list_
.
insert
(
"assign"
);
black_list_
.
insert
(
"fill_constant"
);
black_list_
.
insert
(
"assign_value"
);
black_list_
.
insert
(
"eye"
);
black_list_
.
insert
(
"fill_any_like"
);
black_list_
.
insert
(
"fill_constant_batch_size_like"
);
}
executor_
(
place_
)
{}
void
Run
();
private:
void
LoadAndPrepare
();
inline
bool
NodeVar
HasDtype
(
framework
::
ir
::
Node
*
node
);
inline
bool
VarNode
HasDtype
(
framework
::
ir
::
Node
*
node
);
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
);
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
);
void
SaveMixedModel
();
void
ConvertTensorDtype
(
int
block_idx
);
void
ConvertTensorDtype
(
BlockID
block_idx
);
void
ProcessInputNode
(
bool
support_precision
,
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
block_idx
);
VarType
::
Type
to_type
,
BlockID
block_idx
);
void
ProcessOutputNode
(
int
block_idx
,
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
);
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
);
void
ProcessOutputNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
);
inline
bool
IsFloatVarType
(
VarType
::
Type
type
);
bool
OutShouldNotConvert
(
ir
::
Node
*
var_node
);
bool
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
);
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
);
framework
::
ir
::
Node
*
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
node
);
void
FindVarsInMultiBlock
();
inline
bool
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
inline
bool
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
);
private:
...
...
@@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass {
framework
::
Scope
scope_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>
vars_in_multi_block_map_
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
vars_appear_multi_in_one_block_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarType
::
Type
,
BlockID
>>
vars_in_multi_block_with_pair_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
vars_in_multi_block_with_ops_
;
int
suffix_
{
0
};
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
...
...
@@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass {
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
)
{
if
(
vars_in_multi_block_map_
.
count
(
node
->
Name
()))
{
int
var_origin_block_id
=
vars_in_multi_block_map_
.
at
(
node
->
Name
()).
second
;
if
(
block_idx
!=
var_origin_block_id
)
{
auto
graph
=
graphes_
[
var_origin_block_id
];
for
(
auto
nd
:
graph
->
Nodes
())
{
if
(
nd
->
Name
()
==
node
->
Name
())
{
return
nd
;
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
if
(
vars_in_multi_block_with_pair_
.
count
(
var_node
->
Name
()))
{
auto
origin_blockId
=
vars_in_multi_block_with_pair_
.
at
(
var_node
->
Name
()).
second
;
if
(
block_idx
!=
origin_blockId
)
{
auto
*
graph
=
graphes_
[
origin_blockId
];
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
var_node
->
Name
())
{
return
node
;
}
}
}
}
return
node
;
return
var_
node
;
}
inline
bool
ConvertToMixedPrecisionPass
::
NodeVarHasDtype
(
framework
::
ir
::
Node
*
node
)
{
if
(
node
->
IsVar
()
&&
(
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
STRINGS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
VOCAB
))
{
return
true
;
}
return
false
;
inline
bool
ConvertToMixedPrecisionPass
::
VarNodeHasDtype
(
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
auto
type
=
var_node
->
Var
()
->
GetType
();
return
(
type
==
VarType
::
SELECTED_ROWS
)
||
(
type
==
VarType
::
LOD_TENSOR
)
||
(
type
==
VarType
::
LOD_TENSOR_ARRAY
)
||
(
type
==
VarType
::
STRINGS
)
||
(
type
==
VarType
::
VOCAB
);
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
bool
ret
{
false
};
for
(
auto
*
out
:
op_node
->
outputs
)
{
auto
*
real_
node
=
GetRealNode
(
block_idx
,
out
);
if
(
!
real_node
->
Var
()
->
Persistable
()
&&
vars_
appear_multi_in_one_block_
[
block_idx
].
count
(
out
->
Name
()))
{
for
(
auto
op_type
:
vars_
appear_multi_in_one_block_
[
block_idx
].
at
(
out
->
Name
()))
{
if
(
OpSupportPrecision
(
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
if
(
!
var_node
->
IsVar
())
continue
;
auto
*
real_
var_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
real_
var_
node
->
Var
()
->
Persistable
()
&&
vars_
in_multi_block_with_ops_
.
count
(
var_node
->
Name
()))
{
for
(
const
auto
&
op_type
:
vars_
in_multi_block_with_ops_
.
at
(
var_node
->
Name
()))
{
if
(
!
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
ret
=
true
;
VLOG
(
2
)
<<
out
->
Name
()
VLOG
(
2
)
<<
var_node
->
Name
()
<<
" is multi precision op's out, so we skip convert to fp16"
;
break
;
return
true
;
}
}
}
if
(
ret
)
break
;
}
return
ret
;
return
false
;
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
block_idx
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
graph
=
graphes_
[
block_idx
];
VarType
::
Type
to_type
,
BlockID
block_idx
)
{
if
(
!
in_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
graph
=
graphes_
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
bool
is_in_multi_block
=
vars_in_multi_block_
map
_
.
count
(
in_var
->
Name
());
bool
is_in_multi_block
=
vars_in_multi_block_
with_pair
_
.
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
in_var_type
=
vars_in_multi_block_
map
_
.
at
(
in_var
->
Name
()).
first
;
in_var_type
=
vars_in_multi_block_
with_pair
_
.
at
(
in_var
->
Name
()).
first
;
}
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
in_var
->
SetDataType
(
to_type
);
in_var_type
=
to_type
;
...
...
@@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
}
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
int
block_idx
,
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
var_node
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
if
(
!
var_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
if
(
out_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
out_var
->
GetDataType
()
==
VarType
::
FP32
)
{
if
(
OutShouldNotConvert
(
var_node
))
return
;
out_var
->
SetDataType
(
to_type
);
}
...
...
@@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode(
}
// Just process special cases.
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_node
=
var_node
->
inputs
[
0
];
auto
*
op_desc
=
op_node
->
Op
();
...
...
@@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
return
false
;
}
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_nodes
=
var_node
->
outputs
;
for
(
auto
*
op_node
:
op_nodes
)
{
auto
*
op_desc
=
op_node
->
Op
();
...
...
@@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
return
false
;
}
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
VarType
::
Type
type
)
{
return
(
type
==
VarType
::
FP16
)
||
(
type
==
VarType
::
FP32
)
||
(
type
==
VarType
::
BF16
);
}
void
ConvertToMixedPrecisionPass
::
LoadAndPrepare
()
{
...
...
@@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc_
));
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
}
// Remove all control var
IrInferCleanGraphPass
pass
;
...
...
@@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
pass
.
Run
(
&
arg
);
vars_appear_multi_in_one_block_
.
resize
(
program_desc_
->
Size
());
FindVarsInMultiBlock
();
}
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
std
::
vector
<
std
::
set
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
());
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
();
++
i
)
{
for
(
auto
op
:
program_desc_
->
Block
(
i
).
AllOps
())
{
auto
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
i
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
auto
out_names
=
op
->
OutputArgumentNames
();
std
::
unordered_set
<
std
::
string
>
all_var_names_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
());
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
const
auto
&
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
for
(
auto
&
n
:
out_names
)
{
if
(
block_var_names_set
[
i
].
count
(
n
))
{
vars_
appear_multi_in_one_block_
[
i
][
n
].
push_back
(
op
->
Type
());
for
(
const
auto
&
name
:
out_names
)
{
if
(
all_var_names_set
.
count
(
name
))
{
vars_
in_multi_block_with_ops_
[
name
].
push_back
(
op
->
Type
());
}
}
}
block_var_names_set
[
i
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
all_var_names_set
.
insert
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
());
}
}
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
()
-
1
;
++
i
)
{
for
(
size_t
j
=
i
+
1
;
j
<
program_desc_
->
Size
();
++
j
)
{
std
::
set
<
std
::
string
>
vars_in_multi_block
;
std
::
set_intersection
(
block_var_names_set
[
i
].
begin
(),
block_var_names_set
[
i
].
end
(),
block_var_names_set
[
j
].
begin
(),
block_var_names_set
[
j
].
end
(),
std
::
inserter
(
vars_in_multi_block
,
vars_in_multi_block
.
begin
()
));
for
(
auto
name
:
vars_in_multi_block
)
{
vars_in_multi_block_
map
_
.
emplace
(
name
,
std
::
make_pair
(
framework
::
proto
::
VarType
::
FP32
,
i
));
CHECK_GT
(
program_desc_
->
Size
(),
0U
);
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
()
-
1
;
++
idx
)
{
for
(
BlockID
jdx
=
idx
+
1
;
jdx
<
program_desc_
->
Size
();
++
jdx
)
{
std
::
vector
<
std
::
string
>
vars_in_multi_block
;
std
::
set_intersection
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
(),
block_var_names_set
[
jdx
].
begin
(),
block_var_names_set
[
jdx
].
end
(),
std
::
back_inserter
(
vars_in_multi_block
));
for
(
const
auto
&
name
:
vars_in_multi_block
)
{
vars_in_multi_block_
with_pair
_
.
emplace
(
name
,
std
::
make_pair
(
VarType
::
FP32
,
idx
));
}
}
}
...
...
@@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
in_var
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
VarType
::
FP64
)
{
in_var
->
SetDataType
(
VarType
::
FP32
);
}
}
}
...
...
@@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
void
ConvertToMixedPrecisionPass
::
Run
()
{
LoadAndPrepare
();
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
for
(
size_t
i
=
0
;
i
<
graphes_
.
size
();
++
i
)
{
auto
*
graph
=
graphes_
[
i
];
VLOG
(
2
)
<<
" -------- handle subgraph "
<<
i
<<
", has "
<<
graph
->
Nodes
().
size
()
<<
" nodes --------"
;
...
...
@@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() {
// A trick
PatchForStrangeOp
();
CHECK_EQ
(
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
}
SaveMixedModel
();
}
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
int
block_idx
)
{
auto
graph
=
graphes_
[
block_idx
];
framework
::
proto
::
VarType
::
Type
to_type
;
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
BlockID
block_idx
)
{
auto
*
graph
=
graphes_
[
block_idx
];
VarType
::
Type
to_type
;
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
FP16
;
to_type
=
VarType
::
FP16
;
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
BF16
;
to_type
=
VarType
::
BF16
;
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"mixed_precision currently not supported dtype %d, we now only "
...
...
@@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 1. set input dtype.
if
(
op_type
==
"feed"
)
{
auto
feed_var
=
op_node
->
outputs
[
0
]
->
Var
();
if
(
!
keep_io_types_
&&
feed_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
!
keep_io_types_
&&
feed_var
->
GetDataType
()
==
VarType
::
FP32
)
{
feed_var
->
SetDataType
(
to_type
);
}
}
else
if
(
op_type
==
"fetch"
)
{
...
...
@@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// same name.
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
for
(
auto
*
in
:
op_node
->
inputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in
);
if
(
NodeVarHasDtype
(
real_node
))
{
if
(
!
in
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in
);
if
(
VarNodeHasDtype
(
real_node
))
{
in_name_to_node
[
in
->
Name
()]
=
in
;
}
}
for
(
auto
out
:
op_node
->
outputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
out
);
if
(
NodeVarHasDtype
(
real_node
))
{
for
(
auto
*
out
:
op_node
->
outputs
)
{
if
(
!
out
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out
);
if
(
VarNodeHasDtype
(
real_node
))
{
if
(
in_name_to_node
.
count
(
out
->
Name
()))
real_node
->
Var
()
->
SetDataType
(
in_name_to_node
[
out
->
Name
()]
->
Var
()
->
GetDataType
());
...
...
@@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
//
// If a var(op's out var) appears multiple times in
a block
, we should not
// If a var(op's out var) appears multiple times in
graph
, we should not
// convert to fp16.
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
// if op not has float input, we will not choose the low precision kernel.
// If the op has no input and output of float type, we will not choose the
// low precision kernel.
{
bool
has_float_input
{
false
};
for
(
auto
in_node
:
op_node
->
inputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
BF16
)
{
has_float_input
=
true
;
bool
has_float_input_and_output
{
false
};
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
!
in_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_input_and_output
=
true
;
break
;
}
}
if
(
!
has_float_input
)
{
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
!
out_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_input_and_output
=
true
;
break
;
}
}
if
(
!
has_float_input_and_output
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
VLOG
(
2
)
<<
" op doesn't has float input
and output
, just skip."
;
}
}
VLOG
(
2
)
<<
" support low precision "
<<
support_precision
;
VLOG
(
2
)
<<
"op type: "
<<
op_type
<<
" support low precision: "
<<
support_precision
;
if
(
support_precision
)
{
VLOG
(
2
)
<<
" process input nodes:"
;
...
...
@@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// Just for paddle's terriable case: op's input and output has the same
// name.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
for
(
auto
out_node
:
op_node
->
outputs
)
{
for
(
auto
in_node
:
op_node
->
inputs
)
{
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
}
...
...
@@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
op_node
,
&
suffix_
,
block_desc
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
block_idx
);
}
}
...
...
@@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
auto
ins
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
ins
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
in_node
,
op_node
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
&
suffix_
,
block_desc
,
&
cast_map_
);
VLOG
(
3
)
<<
"-- "
<<
in_node
->
Name
()
<<
"("
<<
to_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
framework
::
proto
::
VarType
::
FP32
<<
")"
;
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
VarType
::
FP32
<<
")"
;
}
}
}
...
...
@@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for
(
auto
*
node
:
output_nodes
)
{
ir
::
Node
*
fetch_op
{
nullptr
};
framework
::
ir
::
Node
*
fetch_op
{
nullptr
};
for
(
auto
*
op_node
:
node
->
outputs
)
{
if
(
op_node
->
IsOp
()
&&
op_node
->
Op
()
->
Type
()
==
"fetch"
)
{
fetch_op
=
op_node
;
}
}
CHECK_NOTNULL
(
fetch_op
);
auto
var
=
node
->
Var
();
auto
*
var
=
node
->
Var
();
if
(
keep_io_types_
&&
var
->
GetDataType
()
==
to_type
)
{
// fp16/bf16 -> fp32.
AddCastOp
(
graph
,
node
,
fetch_op
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
&
suffix_
,
block_desc
,
&
cast_map_
);
}
else
if
(
!
keep_io_types_
&&
var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
}
else
if
(
!
keep_io_types_
&&
var
->
GetDataType
()
==
VarType
::
FP32
)
{
// fp32 -> fp16/bf16
AddCastOp
(
graph
,
node
,
fetch_op
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
to_type
,
&
suffix_
,
block_desc
,
...
...
@@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
}
}
for
(
auto
node
:
graph
->
Nodes
())
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
node
);
if
(
!
NodeVarHasDtype
(
real_node
))
continue
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
node
);
if
(
!
VarNodeHasDtype
(
real_node
))
continue
;
if
(
vars_in_multi_block_map_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_map_
.
at
(
real_node
->
Name
()).
second
==
block_idx
)
{
vars_in_multi_block_map_
.
at
(
real_node
->
Name
()).
first
=
if
(
vars_in_multi_block_with_pair_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
second
==
block_idx
)
{
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
=
real_node
->
Var
()
->
GetDataType
();
}
}
...
...
@@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
main_graph_
,
&
mixed_program_desc
);
paddle
::
CPUPlace
place
;
auto
parameters
=
scope_
.
LocalVarNames
();
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
for
(
auto
*
node
:
main_graph_
->
Nodes
())
{
if
(
!
(
node
->
IsVar
()
))
continue
;
if
(
NodeVar
HasDtype
(
node
))
{
if
(
!
node
->
IsVar
(
))
continue
;
if
(
VarNode
HasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
paddle
::
framework
::
proto
::
VarType
::
FP32
)
{
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
...
...
@@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int
i = 0; i < t->numel(); i++) {
\
mixed_data[i] = static_cast<dtype>(
data[i]);
\
for (int
64_t i = 0; i < origin_tensor->numel(); i++) {
\
mixed_data[i] = static_cast<dtype>(
origin_data[i]);
\
} \
t->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
mixed_tensor, platform::CPUPlace(), origin_tensor)
for
(
const
auto
&
param_name
:
parameters
)
{
if
(
weights_should_be_fp32
.
count
(
param_name
))
continue
;
auto
*
var
=
scope_
.
FindLocalVar
(
param_name
);
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
auto
*
t
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
t
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
auto
*
origin_tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
origin_tensor
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
phi
::
DenseTensor
mixed_tensor
;
mixed_tensor
.
Resize
(
t
->
dims
());
auto
*
data
=
t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
)
)
{
mixed_tensor
.
Resize
(
origin_tensor
->
dims
());
auto
*
origin_data
=
origin_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
FLOAT16
,
phi
::
dtype
::
float16
);
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
))
{
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
BFLOAT16
,
phi
::
dtype
::
bfloat16
);
}
...
...
@@ -851,8 +852,8 @@ void AddCastOp(
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
next_op
,
framework
::
proto
::
VarType
::
Type
from_type
,
framework
::
proto
::
VarType
::
Type
to_type
,
VarType
::
Type
from_type
,
VarType
::
Type
to_type
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
map
)
{
...
...
@@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type,
return
support_precision
;
}
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
)
{
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
{
ConvertToMixedPrecisionPass
pass
(
model_file
,
params_file
,
mixed_model_file
,
...
...
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
浏览文件 @
0972d6ac
...
...
@@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
=
true
,
std
::
unordered_set
<
std
::
string
>
black_list
=
{}
);
bool
keep_io_types
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
}
// namespace analysis
}
// namespace inference
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录