Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0972d6ac
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0972d6ac
编写于
10月 27, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
10月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] improve convert_to_mixed_precision (#47333)
上级
5429d145
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
223 addition
and
221 deletion
+223
-221
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+221
-219
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
...id/inference/analysis/passes/convert_to_mixed_precision.h
+2
-2
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
0972d6ac
...
@@ -42,13 +42,13 @@
...
@@ -42,13 +42,13 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
using
namespace
paddle
::
framework
;
// NOLINT
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
namespace
{
namespace
{
using
VarType
=
framework
::
proto
::
VarType
;
bool
PhiKernelSupportPrecision
(
bool
PhiKernelSupportPrecision
(
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
...
@@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision(
...
@@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision(
phi_op_type
,
phi
::
Backend
::
GPUDNN
,
data_type
,
layout
);
phi_op_type
,
phi
::
Backend
::
GPUDNN
,
data_type
,
layout
);
if
(
!
res
)
{
if
(
!
res
)
{
auto
&
all_kernels
=
OperatorWithKernel
::
AllOpKernels
();
auto
&
all_kernels
=
framework
::
OperatorWithKernel
::
AllOpKernels
();
auto
it
=
all_kernels
.
find
(
op_type
);
auto
it
=
all_kernels
.
find
(
op_type
);
if
(
it
!=
all_kernels
.
end
())
{
if
(
it
!=
all_kernels
.
end
())
{
for
(
auto
&
kern_pair
:
it
->
second
)
{
for
(
auto
&
kern_pair
:
it
->
second
)
{
if
(
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
)
&&
if
(
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
)
&&
kern_pair
.
first
.
data_type_
==
framework
::
proto
::
VarType
::
FP16
)
{
kern_pair
.
first
.
data_type_
==
VarType
::
FP16
)
{
res
=
true
;
res
=
true
;
break
;
}
}
}
}
}
}
...
@@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision(
...
@@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision(
}
}
class
ConvertToMixedPrecisionPass
{
class
ConvertToMixedPrecisionPass
{
using
BlockID
=
size_t
;
public:
public:
explicit
ConvertToMixedPrecisionPass
(
explicit
ConvertToMixedPrecisionPass
(
const
std
::
string
&
model_file
,
const
std
::
string
&
model_file
,
...
@@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass {
...
@@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass {
phi
::
DataType
mixed_precision
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
bool
keep_io_types
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
)
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
:
model_file_
(
model_file
),
:
model_file_
(
model_file
),
params_file_
(
params_file
),
params_file_
(
params_file
),
mixed_model_file_
(
mixed_model_file
),
mixed_model_file_
(
mixed_model_file
),
...
@@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass {
...
@@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass {
keep_io_types_
(
keep_io_types
),
keep_io_types_
(
keep_io_types
),
black_list_
(
black_list
),
black_list_
(
black_list
),
place_
(
paddle
::
CPUPlace
()),
place_
(
paddle
::
CPUPlace
()),
executor_
(
place_
)
{
executor_
(
place_
)
{}
black_list_
.
insert
(
"assign"
);
black_list_
.
insert
(
"fill_constant"
);
black_list_
.
insert
(
"assign_value"
);
black_list_
.
insert
(
"eye"
);
black_list_
.
insert
(
"fill_any_like"
);
black_list_
.
insert
(
"fill_constant_batch_size_like"
);
}
void
Run
();
void
Run
();
private:
private:
void
LoadAndPrepare
();
void
LoadAndPrepare
();
inline
bool
NodeVar
HasDtype
(
framework
::
ir
::
Node
*
node
);
inline
bool
VarNode
HasDtype
(
framework
::
ir
::
Node
*
node
);
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
);
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
);
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
);
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
);
void
SaveMixedModel
();
void
SaveMixedModel
();
void
ConvertTensorDtype
(
int
block_idx
);
void
ConvertTensorDtype
(
BlockID
block_idx
);
void
ProcessInputNode
(
bool
support_precision
,
void
ProcessInputNode
(
bool
support_precision
,
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
VarType
::
Type
to_type
,
int
block_idx
);
BlockID
block_idx
);
void
ProcessOutputNode
(
int
block_idx
,
void
ProcessOutputNode
(
BlockID
block_idx
,
ir
::
Node
*
var_node
,
framework
::
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
);
VarType
::
Type
to_type
);
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
);
inline
bool
IsFloatVarType
(
VarType
::
Type
type
);
bool
OutShouldNotConvert
(
ir
::
Node
*
var_node
);
bool
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// Just process special cases for weights conversion.
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
);
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// To support multi block, we need to consider a lot of special cases.
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
);
framework
::
ir
::
Node
*
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
node
);
void
FindVarsInMultiBlock
();
void
FindVarsInMultiBlock
();
inline
bool
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
inline
bool
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
);
framework
::
ir
::
Node
*
op_node
);
private:
private:
...
@@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass {
...
@@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass {
framework
::
Scope
scope_
;
framework
::
Scope
scope_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarType
::
Type
,
BlockID
>>
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>
vars_in_multi_block_with_pair_
;
vars_in_multi_block_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
vars_in_multi_block_with_ops_
;
vars_appear_multi_in_one_block_
;
int
suffix_
{
0
};
int
suffix_
{
0
};
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
...
@@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass {
...
@@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass {
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealNode
(
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
)
{
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
)
{
if
(
vars_in_multi_block_map_
.
count
(
node
->
Name
()))
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
int
var_origin_block_id
=
vars_in_multi_block_map_
.
at
(
node
->
Name
()).
second
;
if
(
block_idx
!=
var_origin_block_id
)
{
if
(
vars_in_multi_block_with_pair_
.
count
(
var_node
->
Name
()))
{
auto
graph
=
graphes_
[
var_origin_block_id
];
auto
origin_blockId
=
for
(
auto
nd
:
graph
->
Nodes
())
{
vars_in_multi_block_with_pair_
.
at
(
var_node
->
Name
()).
second
;
if
(
nd
->
Name
()
==
node
->
Name
())
{
if
(
block_idx
!=
origin_blockId
)
{
return
nd
;
auto
*
graph
=
graphes_
[
origin_blockId
];
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
var_node
->
Name
())
{
return
node
;
}
}
}
}
}
}
}
}
return
node
;
return
var_
node
;
}
}
inline
bool
ConvertToMixedPrecisionPass
::
NodeVarHasDtype
(
inline
bool
ConvertToMixedPrecisionPass
::
VarNodeHasDtype
(
framework
::
ir
::
Node
*
node
)
{
framework
::
ir
::
Node
*
var_node
)
{
if
(
node
->
IsVar
()
&&
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
(
node
->
Var
()
->
GetType
()
==
auto
type
=
var_node
->
Var
()
->
GetType
();
paddle
::
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
return
(
type
==
VarType
::
SELECTED_ROWS
)
||
(
type
==
VarType
::
LOD_TENSOR
)
||
node
->
Var
()
->
GetType
()
==
(
type
==
VarType
::
LOD_TENSOR_ARRAY
)
||
(
type
==
VarType
::
STRINGS
)
||
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
||
(
type
==
VarType
::
VOCAB
);
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
STRINGS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
VOCAB
))
{
return
true
;
}
return
false
;
}
}
// op1(fp32) -> var1, op2(fp16) -> var1
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
// precision.
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
bool
ret
{
false
};
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
for
(
auto
*
out
:
op_node
->
outputs
)
{
if
(
!
var_node
->
IsVar
())
continue
;
auto
*
real_
node
=
GetRealNode
(
block_idx
,
out
);
auto
*
real_
var_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
real_node
->
Var
()
->
Persistable
()
&&
if
(
!
real_
var_
node
->
Var
()
->
Persistable
()
&&
vars_
appear_multi_in_one_block_
[
block_idx
].
count
(
out
->
Name
()))
{
vars_
in_multi_block_with_ops_
.
count
(
var_node
->
Name
()))
{
for
(
auto
op_type
:
for
(
const
auto
&
op_type
:
vars_
appear_multi_in_one_block_
[
block_idx
].
at
(
out
->
Name
()))
{
vars_
in_multi_block_with_ops_
.
at
(
var_node
->
Name
()))
{
if
(
OpSupportPrecision
(
if
(
!
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
ret
=
true
;
VLOG
(
2
)
<<
var_node
->
Name
()
VLOG
(
2
)
<<
out
->
Name
()
<<
" is multi precision op's out, so we skip convert to fp16"
;
<<
" is multi precision op's out, so we skip convert to fp16"
;
break
;
return
true
;
}
}
}
}
}
}
if
(
ret
)
break
;
}
}
return
ret
;
return
false
;
}
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
bool
support_precision
,
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
framework
::
ir
::
Node
*
op_node
,
int
*
suffix
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
VarType
::
Type
to_type
,
int
block_idx
)
{
BlockID
block_idx
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
if
(
!
in_node
->
IsVar
())
return
;
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
auto
graph
=
graphes_
[
block_idx
];
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
graph
=
graphes_
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
auto
prev_type
=
in_var_type
;
bool
is_in_multi_block
=
vars_in_multi_block_
map
_
.
count
(
in_var
->
Name
());
bool
is_in_multi_block
=
vars_in_multi_block_
with_pair
_
.
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
if
(
!
is_main_block
&&
is_in_multi_block
)
{
in_var_type
=
vars_in_multi_block_
map
_
.
at
(
in_var
->
Name
()).
first
;
in_var_type
=
vars_in_multi_block_
with_pair
_
.
at
(
in_var
->
Name
()).
first
;
}
}
if
(
support_precision
)
{
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
in_var_type
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
in_var
->
SetDataType
(
to_type
);
in_var
->
SetDataType
(
to_type
);
in_var_type
=
to_type
;
in_var_type
=
to_type
;
...
@@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
...
@@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
}
}
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
int
block_idx
,
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
ir
::
Node
*
var_node
,
if
(
!
var_node
->
IsVar
())
return
;
framework
::
proto
::
VarType
::
Type
to_type
)
{
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
var_node
);
auto
*
real_node
=
GetRealNode
(
block_idx
,
var_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
auto
prev_type
=
out_var
->
GetDataType
();
if
(
out_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
out_var
->
GetDataType
()
==
VarType
::
FP32
)
{
if
(
OutShouldNotConvert
(
var_node
))
return
;
if
(
OutShouldNotConvert
(
var_node
))
return
;
out_var
->
SetDataType
(
to_type
);
out_var
->
SetDataType
(
to_type
);
}
}
...
@@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode(
...
@@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode(
}
}
// Just process special cases.
// Just process special cases.
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_node
=
var_node
->
inputs
[
0
];
auto
op_node
=
var_node
->
inputs
[
0
];
auto
*
op_desc
=
op_node
->
Op
();
auto
*
op_desc
=
op_node
->
Op
();
...
@@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
...
@@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
return
false
;
return
false
;
}
}
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
)
{
auto
op_nodes
=
var_node
->
outputs
;
auto
op_nodes
=
var_node
->
outputs
;
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
auto
*
op_desc
=
op_node
->
Op
();
auto
*
op_desc
=
op_node
->
Op
();
...
@@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
...
@@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
return
false
;
return
false
;
}
}
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
return
(
type
==
VarType
::
FP16
)
||
(
type
==
VarType
::
FP32
)
||
type
==
framework
::
proto
::
VarType
::
FP32
||
(
type
==
VarType
::
BF16
);
type
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
}
}
void
ConvertToMixedPrecisionPass
::
LoadAndPrepare
()
{
void
ConvertToMixedPrecisionPass
::
LoadAndPrepare
()
{
...
@@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
...
@@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc_
));
new
framework
::
ir
::
Graph
(
*
program_desc_
));
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
}
// Remove all control var
// Remove all control var
IrInferCleanGraphPass
pass
;
IrInferCleanGraphPass
pass
;
...
@@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
...
@@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
pass
.
Run
(
&
arg
);
pass
.
Run
(
&
arg
);
vars_appear_multi_in_one_block_
.
resize
(
program_desc_
->
Size
());
FindVarsInMultiBlock
();
FindVarsInMultiBlock
();
}
}
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
std
::
vector
<
std
::
set
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
());
std
::
unordered_set
<
std
::
string
>
all_var_names_set
;
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
();
++
i
)
{
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
block_var_names_set
(
for
(
auto
op
:
program_desc_
->
Block
(
i
).
AllOps
())
{
program_desc_
->
Size
());
auto
in_names
=
op
->
InputArgumentNames
();
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
block_var_names_set
[
i
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
auto
out_names
=
op
->
OutputArgumentNames
();
const
auto
&
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
for
(
auto
&
n
:
out_names
)
{
for
(
const
auto
&
name
:
out_names
)
{
if
(
block_var_names_set
[
i
].
count
(
n
))
{
if
(
all_var_names_set
.
count
(
name
))
{
vars_
appear_multi_in_one_block_
[
i
][
n
].
push_back
(
op
->
Type
());
vars_
in_multi_block_with_ops_
[
name
].
push_back
(
op
->
Type
());
}
}
}
}
}
}
block_var_names_set
[
i
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
all_var_names_set
.
insert
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
());
}
}
}
}
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
()
-
1
;
++
i
)
{
CHECK_GT
(
program_desc_
->
Size
(),
0U
);
for
(
size_t
j
=
i
+
1
;
j
<
program_desc_
->
Size
();
++
j
)
{
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
()
-
1
;
++
idx
)
{
std
::
set
<
std
::
string
>
vars_in_multi_block
;
for
(
BlockID
jdx
=
idx
+
1
;
jdx
<
program_desc_
->
Size
();
++
jdx
)
{
std
::
set_intersection
(
std
::
vector
<
std
::
string
>
vars_in_multi_block
;
block_var_names_set
[
i
].
begin
(),
std
::
set_intersection
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
i
].
end
(),
block_var_names_set
[
idx
].
end
(),
block_var_names_set
[
j
].
begin
(),
block_var_names_set
[
jdx
].
begin
(),
block_var_names_set
[
j
].
end
(),
block_var_names_set
[
jdx
].
end
(),
std
::
inserter
(
vars_in_multi_block
,
vars_in_multi_block
.
begin
()
));
std
::
back_inserter
(
vars_in_multi_block
));
for
(
auto
name
:
vars_in_multi_block
)
{
for
(
const
auto
&
name
:
vars_in_multi_block
)
{
vars_in_multi_block_
map
_
.
emplace
(
vars_in_multi_block_
with_pair
_
.
emplace
(
name
,
std
::
make_pair
(
framework
::
proto
::
VarType
::
FP32
,
i
));
name
,
std
::
make_pair
(
VarType
::
FP32
,
idx
));
}
}
}
}
}
}
...
@@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
...
@@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
if
(
op_type
==
"fill_constant"
)
{
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"assign_value"
)
{
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"eye"
)
{
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"cast"
)
{
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
"in_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
"out_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
}
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
VarType
::
FP64
)
{
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
in_var
->
SetDataType
(
VarType
::
FP32
);
in_var
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
}
}
}
}
}
}
...
@@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
...
@@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
void
ConvertToMixedPrecisionPass
::
Run
()
{
void
ConvertToMixedPrecisionPass
::
Run
()
{
LoadAndPrepare
();
LoadAndPrepare
();
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
graphes_
.
size
();
++
i
)
{
auto
graph
=
main_graph_
->
GetSubGraph
(
i
);
auto
*
graph
=
graphes_
[
i
];
graphes_
.
push_back
(
graph
);
VLOG
(
2
)
<<
" -------- handle subgraph "
<<
i
<<
", has "
VLOG
(
2
)
<<
" -------- handle subgraph "
<<
i
<<
", has "
<<
graph
->
Nodes
().
size
()
<<
" nodes --------"
;
<<
graph
->
Nodes
().
size
()
<<
" nodes --------"
;
...
@@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() {
...
@@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() {
// A trick
// A trick
PatchForStrangeOp
();
PatchForStrangeOp
();
CHECK_EQ
(
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
}
}
SaveMixedModel
();
SaveMixedModel
();
}
}
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
int
block_idx
)
{
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
BlockID
block_idx
)
{
auto
graph
=
graphes_
[
block_idx
];
auto
*
graph
=
graphes_
[
block_idx
];
framework
::
proto
::
VarType
::
Type
to_type
;
VarType
::
Type
to_type
;
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
FP16
;
to_type
=
VarType
::
FP16
;
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
BF16
;
to_type
=
VarType
::
BF16
;
}
else
{
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"mixed_precision currently not supported dtype %d, we now only "
"mixed_precision currently not supported dtype %d, we now only "
...
@@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 1. set input dtype.
// 1. set input dtype.
if
(
op_type
==
"feed"
)
{
if
(
op_type
==
"feed"
)
{
auto
feed_var
=
op_node
->
outputs
[
0
]
->
Var
();
auto
feed_var
=
op_node
->
outputs
[
0
]
->
Var
();
if
(
!
keep_io_types_
&&
if
(
!
keep_io_types_
&&
feed_var
->
GetDataType
()
==
VarType
::
FP32
)
{
feed_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
feed_var
->
SetDataType
(
to_type
);
feed_var
->
SetDataType
(
to_type
);
}
}
}
else
if
(
op_type
==
"fetch"
)
{
}
else
if
(
op_type
==
"fetch"
)
{
...
@@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// same name.
// same name.
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
for
(
auto
*
in
:
op_node
->
inputs
)
{
for
(
auto
*
in
:
op_node
->
inputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in
);
if
(
!
in
->
IsVar
())
continue
;
if
(
NodeVarHasDtype
(
real_node
))
{
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in
);
if
(
VarNodeHasDtype
(
real_node
))
{
in_name_to_node
[
in
->
Name
()]
=
in
;
in_name_to_node
[
in
->
Name
()]
=
in
;
}
}
}
}
for
(
auto
out
:
op_node
->
outputs
)
{
for
(
auto
*
out
:
op_node
->
outputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
out
);
if
(
!
out
->
IsVar
())
continue
;
if
(
NodeVarHasDtype
(
real_node
))
{
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out
);
if
(
VarNodeHasDtype
(
real_node
))
{
if
(
in_name_to_node
.
count
(
out
->
Name
()))
if
(
in_name_to_node
.
count
(
out
->
Name
()))
real_node
->
Var
()
->
SetDataType
(
real_node
->
Var
()
->
SetDataType
(
in_name_to_node
[
out
->
Name
()]
->
Var
()
->
GetDataType
());
in_name_to_node
[
out
->
Name
()]
->
Var
()
->
GetDataType
());
...
@@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
// - set output dtype.
//
//
// If a var(op's out var) appears multiple times in
a block
, we should not
// If a var(op's out var) appears multiple times in
graph
, we should not
// convert to fp16.
// convert to fp16.
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
bool
support_precision
=
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
// if op not has float input, we will not choose the low precision kernel.
// If the op has no input and output of float type, we will not choose the
// low precision kernel.
{
{
bool
has_float_input
{
false
};
bool
has_float_input_and_output
{
false
};
for
(
auto
in_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
if
(
!
in_node
->
IsVar
())
continue
;
if
(
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP16
||
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP32
||
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
BF16
)
{
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
has_float_input
=
true
;
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_input_and_output
=
true
;
break
;
break
;
}
}
}
}
if
(
!
has_float_input
)
{
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
!
out_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_input_and_output
=
true
;
break
;
}
}
if
(
!
has_float_input_and_output
)
{
support_precision
=
false
;
support_precision
=
false
;
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
VLOG
(
2
)
<<
" op doesn't has float input
and output
, just skip."
;
}
}
}
}
VLOG
(
2
)
<<
" support low precision "
<<
support_precision
;
VLOG
(
2
)
<<
"op type: "
<<
op_type
<<
" support low precision: "
<<
support_precision
;
if
(
support_precision
)
{
if
(
support_precision
)
{
VLOG
(
2
)
<<
" process input nodes:"
;
VLOG
(
2
)
<<
" process input nodes:"
;
...
@@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// Just for paddle's terriable case: op's input and output has the same
// Just for paddle's terriable case: op's input and output has the same
// name.
// name.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
for
(
auto
out_node
:
op_node
->
outputs
)
{
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
for
(
auto
in_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
}
}
...
@@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
op_node
,
op_node
,
&
suffix_
,
&
suffix_
,
block_desc
,
block_desc
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
block_idx
);
block_idx
);
}
}
}
}
...
@@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp32.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
auto
ins
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_node
:
ins
)
{
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
AddCastOp
(
graph
,
in_node
,
in_node
,
op_node
,
op_node
,
to_type
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
&
suffix_
,
&
suffix_
,
block_desc
,
block_desc
,
&
cast_map_
);
&
cast_map_
);
VLOG
(
3
)
<<
"-- "
<<
in_node
->
Name
()
<<
"("
<<
to_type
<<
") to "
VLOG
(
3
)
<<
"-- "
<<
in_node
->
Name
()
<<
"("
<<
to_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
VarType
::
FP32
<<
")"
;
<<
framework
::
proto
::
VarType
::
FP32
<<
")"
;
}
}
}
}
}
}
...
@@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 4. if output_op's dtype is not compatible to output dtype, then just
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
// insert cast.
for
(
auto
*
node
:
output_nodes
)
{
for
(
auto
*
node
:
output_nodes
)
{
ir
::
Node
*
fetch_op
{
nullptr
};
framework
::
ir
::
Node
*
fetch_op
{
nullptr
};
for
(
auto
*
op_node
:
node
->
outputs
)
{
for
(
auto
*
op_node
:
node
->
outputs
)
{
if
(
op_node
->
IsOp
()
&&
op_node
->
Op
()
->
Type
()
==
"fetch"
)
{
if
(
op_node
->
IsOp
()
&&
op_node
->
Op
()
->
Type
()
==
"fetch"
)
{
fetch_op
=
op_node
;
fetch_op
=
op_node
;
}
}
}
}
CHECK_NOTNULL
(
fetch_op
);
CHECK_NOTNULL
(
fetch_op
);
auto
var
=
node
->
Var
();
auto
*
var
=
node
->
Var
();
if
(
keep_io_types_
&&
var
->
GetDataType
()
==
to_type
)
{
if
(
keep_io_types_
&&
var
->
GetDataType
()
==
to_type
)
{
// fp16/bf16 -> fp32.
// fp16/bf16 -> fp32.
AddCastOp
(
graph
,
AddCastOp
(
graph
,
node
,
node
,
fetch_op
,
fetch_op
,
to_type
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
&
suffix_
,
&
suffix_
,
block_desc
,
block_desc
,
&
cast_map_
);
&
cast_map_
);
}
else
if
(
!
keep_io_types_
&&
}
else
if
(
!
keep_io_types_
&&
var
->
GetDataType
()
==
VarType
::
FP32
)
{
var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
// fp32 -> fp16/bf16
// fp32 -> fp16/bf16
AddCastOp
(
graph
,
AddCastOp
(
graph
,
node
,
node
,
fetch_op
,
fetch_op
,
framework
::
proto
::
VarType
::
FP32
,
VarType
::
FP32
,
to_type
,
to_type
,
&
suffix_
,
&
suffix_
,
block_desc
,
block_desc
,
...
@@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
...
@@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
}
}
}
}
for
(
auto
node
:
graph
->
Nodes
())
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
node
);
if
(
!
node
->
IsVar
())
continue
;
if
(
!
NodeVarHasDtype
(
real_node
))
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
node
);
if
(
!
VarNodeHasDtype
(
real_node
))
continue
;
if
(
vars_in_multi_block_map_
.
count
(
real_node
->
Name
())
&&
if
(
vars_in_multi_block_with_pair_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_map_
.
at
(
real_node
->
Name
()).
second
==
block_idx
)
{
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
second
==
vars_in_multi_block_map_
.
at
(
real_node
->
Name
()).
first
=
block_idx
)
{
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
=
real_node
->
Var
()
->
GetDataType
();
real_node
->
Var
()
->
GetDataType
();
}
}
}
}
...
@@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
main_graph_
,
&
mixed_program_desc
);
framework
::
ir
::
GraphToProgram
(
*
main_graph_
,
&
mixed_program_desc
);
paddle
::
CPUPlace
place
;
auto
parameters
=
scope_
.
LocalVarNames
();
auto
parameters
=
scope_
.
LocalVarNames
();
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
for
(
auto
*
node
:
main_graph_
->
Nodes
())
{
for
(
auto
*
node
:
main_graph_
->
Nodes
())
{
if
(
!
(
node
->
IsVar
()
))
continue
;
if
(
!
node
->
IsVar
(
))
continue
;
if
(
NodeVar
HasDtype
(
node
))
{
if
(
VarNode
HasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
paddle
::
framework
::
proto
::
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
weights_should_be_fp32
.
insert
(
node
->
Name
());
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
}
...
@@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int
i = 0; i < t->numel(); i++) {
\
for (int
64_t i = 0; i < origin_tensor->numel(); i++) {
\
mixed_data[i] = static_cast<dtype>(
data[i]);
\
mixed_data[i] = static_cast<dtype>(
origin_data[i]);
\
} \
} \
t->clear(); \
origin_tensor->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
paddle::framework::TensorCopySync( \
mixed_tensor, platform::CPUPlace(), origin_tensor)
for
(
const
auto
&
param_name
:
parameters
)
{
for
(
const
auto
&
param_name
:
parameters
)
{
if
(
weights_should_be_fp32
.
count
(
param_name
))
continue
;
auto
*
var
=
scope_
.
FindLocalVar
(
param_name
);
auto
*
var
=
scope_
.
FindLocalVar
(
param_name
);
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
auto
*
t
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
origin_tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
t
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
if
(
origin_tensor
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
phi
::
DenseTensor
mixed_tensor
;
phi
::
DenseTensor
mixed_tensor
;
mixed_tensor
.
Resize
(
t
->
dims
());
mixed_tensor
.
Resize
(
origin_tensor
->
dims
());
auto
*
data
=
t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
origin_data
=
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
&&
origin_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
!
weights_should_be_fp32
.
count
(
param_name
)
)
{
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
FLOAT16
,
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
FLOAT16
,
phi
::
dtype
::
float16
);
phi
::
dtype
::
float16
);
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
&&
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
!
weights_should_be_fp32
.
count
(
param_name
))
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
BFLOAT16
,
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
BFLOAT16
,
phi
::
dtype
::
bfloat16
);
phi
::
dtype
::
bfloat16
);
}
}
...
@@ -851,8 +852,8 @@ void AddCastOp(
...
@@ -851,8 +852,8 @@ void AddCastOp(
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
next_op
,
framework
::
ir
::
Node
*
next_op
,
framework
::
proto
::
VarType
::
Type
from_type
,
VarType
::
Type
from_type
,
framework
::
proto
::
VarType
::
Type
to_type
,
VarType
::
Type
to_type
,
int
*
suffix
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
map
)
{
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
map
)
{
...
@@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type,
...
@@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type,
return
support_precision
;
return
support_precision
;
}
}
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
void
ConvertToMixedPrecision
(
const
std
::
string
&
params_file
,
const
std
::
string
&
model_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_params_file
,
const
std
::
string
&
mixed_model_file
,
phi
::
DataType
mixed_precision
,
const
std
::
string
&
mixed_params_file
,
phi
::
Backend
backend
,
phi
::
DataType
mixed_precision
,
bool
keep_io_types
,
phi
::
Backend
backend
,
std
::
unordered_set
<
std
::
string
>
black_list
)
{
bool
keep_io_types
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
{
ConvertToMixedPrecisionPass
pass
(
model_file
,
ConvertToMixedPrecisionPass
pass
(
model_file
,
params_file
,
params_file
,
mixed_model_file
,
mixed_model_file
,
...
...
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
浏览文件 @
0972d6ac
...
@@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
...
@@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const
std
::
string
&
mixed_params_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
bool
keep_io_types
=
true
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
=
{}
);
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
}
// namespace analysis
}
// namespace analysis
}
// namespace inference
}
// namespace inference
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录