Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a00aebe1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a00aebe1
编写于
11月 15, 2022
作者:
W
Wilber
提交者:
GitHub
11月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[convert_to_mixed_precision] fallback to fp32 when encounter circle (#47902)
上级
d4d3d7ed
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
127 addition
and
216 deletion
+127
-216
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+127
-216
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
a00aebe1
...
...
@@ -40,7 +40,6 @@
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass {
black_list_
(
black_list
),
place_
(
paddle
::
CPUPlace
()),
executor_
(
place_
)
{
// black_list_.insert("assign");
black_list_
.
insert
(
"fill_constant"
);
black_list_
.
insert
(
"assign_value"
);
black_list_
.
insert
(
"eye"
);
black_list_
.
insert
(
"fill_any_like"
);
black_list_
.
insert
(
"fill_constant_batch_size_like"
);
VLOG
(
4
)
<<
"black_list has "
;
for
(
auto
&
name
:
black_list_
)
{
VLOG
(
4
)
<<
" - "
<<
name
;
}
}
void
Run
();
...
...
@@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass {
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
node
);
void
FindVarsInMultiBlock
();
inline
bool
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
);
framework
::
ir
::
Node
*
GetRealVarNode
(
framework
::
ir
::
Node
*
node
);
private:
// A trick. Patch for strange op, which input name equal to output name, such
// as `fused_multi_transformer`
void
PatchForStrangeOp
();
// Fallback to fp32 dtype when encounter circle (Not a DAG graph).
void
ProcessCircleCases
();
private:
std
::
string
model_file_
;
...
...
@@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass {
framework
::
Executor
executor_
;
framework
::
Scope
scope_
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
name2node_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarType
::
Type
,
BlockID
>>
vars_in_multi_block_with_pair_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
vars_in_multi_block_with_ops_
;
int
suffix_
{
0
};
std
::
set
<
std
::
string
>
var_names_in_circles_
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
)
{
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
if
(
vars_in_multi_block_with_pair_
.
count
(
var_node
->
Name
()))
{
auto
origin_blockId
=
vars_in_multi_block_with_pair_
.
at
(
var_node
->
Name
()).
second
;
if
(
block_idx
!=
origin_blockId
)
{
auto
*
graph
=
graphes_
[
origin_blockId
];
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
var_node
->
Name
())
{
return
node
;
}
}
}
}
if
(
name2node_
.
count
(
var_node
->
Name
()))
return
name2node_
[
var_node
->
Name
()];
return
var_node
;
}
...
...
@@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
(
type
==
VarType
::
VOCAB
);
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
if
(
!
var_node
->
IsVar
())
continue
;
auto
*
real_var_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
real_var_node
->
Var
()
->
Persistable
()
&&
vars_in_multi_block_with_ops_
.
count
(
var_node
->
Name
()))
{
for
(
const
auto
&
op_type
:
vars_in_multi_block_with_ops_
.
at
(
var_node
->
Name
()))
{
if
(
!
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
VLOG
(
2
)
<<
var_node
->
Name
()
<<
" is multi precision op's out, so we skip convert to fp16"
;
return
true
;
}
}
}
}
return
false
;
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
framework
::
ir
::
Node
*
in_node
,
...
...
@@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
VarType
::
Type
to_type
,
BlockID
block_idx
)
{
if
(
!
in_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
graph
=
graphes_
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
bool
is_in_multi_block
=
vars_in_multi_block_with_pair_
.
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
in_var_type
=
vars_in_multi_block_with_pair_
.
at
(
in_var
->
Name
()).
first
;
}
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
...
...
@@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
if
(
!
var_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
var_node
);
auto
*
real_node
=
GetRealVarNode
(
var_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
...
...
@@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc_
));
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
if
(
!
name2node_
.
count
(
node
->
Name
()))
{
name2node_
[
node
->
Name
()]
=
node
;
}
}
}
// Remove all control var
...
...
@@ -411,82 +364,78 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
pass
.
Run
(
&
arg
);
FindVarsInMultiBlock
();
ProcessCircleCases
();
}
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
std
::
unordered_set
<
std
::
string
>
all_var_names_set
;
std
::
vector
<
std
::
s
et
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
())
;
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
// Find var names which in circles.
void
ConvertToMixedPrecisionPass
::
ProcessCircleCases
()
{
std
::
vector
<
std
::
s
tring
>
vars_in_circles
;
for
(
size_t
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
// TODO(inference): batch_norm has circle, but we need to fuse it in conv
// op.
if
(
op
->
Type
()
==
"batch_norm"
)
continue
;
const
auto
&
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
for
(
const
auto
&
name
:
out_names
)
{
if
(
all_var_names_set
.
count
(
name
))
{
vars_in_multi_block_with_ops_
[
name
].
push_back
(
op
->
Type
());
}
}
}
all_var_names_set
.
insert
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
());
std
::
set
<
std
::
string
>
in_names_set
(
in_names
.
begin
(),
in_names
.
end
());
std
::
set
<
std
::
string
>
out_names_set
(
out_names
.
begin
(),
out_names
.
end
());
std
::
set_intersection
(
in_names_set
.
begin
(),
in_names_set
.
end
(),
out_names_set
.
begin
(),
out_names_set
.
end
(),
std
::
back_inserter
(
vars_in_circles
));
}
}
CHECK_GT
(
program_desc_
->
Size
(),
0U
);
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
()
-
1
;
++
idx
)
{
for
(
BlockID
jdx
=
idx
+
1
;
jdx
<
program_desc_
->
Size
();
++
jdx
)
{
std
::
vector
<
std
::
string
>
vars_in_multi_block
;
std
::
set_intersection
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
(),
block_var_names_set
[
jdx
].
begin
(),
block_var_names_set
[
jdx
].
end
(),
std
::
back_inserter
(
vars_in_multi_block
));
for
(
const
auto
&
name
:
vars_in_multi_block
)
{
vars_in_multi_block_with_pair_
.
emplace
(
name
,
std
::
make_pair
(
VarType
::
Type
(),
idx
));
}
for
(
auto
&
name
:
vars_in_circles
)
{
var_names_in_circles_
.
insert
(
name
);
}
for
(
auto
&
name
:
var_names_in_circles_
)
{
LOG
(
INFO
)
<<
name
<<
" in circles, so we will skip process those vars and ops."
;
}
}
void
ConvertToMixedPrecisionPass
::
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
inline
void
ProcessConstantOpAttr
(
framework
::
ir
::
Node
*
op_node
,
VarType
::
Type
from_type
,
VarType
::
Type
to_type
)
{
if
(
!
op_node
->
IsOp
())
return
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
return
;
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
to_type
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
to_type
));
}
}
void
ConvertToMixedPrecisionPass
::
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP64
,
VarType
::
FP32
);
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
...
...
@@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() {
ConvertTensorDtype
(
i
);
FixCastAttr
(
graph
);
// A trick
PatchForStrangeOp
();
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
}
...
...
@@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
continue
;
}
// We can not add cast operator before ops who have sub_block, as in
// sub_block we may get a var which may be transformer by cast op.
else
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
// NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
for
(
auto
*
in
:
op_node
->
inputs
)
{
if
(
!
in
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in
);
if
(
VarNodeHasDtype
(
real_node
))
{
in_name_to_node
[
in
->
Name
()]
=
in
;
}
}
for
(
auto
*
out
:
op_node
->
outputs
)
{
if
(
!
out
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out
);
if
(
VarNodeHasDtype
(
real_node
))
{
if
(
in_name_to_node
.
count
(
out
->
Name
()))
real_node
->
Var
()
->
SetDataType
(
in_name_to_node
[
out
->
Name
()]
->
Var
()
->
GetDataType
());
}
}
continue
;
}
...
...
@@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
//
// If a var(op's out var) appears multiple times in graph, we should not
// convert to fp16.
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
else
if
(
black_list_
.
count
(
op_type
)
==
0
)
{
// NOLINT
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
// If the op has no input of float type, we will not choose the
// If op's output in circle, we should not convert to fp16.
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
var_names_in_circles_
.
count
(
out_node
->
Name
()))
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op's output "
<<
out_node
->
Name
()
<<
" is in circle, we can not support this case, just skip."
;
break
;
}
}
// If the op has no input or output of float type, we will not choose the
// low precision kernel.
{
bool
has_float_in
p
ut
{
false
};
if
(
support_precision
)
{
bool
has_float_in
_o
ut
{
false
};
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
!
in_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
if
(
in_node
->
Var
()
->
GetType
()
!=
VarType
::
LOD_TENSOR
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op has tensor array input["
<<
in_node
->
Name
()
<<
"], just skip."
;
break
;
}
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in_out
=
true
;
break
;
}
}
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
!
out_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
out_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in
p
ut
=
true
;
has_float_in
_o
ut
=
true
;
break
;
}
}
if
(
!
has_float_in
p
ut
)
{
if
(
!
has_float_in
_o
ut
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
VLOG
(
2
)
<<
" op doesn't has float input
and output
, just skip."
;
}
}
VLOG
(
2
)
<<
"op type: "
<<
op_type
<<
" support low precision: "
<<
support_precision
;
if
(
support_precision
)
{
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP32
,
to_type
);
VLOG
(
2
)
<<
" process input nodes:"
;
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
// Just for paddle's terriable case: op's input and output has the same
// name.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
}
}
}
// Process inputs.
for
(
auto
*
in_node
:
inputs
)
{
ProcessInputNode
(
true
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
to_type
,
block_idx
);
if
(
names_map
.
count
(
in_node
->
Name
())
&&
cast_map_
.
count
(
in_node
))
{
names_map
[
in_node
->
Name
()]
=
cast_map_
[
in_node
]
->
Name
();
}
}
VLOG
(
2
)
<<
" process output nodes:"
;
// Process outputs.
for
(
auto
*
out_node
:
o
p_node
->
o
utputs
)
{
auto
outputs
=
op_node
->
outputs
;
for
(
auto
*
out_node
:
outputs
)
{
ProcessOutputNode
(
block_idx
,
out_node
,
to_type
);
}
}
else
{
...
...
@@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
<<
", node input size "
<<
op_node
->
inputs
.
size
();
auto
in_nodes
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
in_nodes
)
{
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
...
...
@@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
}
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
node
);
if
(
!
VarNodeHasDtype
(
real_node
))
continue
;
if
(
vars_in_multi_block_with_pair_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
second
==
block_idx
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
==
VarType
::
Type
())
{
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
=
real_node
->
Var
()
->
GetDataType
();
}
}
if
(
num_low_precision
)
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
...
...
@@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
// TODO(inference): we need a cast elimination pass.
void
ConvertToMixedPrecisionPass
::
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
...
...
@@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
if
(
VarNodeHasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
()
<<
", ptr "
<<
reinterpret_cast
<
void
*>
(
node
->
Var
());
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
}
...
...
@@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
std
::
ostringstream
os
;
phi
::
CPUContext
ctx
;
for
(
const
auto
&
param
:
parameters
)
{
VLOG
(
3
)
<<
"Serialize param: "
<<
param
;
PADDLE_ENFORCE_NOT_NULL
(
scope_
.
FindVar
(
param
),
platform
::
errors
::
NotFound
(
...
...
@@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
}
void
ConvertToMixedPrecisionPass
::
PatchForStrangeOp
()
{
for
(
auto
*
graph
:
graphes_
)
{
for
(
auto
op_node
:
framework
::
ir
::
TopologySortOperations
(
*
graph
))
{
if
(
op_node
->
Name
()
==
"fused_multi_transformer"
)
{
auto
cache_kv_inputs
=
op_node
->
Op
()
->
Input
(
"CacheKV"
);
auto
cache_kv_outputs
=
op_node
->
Op
()
->
Output
(
"CacheKVOut"
);
CHECK_EQ
(
cache_kv_inputs
.
size
(),
cache_kv_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
cache_kv_inputs
.
size
();
++
i
)
{
op_node
->
Op
()
->
RenameOutput
(
cache_kv_outputs
[
i
],
cache_kv_inputs
[
i
]);
}
}
}
}
}
}
// namespace
void
AddCastOp
(
...
...
@@ -893,6 +803,7 @@ void AddCastOp(
}
next_op
->
Op
()
->
Rename
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_UNLINK
(
node
,
next_op
);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录