Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a00aebe1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a00aebe1
编写于
11月 15, 2022
作者:
W
Wilber
提交者:
GitHub
11月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[convert_to_mixed_precision] fallback to fp32 when encounter circle (#47902)
上级
d4d3d7ed
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
127 addition
and
216 deletion
+127
-216
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+127
-216
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
a00aebe1
...
@@ -40,7 +40,6 @@
...
@@ -40,7 +40,6 @@
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass {
...
@@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass {
black_list_
(
black_list
),
black_list_
(
black_list
),
place_
(
paddle
::
CPUPlace
()),
place_
(
paddle
::
CPUPlace
()),
executor_
(
place_
)
{
executor_
(
place_
)
{
// black_list_.insert("assign");
VLOG
(
4
)
<<
"black_list has "
;
black_list_
.
insert
(
"fill_constant"
);
for
(
auto
&
name
:
black_list_
)
{
black_list_
.
insert
(
"assign_value"
);
VLOG
(
4
)
<<
" - "
<<
name
;
black_list_
.
insert
(
"eye"
);
}
black_list_
.
insert
(
"fill_any_like"
);
black_list_
.
insert
(
"fill_constant_batch_size_like"
);
}
}
void
Run
();
void
Run
();
...
@@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass {
...
@@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass {
// Just process special cases for weights conversion.
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
GetRealVarNode
(
framework
::
ir
::
Node
*
node
);
framework
::
ir
::
Node
*
node
);
void
FindVarsInMultiBlock
();
inline
bool
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
);
private:
// Fallback to fp32 dtype when encounter circle (Not a DAG graph).
// A trick. Patch for strange op, which input name equal to output name, such
void
ProcessCircleCases
();
// as `fused_multi_transformer`
void
PatchForStrangeOp
();
private:
private:
std
::
string
model_file_
;
std
::
string
model_file_
;
...
@@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass {
...
@@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass {
framework
::
Executor
executor_
;
framework
::
Executor
executor_
;
framework
::
Scope
scope_
;
framework
::
Scope
scope_
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
name2node_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarType
::
Type
,
BlockID
>>
vars_in_multi_block_with_pair_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
vars_in_multi_block_with_ops_
;
int
suffix_
{
0
};
int
suffix_
{
0
};
std
::
set
<
std
::
string
>
var_names_in_circles_
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
)
{
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
if
(
name2node_
.
count
(
var_node
->
Name
()))
return
name2node_
[
var_node
->
Name
()];
if
(
vars_in_multi_block_with_pair_
.
count
(
var_node
->
Name
()))
{
auto
origin_blockId
=
vars_in_multi_block_with_pair_
.
at
(
var_node
->
Name
()).
second
;
if
(
block_idx
!=
origin_blockId
)
{
auto
*
graph
=
graphes_
[
origin_blockId
];
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
var_node
->
Name
())
{
return
node
;
}
}
}
}
return
var_node
;
return
var_node
;
}
}
...
@@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
...
@@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
(
type
==
VarType
::
VOCAB
);
(
type
==
VarType
::
VOCAB
);
}
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
if
(
!
var_node
->
IsVar
())
continue
;
auto
*
real_var_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
real_var_node
->
Var
()
->
Persistable
()
&&
vars_in_multi_block_with_ops_
.
count
(
var_node
->
Name
()))
{
for
(
const
auto
&
op_type
:
vars_in_multi_block_with_ops_
.
at
(
var_node
->
Name
()))
{
if
(
!
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
VLOG
(
2
)
<<
var_node
->
Name
()
<<
" is multi precision op's out, so we skip convert to fp16"
;
return
true
;
}
}
}
}
return
false
;
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
bool
support_precision
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
in_node
,
...
@@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
...
@@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
VarType
::
Type
to_type
,
VarType
::
Type
to_type
,
BlockID
block_idx
)
{
BlockID
block_idx
)
{
if
(
!
in_node
->
IsVar
())
return
;
if
(
!
in_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
graph
=
graphes_
[
block_idx
];
auto
*
graph
=
graphes_
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
auto
prev_type
=
in_var_type
;
bool
is_in_multi_block
=
vars_in_multi_block_with_pair_
.
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
in_var_type
=
vars_in_multi_block_with_pair_
.
at
(
in_var
->
Name
()).
first
;
}
if
(
support_precision
)
{
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
...
@@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
...
@@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
if
(
!
var_node
->
IsVar
())
return
;
if
(
!
var_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
var_node
);
auto
*
real_node
=
GetRealVarNode
(
var_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
auto
prev_type
=
out_var
->
GetDataType
();
...
@@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
...
@@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc_
));
new
framework
::
ir
::
Graph
(
*
program_desc_
));
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
graphes_
.
push_back
(
graph
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
if
(
!
name2node_
.
count
(
node
->
Name
()))
{
name2node_
[
node
->
Name
()]
=
node
;
}
}
}
}
// Remove all control var
// Remove all control var
...
@@ -411,46 +364,68 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
...
@@ -411,46 +364,68 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
pass
.
Run
(
&
arg
);
pass
.
Run
(
&
arg
);
FindVarsInMultiBlock
();
ProcessCircleCases
();
}
}
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
// Find var names which in circles.
std
::
unordered_set
<
std
::
string
>
all_var_names_set
;
void
ConvertToMixedPrecisionPass
::
ProcessCircleCases
()
{
std
::
vector
<
std
::
s
et
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
())
;
std
::
vector
<
std
::
s
tring
>
vars_in_circles
;
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
// TODO(inference): batch_norm has circle, but we need to fuse it in conv
// op.
if
(
op
->
Type
()
==
"batch_norm"
)
continue
;
const
auto
&
in_names
=
op
->
InputArgumentNames
();
const
auto
&
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
std
::
set
<
std
::
string
>
in_names_set
(
in_names
.
begin
(),
in_names
.
end
());
std
::
set
<
std
::
string
>
out_names_set
(
out_names
.
begin
(),
out_names
.
end
());
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
std
::
set_intersection
(
in_names_set
.
begin
(),
for
(
const
auto
&
name
:
out_names
)
{
in_names_set
.
end
(),
if
(
all_var_names_set
.
count
(
name
))
{
out_names_set
.
begin
(),
vars_in_multi_block_with_ops_
[
name
].
push_back
(
op
->
Type
());
out_names_set
.
end
(),
}
std
::
back_inserter
(
vars_in_circles
));
}
}
all_var_names_set
.
insert
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
());
}
}
}
}
CHECK_GT
(
program_desc_
->
Size
(),
0U
);
for
(
auto
&
name
:
vars_in_circles
)
{
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
()
-
1
;
++
idx
)
{
var_names_in_circles_
.
insert
(
name
);
for
(
BlockID
jdx
=
idx
+
1
;
jdx
<
program_desc_
->
Size
();
++
jdx
)
{
}
std
::
vector
<
std
::
string
>
vars_in_multi_block
;
for
(
auto
&
name
:
var_names_in_circles_
)
{
std
::
set_intersection
(
block_var_names_set
[
idx
].
begin
(),
LOG
(
INFO
)
<<
name
block_var_names_set
[
idx
].
end
(),
<<
" in circles, so we will skip process those vars and ops."
;
block_var_names_set
[
jdx
].
begin
(),
}
block_var_names_set
[
jdx
].
end
(),
}
std
::
back_inserter
(
vars_in_multi_block
));
inline
void
ProcessConstantOpAttr
(
framework
::
ir
::
Node
*
op_node
,
for
(
const
auto
&
name
:
vars_in_multi_block
)
{
VarType
::
Type
from_type
,
vars_in_multi_block_with_pair_
.
emplace
(
VarType
::
Type
to_type
)
{
name
,
std
::
make_pair
(
VarType
::
Type
(),
idx
));
if
(
!
op_node
->
IsOp
())
return
;
}
auto
op_type
=
op_node
->
Op
()
->
Type
();
}
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
return
;
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
to_type
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
to_type
));
}
}
}
}
...
@@ -460,33 +435,7 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
...
@@ -460,33 +435,7 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP64
,
VarType
::
FP32
);
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
}
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
...
@@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() {
...
@@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() {
ConvertTensorDtype
(
i
);
ConvertTensorDtype
(
i
);
FixCastAttr
(
graph
);
FixCastAttr
(
graph
);
// A trick
PatchForStrangeOp
();
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
}
}
...
@@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
continue
;
continue
;
}
}
// We can not add cast operator before ops who have sub_block, as in
// sub_block we may get a var which may be transformer by cast op.
else
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
// NOLINT
else
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
// NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
for
(
auto
*
in
:
op_node
->
inputs
)
{
if
(
!
in
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in
);
if
(
VarNodeHasDtype
(
real_node
))
{
in_name_to_node
[
in
->
Name
()]
=
in
;
}
}
for
(
auto
*
out
:
op_node
->
outputs
)
{
if
(
!
out
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out
);
if
(
VarNodeHasDtype
(
real_node
))
{
if
(
in_name_to_node
.
count
(
out
->
Name
()))
real_node
->
Var
()
->
SetDataType
(
in_name_to_node
[
out
->
Name
()]
->
Var
()
->
GetDataType
());
}
}
continue
;
continue
;
}
}
...
@@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// - cast weight to fp16/bf16.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
// - set output dtype.
//
else
if
(
black_list_
.
count
(
op_type
)
==
0
)
{
// NOLINT
// If a var(op's out var) appears multiple times in graph, we should not
// convert to fp16.
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
bool
support_precision
=
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
// If the op has no input of float type, we will not choose the
// If op's output in circle, we should not convert to fp16.
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
var_names_in_circles_
.
count
(
out_node
->
Name
()))
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op's output "
<<
out_node
->
Name
()
<<
" is in circle, we can not support this case, just skip."
;
break
;
}
}
// If the op has no input or output of float type, we will not choose the
// low precision kernel.
// low precision kernel.
{
if
(
support_precision
)
{
bool
has_float_in
p
ut
{
false
};
bool
has_float_in
_o
ut
{
false
};
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
!
in_node
->
IsVar
())
continue
;
if
(
!
in_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
if
(
in_node
->
Var
()
->
GetType
()
!=
VarType
::
LOD_TENSOR
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op has tensor array input["
<<
in_node
->
Name
()
<<
"], just skip."
;
break
;
}
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in_out
=
true
;
break
;
}
}
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
!
out_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
out_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in
p
ut
=
true
;
has_float_in
_o
ut
=
true
;
break
;
break
;
}
}
}
}
if
(
!
has_float_in
p
ut
)
{
if
(
!
has_float_in
_o
ut
)
{
support_precision
=
false
;
support_precision
=
false
;
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
VLOG
(
2
)
<<
" op doesn't has float input
and output
, just skip."
;
}
}
}
}
VLOG
(
2
)
<<
"op type: "
<<
op_type
VLOG
(
2
)
<<
"op type: "
<<
op_type
<<
" support low precision: "
<<
support_precision
;
<<
" support low precision: "
<<
support_precision
;
if
(
support_precision
)
{
if
(
support_precision
)
{
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP32
,
to_type
);
VLOG
(
2
)
<<
" process input nodes:"
;
VLOG
(
2
)
<<
" process input nodes:"
;
++
num_low_precision
;
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
// Just for paddle's terriable case: op's input and output has the same
// name.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
}
}
}
// Process inputs.
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
ProcessInputNode
(
ProcessInputNode
(
true
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
to_type
,
block_idx
);
true
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
to_type
,
block_idx
);
if
(
names_map
.
count
(
in_node
->
Name
())
&&
cast_map_
.
count
(
in_node
))
{
names_map
[
in_node
->
Name
()]
=
cast_map_
[
in_node
]
->
Name
();
}
}
}
VLOG
(
2
)
<<
" process output nodes:"
;
VLOG
(
2
)
<<
" process output nodes:"
;
// Process outputs.
auto
outputs
=
op_node
->
outputs
;
for
(
auto
*
out_node
:
o
p_node
->
o
utputs
)
{
for
(
auto
*
out_node
:
outputs
)
{
ProcessOutputNode
(
block_idx
,
out_node
,
to_type
);
ProcessOutputNode
(
block_idx
,
out_node
,
to_type
);
}
}
}
else
{
}
else
{
...
@@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// 3. check op not support fp16/bf16 or in blacklist.
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
<<
", node input size "
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
<<
op_node
->
inputs
.
size
();
auto
in_nodes
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
in_nodes
)
{
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
AddCastOp
(
graph
,
...
@@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
}
}
}
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
node
);
if
(
!
VarNodeHasDtype
(
real_node
))
continue
;
if
(
vars_in_multi_block_with_pair_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
second
==
block_idx
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
==
VarType
::
Type
())
{
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
=
real_node
->
Var
()
->
GetDataType
();
}
}
if
(
num_low_precision
)
if
(
num_low_precision
)
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
...
@@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// We modify op's input output precision, and we need to fix cast op in_dtype
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
// and out_dtype attribute.
// TODO(inference): we need a cast elimination pass.
void
ConvertToMixedPrecisionPass
::
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
void
ConvertToMixedPrecisionPass
::
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
...
@@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
if
(
VarNodeHasDtype
(
node
))
{
if
(
VarNodeHasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
()
<<
", ptr "
<<
reinterpret_cast
<
void
*>
(
node
->
Var
());
weights_should_be_fp32
.
insert
(
node
->
Name
());
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
}
}
}
...
@@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
std
::
ostringstream
os
;
std
::
ostringstream
os
;
phi
::
CPUContext
ctx
;
phi
::
CPUContext
ctx
;
for
(
const
auto
&
param
:
parameters
)
{
for
(
const
auto
&
param
:
parameters
)
{
VLOG
(
3
)
<<
"Serialize param: "
<<
param
;
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
scope_
.
FindVar
(
param
),
scope_
.
FindVar
(
param
),
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
...
@@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
}
}
void
ConvertToMixedPrecisionPass
::
PatchForStrangeOp
()
{
for
(
auto
*
graph
:
graphes_
)
{
for
(
auto
op_node
:
framework
::
ir
::
TopologySortOperations
(
*
graph
))
{
if
(
op_node
->
Name
()
==
"fused_multi_transformer"
)
{
auto
cache_kv_inputs
=
op_node
->
Op
()
->
Input
(
"CacheKV"
);
auto
cache_kv_outputs
=
op_node
->
Op
()
->
Output
(
"CacheKVOut"
);
CHECK_EQ
(
cache_kv_inputs
.
size
(),
cache_kv_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
cache_kv_inputs
.
size
();
++
i
)
{
op_node
->
Op
()
->
RenameOutput
(
cache_kv_outputs
[
i
],
cache_kv_inputs
[
i
]);
}
}
}
}
}
}
// namespace
}
// namespace
void
AddCastOp
(
void
AddCastOp
(
...
@@ -893,6 +803,7 @@ void AddCastOp(
...
@@ -893,6 +803,7 @@ void AddCastOp(
}
}
next_op
->
Op
()
->
Rename
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
next_op
->
Op
()
->
Rename
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_UNLINK
(
node
,
next_op
);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录