Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0972d6ac
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0972d6ac
编写于
10月 27, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
10月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] improve convert_to_mixed_precision (#47333)
上级
5429d145
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
223 addition
and
221 deletion
+223
-221
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+221
-219
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
...id/inference/analysis/passes/convert_to_mixed_precision.h
+2
-2
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
0972d6ac
...
...
@@ -42,13 +42,13 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
using
namespace
paddle
::
framework
;
// NOLINT
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
namespace
{
using
VarType
=
framework
::
proto
::
VarType
;
bool
PhiKernelSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
...
...
@@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision(
phi_op_type
,
phi
::
Backend
::
GPUDNN
,
data_type
,
layout
);
if
(
!
res
)
{
auto
&
all_kernels
=
OperatorWithKernel
::
AllOpKernels
();
auto
&
all_kernels
=
framework
::
OperatorWithKernel
::
AllOpKernels
();
auto
it
=
all_kernels
.
find
(
op_type
);
if
(
it
!=
all_kernels
.
end
())
{
for
(
auto
&
kern_pair
:
it
->
second
)
{
if
(
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
)
&&
kern_pair
.
first
.
data_type_
==
framework
::
proto
::
VarType
::
FP16
)
{
kern_pair
.
first
.
data_type_
==
VarType
::
FP16
)
{
res
=
true
;
break
;
}
}
}
...
...
@@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision(
}
class
ConvertToMixedPrecisionPass
{
using
BlockID
=
size_t
;
public:
explicit
ConvertToMixedPrecisionPass
(
const
std
::
string
&
model_file
,
...
...
@@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass {
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
)
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
:
model_file_
(
model_file
),
params_file_
(
params_file
),
mixed_model_file_
(
mixed_model_file
),
...
...
@@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass {
keep_io_types_
(
keep_io_types
),
black_list_
(
black_list
),
place_
(
paddle
::
CPUPlace
()),
executor_
(
place_
)
{
black_list_
.
insert
(
"assign"
);
black_list_
.
insert
(
"fill_constant"
);
black_list_
.
insert
(
"assign_value"
);
black_list_
.
insert
(
"eye"
);
black_list_
.
insert
(
"fill_any_like"
);
black_list_
.
insert
(
"fill_constant_batch_size_like"
);
}
executor_
(
place_
)
{}
void
Run
();
private:
void
LoadAndPrepare
();
inline
bool
NodeVar
HasDtype
(
framework
::
ir
::
Node
*
node
);
inline
bool
VarNode
HasDtype
(
framework
::
ir
::
Node
*
node
);
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
);
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
);
void
SaveMixedModel
();
void
ConvertTensorDtype
(
int
block_idx
);
void
ConvertTensorDtype
(
BlockID
block_idx
);
void
ProcessInputNode
(
bool
support_precision
,
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
block_idx
);
VarType
::
Type
to_type
,
BlockID
block_idx
);
void
ProcessOutputNode
(
int
block_idx
,
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
);
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
);
void
ProcessOutputNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
);
inline
bool
IsFloatVarType
(
VarType
::
Type
type
);
bool
OutShouldNotConvert
(
ir
::
Node
*
var_node
);
bool
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
);
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
);
framework
::
ir
::
Node
*
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
node
);
void
FindVarsInMultiBlock
();
inline
bool
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
inline
bool
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
);
private:
...
...
@@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass {
framework
::
Scope
scope_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>
vars_in_multi_block_map_
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
vars_appear_multi_in_one_block_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarType
::
Type
,
BlockID
>>
vars_in_multi_block_with_pair_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
vars_in_multi_block_with_ops_
;
int
suffix_
{
0
};
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
...
...
@@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass {
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
)
{
if
(
vars_in_multi_block_map_
.
count
(
node
->
Name
()))
{
int
var_origin_block_id
=
vars_in_multi_block_map_
.
at
(
node
->
Name
()).
second
;
if
(
block_idx
!=
var_origin_block_id
)
{
auto
graph
=
graphes_
[
var_origin_block_id
];
for
(
auto
nd
:
graph
->
Nodes
())
{
if
(
nd
->
Name
()
==
node
->
Name
())
{
return
nd
;
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
if
(
vars_in_multi_block_with_pair_
.
count
(
var_node
->
Name
()))
{
auto
origin_blockId
=
vars_in_multi_block_with_pair_
.
at
(
var_node
->
Name
()).
second
;
if
(
block_idx
!=
origin_blockId
)
{
auto
*
graph
=
graphes_
[
origin_blockId
];
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
var_node
->
Name
())
{
return
node
;
}
}
}
}
return
node
;
return
var_
node
;
}
inline
bool
ConvertToMixedPrecisionPass
::
NodeVarHasDtype
(
framework
::
ir
::
Node
*
node
)
{
if
(
node
->
IsVar
()
&&
(
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
STRINGS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
VOCAB
))
{
return
true
;
}
return
false
;
inline
bool
ConvertToMixedPrecisionPass
::
VarNodeHasDtype
(
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
auto
type
=
var_node
->
Var
()
->
GetType
();
return
(
type
==
VarType
::
SELECTED_ROWS
)
||
(
type
==
VarType
::
LOD_TENSOR
)
||
(
type
==
VarType
::
LOD_TENSOR_ARRAY
)
||
(
type
==
VarType
::
STRINGS
)
||
(
type
==
VarType
::
VOCAB
);
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
bool
ret
{
false
};
for
(
auto
*
out
:
op_node
->
outputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
out
);
if
(
!
real_node
->
Var
()
->
Persistable
()
&&
vars_appear_multi_in_one_block_
[
block_idx
].
count
(
out
->
Name
()))
{
for
(
auto
op_type
:
vars_appear_multi_in_one_block_
[
block_idx
].
at
(
out
->
Name
()))
{
if
(
OpSupportPrecision
(
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
if
(
!
var_node
->
IsVar
())
continue
;
auto
*
real_var_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
real_var_node
->
Var
()
->
Persistable
()
&&
vars_in_multi_block_with_ops_
.
count
(
var_node
->
Name
()))
{
for
(
const
auto
&
op_type
:
vars_in_multi_block_with_ops_
.
at
(
var_node
->
Name
()))
{
if
(
!
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
ret
=
true
;
VLOG
(
2
)
<<
out
->
Name
()
VLOG
(
2
)
<<
var_node
->
Name
()
<<
" is multi precision op's out, so we skip convert to fp16"
;
break
;
return
true
;
}
}
}
if
(
ret
)
break
;
}
return
ret
;
return
false
;
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
block_idx
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
graph
=
graphes_
[
block_idx
];
VarType
::
Type
to_type
,
BlockID
block_idx
)
{
if
(
!
in_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
graph
=
graphes_
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
bool
is_in_multi_block
=
vars_in_multi_block_
map
_
.
count
(
in_var
->
Name
());
bool
is_in_multi_block
=
vars_in_multi_block_
with_pair
_
.
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
in_var_type
=
vars_in_multi_block_
map
_
.
at
(
in_var
->
Name
()).
first
;
in_var_type
=
vars_in_multi_block_
with_pair
_
.
at
(
in_var
->
Name
()).
first
;
}
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
in_var
->
SetDataType
(
to_type
);
in_var_type
=
to_type
;
...
...
@@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
}
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
int
block_idx
,
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
var_node
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
if
(
!
var_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
if
(
out_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
out_var
->
GetDataType
()
==
VarType
::
FP32
)
{
if
(
OutShouldNotConvert
(
var_node
))
return
;
out_var
->
SetDataType
(
to_type
);
}
...
...
@@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode(
}
// Just process special cases.
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_node
=
var_node
->
inputs
[
0
];
auto
*
op_desc
=
op_node
->
Op
();
...
...
@@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
return
false
;
}
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_nodes
=
var_node
->
outputs
;
for
(
auto
*
op_node
:
op_nodes
)
{
auto
*
op_desc
=
op_node
->
Op
();
...
...
@@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
return
false
;
}
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
VarType
::
Type
type
)
{
return
(
type
==
VarType
::
FP16
)
||
(
type
==
VarType
::
FP32
)
||
(
type
==
VarType
::
BF16
);
}
void
ConvertToMixedPrecisionPass
::
LoadAndPrepare
()
{
...
...
@@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc_
));
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
}
// Remove all control var
IrInferCleanGraphPass
pass
;
...
...
@@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
pass
.
Run
(
&
arg
);
vars_appear_multi_in_one_block_
.
resize
(
program_desc_
->
Size
());
FindVarsInMultiBlock
();
}
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
std
::
vector
<
std
::
set
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
());
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
();
++
i
)
{
for
(
auto
op
:
program_desc_
->
Block
(
i
).
AllOps
())
{
auto
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
i
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
auto
out_names
=
op
->
OutputArgumentNames
();
std
::
unordered_set
<
std
::
string
>
all_var_names_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
());
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
const
auto
&
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
for
(
auto
&
n
:
out_names
)
{
if
(
block_var_names_set
[
i
].
count
(
n
))
{
vars_
appear_multi_in_one_block_
[
i
][
n
].
push_back
(
op
->
Type
());
for
(
const
auto
&
name
:
out_names
)
{
if
(
all_var_names_set
.
count
(
name
))
{
vars_
in_multi_block_with_ops_
[
name
].
push_back
(
op
->
Type
());
}
}
}
block_var_names_set
[
i
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
all_var_names_set
.
insert
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
());
}
}
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
()
-
1
;
++
i
)
{
for
(
size_t
j
=
i
+
1
;
j
<
program_desc_
->
Size
();
++
j
)
{
std
::
set
<
std
::
string
>
vars_in_multi_block
;
std
::
set_intersection
(
block_var_names_set
[
i
].
begin
(),
block_var_names_set
[
i
].
end
(),
block_var_names_set
[
j
].
begin
(),
block_var_names_set
[
j
].
end
(),
std
::
inserter
(
vars_in_multi_block
,
vars_in_multi_block
.
begin
()
));
CHECK_GT
(
program_desc_
->
Size
(),
0U
);
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
()
-
1
;
++
idx
)
{
for
(
BlockID
jdx
=
idx
+
1
;
jdx
<
program_desc_
->
Size
();
++
jdx
)
{
std
::
vector
<
std
::
string
>
vars_in_multi_block
;
std
::
set_intersection
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
(),
block_var_names_set
[
jdx
].
begin
(),
block_var_names_set
[
jdx
].
end
(),
std
::
back_inserter
(
vars_in_multi_block
));
for
(
auto
name
:
vars_in_multi_block
)
{
vars_in_multi_block_
map
_
.
emplace
(
name
,
std
::
make_pair
(
framework
::
proto
::
VarType
::
FP32
,
i
));
for
(
const
auto
&
name
:
vars_in_multi_block
)
{
vars_in_multi_block_
with_pair
_
.
emplace
(
name
,
std
::
make_pair
(
VarType
::
FP32
,
idx
));
}
}
}
...
...
@@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
in_var
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
VarType
::
FP64
)
{
in_var
->
SetDataType
(
VarType
::
FP32
);
}
}
}
...
...
@@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
void
ConvertToMixedPrecisionPass
::
Run
()
{
LoadAndPrepare
();
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
for
(
size_t
i
=
0
;
i
<
graphes_
.
size
();
++
i
)
{
auto
*
graph
=
graphes_
[
i
];
VLOG
(
2
)
<<
" -------- handle subgraph "
<<
i
<<
", has "
<<
graph
->
Nodes
().
size
()
<<
" nodes --------"
;
...
...
@@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() {
// A trick
PatchForStrangeOp
();
CHECK_EQ
(
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
}
SaveMixedModel
();
}
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
int
block_idx
)
{
auto
graph
=
graphes_
[
block_idx
];
framework
::
proto
::
VarType
::
Type
to_type
;
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
BlockID
block_idx
)
{
auto
*
graph
=
graphes_
[
block_idx
];
VarType
::
Type
to_type
;
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
FP16
;
to_type
=
VarType
::
FP16
;
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
BF16
;
to_type
=
VarType
::
BF16
;
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"mixed_precision currently not supported dtype %d, we now only "
...
...
@@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 1. set input dtype.
if
(
op_type
==
"feed"
)
{
auto
feed_var
=
op_node
->
outputs
[
0
]
->
Var
();
if
(
!
keep_io_types_
&&
feed_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
!
keep_io_types_
&&
feed_var
->
GetDataType
()
==
VarType
::
FP32
)
{
feed_var
->
SetDataType
(
to_type
);
}
}
else
if
(
op_type
==
"fetch"
)
{
...
...
@@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// same name.
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
for
(
auto
*
in
:
op_node
->
inputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in
);
if
(
NodeVarHasDtype
(
real_node
))
{
if
(
!
in
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in
);
if
(
VarNodeHasDtype
(
real_node
))
{
in_name_to_node
[
in
->
Name
()]
=
in
;
}
}
for
(
auto
out
:
op_node
->
outputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
out
);
if
(
NodeVarHasDtype
(
real_node
))
{
for
(
auto
*
out
:
op_node
->
outputs
)
{
if
(
!
out
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out
);
if
(
VarNodeHasDtype
(
real_node
))
{
if
(
in_name_to_node
.
count
(
out
->
Name
()))
real_node
->
Var
()
->
SetDataType
(
in_name_to_node
[
out
->
Name
()]
->
Var
()
->
GetDataType
());
...
...
@@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
//
// If a var(op's out var) appears multiple times in
a block
, we should not
// If a var(op's out var) appears multiple times in
graph
, we should not
// convert to fp16.
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
// if op not has float input, we will not choose the low precision kernel.
// If the op has no input and output of float type, we will not choose the
// low precision kernel.
{
bool
has_float_input
{
false
};
for
(
auto
in_node
:
op_node
->
inputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
BF16
)
{
has_float_input
=
true
;
bool
has_float_input_and_output
{
false
};
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
!
in_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_input_and_output
=
true
;
break
;
}
}
if
(
!
has_float_input
)
{
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
!
out_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_input_and_output
=
true
;
break
;
}
}
if
(
!
has_float_input_and_output
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
VLOG
(
2
)
<<
" op doesn't has float input
and output
, just skip."
;
}
}
VLOG
(
2
)
<<
" support low precision "
<<
support_precision
;
VLOG
(
2
)
<<
"op type: "
<<
op_type
<<
" support low precision: "
<<
support_precision
;
if
(
support_precision
)
{
VLOG
(
2
)
<<
" process input nodes:"
;
...
...
@@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// Just for paddle's terriable case: op's input and output has the same
// name.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
for
(
auto
out_node
:
op_node
->
outputs
)
{
for
(
auto
in_node
:
op_node
->
inputs
)
{
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
}
...
...
@@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
op_node
,
&
suffix_
,
block_desc
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
block_idx
);
}
}
...
...
@@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
auto
ins
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
ins
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
in_node
,
op_node
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
&
suffix_
,
block_desc
,
&
cast_map_
);
VLOG
(
3
)
<<
"-- "
<<
in_node
->
Name
()
<<
"("
<<
to_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
framework
::
proto
::
VarType
::
FP32
<<
")"
;
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
VarType
::
FP32
<<
")"
;
}
}
}
...
...
@@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for
(
auto
*
node
:
output_nodes
)
{
ir
::
Node
*
fetch_op
{
nullptr
};
framework
::
ir
::
Node
*
fetch_op
{
nullptr
};
for
(
auto
*
op_node
:
node
->
outputs
)
{
if
(
op_node
->
IsOp
()
&&
op_node
->
Op
()
->
Type
()
==
"fetch"
)
{
fetch_op
=
op_node
;
}
}
CHECK_NOTNULL
(
fetch_op
);
auto
var
=
node
->
Var
();
auto
*
var
=
node
->
Var
();
if
(
keep_io_types_
&&
var
->
GetDataType
()
==
to_type
)
{
// fp16/bf16 -> fp32.
AddCastOp
(
graph
,
node
,
fetch_op
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
&
suffix_
,
block_desc
,
&
cast_map_
);
}
else
if
(
!
keep_io_types_
&&
var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
}
else
if
(
!
keep_io_types_
&&
var
->
GetDataType
()
==
VarType
::
FP32
)
{
// fp32 -> fp16/bf16
AddCastOp
(
graph
,
node
,
fetch_op
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
to_type
,
&
suffix_
,
block_desc
,
...
...
@@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
}
}
for
(
auto
node
:
graph
->
Nodes
())
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
node
);
if
(
!
NodeVarHasDtype
(
real_node
))
continue
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
node
);
if
(
!
VarNodeHasDtype
(
real_node
))
continue
;
if
(
vars_in_multi_block_map_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_map_
.
at
(
real_node
->
Name
()).
second
==
block_idx
)
{
vars_in_multi_block_map_
.
at
(
real_node
->
Name
()).
first
=
if
(
vars_in_multi_block_with_pair_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
second
==
block_idx
)
{
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
=
real_node
->
Var
()
->
GetDataType
();
}
}
...
...
@@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
main_graph_
,
&
mixed_program_desc
);
paddle
::
CPUPlace
place
;
auto
parameters
=
scope_
.
LocalVarNames
();
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
for
(
auto
*
node
:
main_graph_
->
Nodes
())
{
if
(
!
(
node
->
IsVar
()
))
continue
;
if
(
NodeVar
HasDtype
(
node
))
{
if
(
!
node
->
IsVar
(
))
continue
;
if
(
VarNode
HasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
paddle
::
framework
::
proto
::
VarType
::
FP32
)
{
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
...
...
@@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int
i = 0; i < t->numel(); i++) {
\
mixed_data[i] = static_cast<dtype>(
data[i]);
\
for (int
64_t i = 0; i < origin_tensor->numel(); i++) {
\
mixed_data[i] = static_cast<dtype>(
origin_data[i]);
\
} \
t->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
mixed_tensor, platform::CPUPlace(), origin_tensor)
for
(
const
auto
&
param_name
:
parameters
)
{
if
(
weights_should_be_fp32
.
count
(
param_name
))
continue
;
auto
*
var
=
scope_
.
FindLocalVar
(
param_name
);
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
auto
*
t
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
t
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
auto
*
origin_tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
origin_tensor
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
phi
::
DenseTensor
mixed_tensor
;
mixed_tensor
.
Resize
(
t
->
dims
());
auto
*
data
=
t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
)
)
{
mixed_tensor
.
Resize
(
origin_tensor
->
dims
());
auto
*
origin_data
=
origin_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
FLOAT16
,
phi
::
dtype
::
float16
);
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
))
{
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
BFLOAT16
,
phi
::
dtype
::
bfloat16
);
}
...
...
@@ -851,8 +852,8 @@ void AddCastOp(
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
next_op
,
framework
::
proto
::
VarType
::
Type
from_type
,
framework
::
proto
::
VarType
::
Type
to_type
,
VarType
::
Type
from_type
,
VarType
::
Type
to_type
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
map
)
{
...
...
@@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type,
return
support_precision
;
}
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
)
{
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
{
ConvertToMixedPrecisionPass
pass
(
model_file
,
params_file
,
mixed_model_file
,
...
...
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
浏览文件 @
0972d6ac
...
...
@@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
=
true
,
std
::
unordered_set
<
std
::
string
>
black_list
=
{}
);
bool
keep_io_types
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
}
// namespace analysis
}
// namespace inference
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录