Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a00aebe1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
a00aebe1
编写于
11月 15, 2022
作者:
W
Wilber
提交者:
GitHub
11月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[convert_to_mixed_precision] fallback to fp32 when encounter circle (#47902)
上级
d4d3d7ed
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
127 addition
and
216 deletion
+127
-216
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+127
-216
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
a00aebe1
...
@@ -40,7 +40,6 @@
...
@@ -40,7 +40,6 @@
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass {
...
@@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass {
black_list_
(
black_list
),
black_list_
(
black_list
),
place_
(
paddle
::
CPUPlace
()),
place_
(
paddle
::
CPUPlace
()),
executor_
(
place_
)
{
executor_
(
place_
)
{
// black_list_.insert("assign");
VLOG
(
4
)
<<
"black_list has "
;
black_list_
.
insert
(
"fill_constant"
);
for
(
auto
&
name
:
black_list_
)
{
black_list_
.
insert
(
"assign_value"
);
VLOG
(
4
)
<<
" - "
<<
name
;
black_list_
.
insert
(
"eye"
);
}
black_list_
.
insert
(
"fill_any_like"
);
black_list_
.
insert
(
"fill_constant_batch_size_like"
);
}
}
void
Run
();
void
Run
();
...
@@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass {
...
@@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass {
// Just process special cases for weights conversion.
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
bool
WeightsShouldNotConvert
(
framework
::
ir
::
Node
*
var_node
);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
GetRealVarNode
(
framework
::
ir
::
Node
*
node
);
framework
::
ir
::
Node
*
node
);
void
FindVarsInMultiBlock
();
inline
bool
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
);
private:
// Fallback to fp32 dtype when encounter circle (Not a DAG graph).
// A trick. Patch for strange op, which input name equal to output name, such
void
ProcessCircleCases
();
// as `fused_multi_transformer`
void
PatchForStrangeOp
();
private:
private:
std
::
string
model_file_
;
std
::
string
model_file_
;
...
@@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass {
...
@@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass {
framework
::
Executor
executor_
;
framework
::
Executor
executor_
;
framework
::
Scope
scope_
;
framework
::
Scope
scope_
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
name2node_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
VarType
::
Type
,
BlockID
>>
vars_in_multi_block_with_pair_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
vars_in_multi_block_with_ops_
;
int
suffix_
{
0
};
int
suffix_
{
0
};
std
::
set
<
std
::
string
>
var_names_in_circles_
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealVarNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
)
{
framework
::
ir
::
Node
*
var_node
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
if
(
name2node_
.
count
(
var_node
->
Name
()))
return
name2node_
[
var_node
->
Name
()];
if
(
vars_in_multi_block_with_pair_
.
count
(
var_node
->
Name
()))
{
auto
origin_blockId
=
vars_in_multi_block_with_pair_
.
at
(
var_node
->
Name
()).
second
;
if
(
block_idx
!=
origin_blockId
)
{
auto
*
graph
=
graphes_
[
origin_blockId
];
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
var_node
->
Name
())
{
return
node
;
}
}
}
}
return
var_node
;
return
var_node
;
}
}
...
@@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
...
@@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
(
type
==
VarType
::
VOCAB
);
(
type
==
VarType
::
VOCAB
);
}
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
if
(
!
var_node
->
IsVar
())
continue
;
auto
*
real_var_node
=
GetRealVarNode
(
block_idx
,
var_node
);
if
(
!
real_var_node
->
Var
()
->
Persistable
()
&&
vars_in_multi_block_with_ops_
.
count
(
var_node
->
Name
()))
{
for
(
const
auto
&
op_type
:
vars_in_multi_block_with_ops_
.
at
(
var_node
->
Name
()))
{
if
(
!
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
VLOG
(
2
)
<<
var_node
->
Name
()
<<
" is multi precision op's out, so we skip convert to fp16"
;
return
true
;
}
}
}
}
return
false
;
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
bool
support_precision
,
framework
::
ir
::
Node
*
in_node
,
framework
::
ir
::
Node
*
in_node
,
...
@@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
...
@@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
VarType
::
Type
to_type
,
VarType
::
Type
to_type
,
BlockID
block_idx
)
{
BlockID
block_idx
)
{
if
(
!
in_node
->
IsVar
())
return
;
if
(
!
in_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
graph
=
graphes_
[
block_idx
];
auto
*
graph
=
graphes_
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
auto
prev_type
=
in_var_type
;
bool
is_in_multi_block
=
vars_in_multi_block_with_pair_
.
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
in_var_type
=
vars_in_multi_block_with_pair_
.
at
(
in_var
->
Name
()).
first
;
}
if
(
support_precision
)
{
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
...
@@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
...
@@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
BlockID
block_idx
,
framework
::
ir
::
Node
*
var_node
,
VarType
::
Type
to_type
)
{
if
(
!
var_node
->
IsVar
())
return
;
if
(
!
var_node
->
IsVar
())
return
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
var_node
);
auto
*
real_node
=
GetRealVarNode
(
var_node
);
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
if
(
!
VarNodeHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
auto
prev_type
=
out_var
->
GetDataType
();
...
@@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
...
@@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc_
));
new
framework
::
ir
::
Graph
(
*
program_desc_
));
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
auto
*
graph
=
main_graph_
->
GetSubGraph
(
i
);
graphes_
.
push_back
(
graph
);
graphes_
.
push_back
(
graph
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
if
(
!
name2node_
.
count
(
node
->
Name
()))
{
name2node_
[
node
->
Name
()]
=
node
;
}
}
}
}
// Remove all control var
// Remove all control var
...
@@ -411,82 +364,78 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
...
@@ -411,82 +364,78 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
pass
.
Run
(
&
arg
);
pass
.
Run
(
&
arg
);
FindVarsInMultiBlock
();
ProcessCircleCases
();
}
}
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
// Find var names which in circles.
std
::
unordered_set
<
std
::
string
>
all_var_names_set
;
void
ConvertToMixedPrecisionPass
::
ProcessCircleCases
()
{
std
::
vector
<
std
::
s
et
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
())
;
std
::
vector
<
std
::
s
tring
>
vars_in_circles
;
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
program_desc_
->
Size
();
++
idx
)
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
for
(
auto
*
op
:
program_desc_
->
Block
(
idx
).
AllOps
())
{
// TODO(inference): batch_norm has circle, but we need to fuse it in conv
// op.
if
(
op
->
Type
()
==
"batch_norm"
)
continue
;
const
auto
&
in_names
=
op
->
InputArgumentNames
();
const
auto
&
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
const
auto
&
out_names
=
op
->
OutputArgumentNames
();
block_var_names_set
[
idx
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
std
::
set
<
std
::
string
>
in_names_set
(
in_names
.
begin
(),
in_names
.
end
());
std
::
set
<
std
::
string
>
out_names_set
(
out_names
.
begin
(),
out_names
.
end
());
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
std
::
set_intersection
(
in_names_set
.
begin
(),
for
(
const
auto
&
name
:
out_names
)
{
in_names_set
.
end
(),
if
(
all_var_names_set
.
count
(
name
))
{
out_names_set
.
begin
(),
vars_in_multi_block_with_ops_
[
name
].
push_back
(
op
->
Type
());
out_names_set
.
end
(),
}
std
::
back_inserter
(
vars_in_circles
));
}
}
all_var_names_set
.
insert
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
());
}
}
}
}
CHECK_GT
(
program_desc_
->
Size
(),
0U
);
for
(
auto
&
name
:
vars_in_circles
)
{
for
(
BlockID
idx
=
0
;
idx
<
program_desc_
->
Size
()
-
1
;
++
idx
)
{
var_names_in_circles_
.
insert
(
name
);
for
(
BlockID
jdx
=
idx
+
1
;
jdx
<
program_desc_
->
Size
();
++
jdx
)
{
std
::
vector
<
std
::
string
>
vars_in_multi_block
;
std
::
set_intersection
(
block_var_names_set
[
idx
].
begin
(),
block_var_names_set
[
idx
].
end
(),
block_var_names_set
[
jdx
].
begin
(),
block_var_names_set
[
jdx
].
end
(),
std
::
back_inserter
(
vars_in_multi_block
));
for
(
const
auto
&
name
:
vars_in_multi_block
)
{
vars_in_multi_block_with_pair_
.
emplace
(
name
,
std
::
make_pair
(
VarType
::
Type
(),
idx
));
}
}
}
for
(
auto
&
name
:
var_names_in_circles_
)
{
LOG
(
INFO
)
<<
name
<<
" in circles, so we will skip process those vars and ops."
;
}
}
}
}
void
ConvertToMixedPrecisionPass
::
ConvertAllFp64ToFp32
(
inline
void
ProcessConstantOpAttr
(
framework
::
ir
::
Node
*
op_node
,
framework
::
ir
::
Graph
*
graph
)
{
VarType
::
Type
from_type
,
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
VarType
::
Type
to_type
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
return
;
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
return
;
if
(
op_type
==
"fill_constant"
)
{
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"assign_value"
)
{
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"eye"
)
{
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
to_type
));
}
else
if
(
op_type
==
"cast"
)
{
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
to_type
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
VarType
::
FP64
))
static_cast
<
int
>
(
from_type
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
VarType
::
FP32
));
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
to_type
));
}
}
}
void
ConvertToMixedPrecisionPass
::
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP64
,
VarType
::
FP32
);
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
...
@@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() {
...
@@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() {
ConvertTensorDtype
(
i
);
ConvertTensorDtype
(
i
);
FixCastAttr
(
graph
);
FixCastAttr
(
graph
);
// A trick
PatchForStrangeOp
();
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
CHECK_EQ
(
framework
::
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
}
}
...
@@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
continue
;
continue
;
}
}
// We can not add cast operator before ops who have sub_block, as in
// sub_block we may get a var which may be transformer by cast op.
else
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
// NOLINT
else
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
// NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
for
(
auto
*
in
:
op_node
->
inputs
)
{
if
(
!
in
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in
);
if
(
VarNodeHasDtype
(
real_node
))
{
in_name_to_node
[
in
->
Name
()]
=
in
;
}
}
for
(
auto
*
out
:
op_node
->
outputs
)
{
if
(
!
out
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
out
);
if
(
VarNodeHasDtype
(
real_node
))
{
if
(
in_name_to_node
.
count
(
out
->
Name
()))
real_node
->
Var
()
->
SetDataType
(
in_name_to_node
[
out
->
Name
()]
->
Var
()
->
GetDataType
());
}
}
continue
;
continue
;
}
}
...
@@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// - cast weight to fp16/bf16.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
// - set output dtype.
//
else
if
(
black_list_
.
count
(
op_type
)
==
0
)
{
// NOLINT
// If a var(op's out var) appears multiple times in graph, we should not
// convert to fp16.
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
bool
support_precision
=
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
// If the op has no input of float type, we will not choose the
// If op's output in circle, we should not convert to fp16.
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
var_names_in_circles_
.
count
(
out_node
->
Name
()))
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op's output "
<<
out_node
->
Name
()
<<
" is in circle, we can not support this case, just skip."
;
break
;
}
}
// If the op has no input or output of float type, we will not choose the
// low precision kernel.
// low precision kernel.
{
if
(
support_precision
)
{
bool
has_float_in
p
ut
{
false
};
bool
has_float_in
_o
ut
{
false
};
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
!
in_node
->
IsVar
())
continue
;
if
(
!
in_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
in_node
);
if
(
in_node
->
Var
()
->
GetType
()
!=
VarType
::
LOD_TENSOR
)
{
support_precision
=
false
;
VLOG
(
2
)
<<
" op has tensor array input["
<<
in_node
->
Name
()
<<
"], just skip."
;
break
;
}
auto
*
real_node
=
GetRealVarNode
(
in_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in_out
=
true
;
break
;
}
}
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
!
out_node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
out_node
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
if
(
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
real_node
->
Var
()
->
GetDataType
()
==
VarType
::
BF16
)
{
has_float_in
p
ut
=
true
;
has_float_in
_o
ut
=
true
;
break
;
break
;
}
}
}
}
if
(
!
has_float_in
p
ut
)
{
if
(
!
has_float_in
_o
ut
)
{
support_precision
=
false
;
support_precision
=
false
;
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
VLOG
(
2
)
<<
" op doesn't has float input
and output
, just skip."
;
}
}
}
}
VLOG
(
2
)
<<
"op type: "
<<
op_type
VLOG
(
2
)
<<
"op type: "
<<
op_type
<<
" support low precision: "
<<
support_precision
;
<<
" support low precision: "
<<
support_precision
;
if
(
support_precision
)
{
if
(
support_precision
)
{
ProcessConstantOpAttr
(
op_node
,
VarType
::
FP32
,
to_type
);
VLOG
(
2
)
<<
" process input nodes:"
;
VLOG
(
2
)
<<
" process input nodes:"
;
++
num_low_precision
;
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
// Just for paddle's terriable case: op's input and output has the same
// name.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
}
}
}
// Process inputs.
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
ProcessInputNode
(
ProcessInputNode
(
true
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
to_type
,
block_idx
);
true
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
to_type
,
block_idx
);
if
(
names_map
.
count
(
in_node
->
Name
())
&&
cast_map_
.
count
(
in_node
))
{
names_map
[
in_node
->
Name
()]
=
cast_map_
[
in_node
]
->
Name
();
}
}
}
VLOG
(
2
)
<<
" process output nodes:"
;
VLOG
(
2
)
<<
" process output nodes:"
;
// Process outputs.
auto
outputs
=
op_node
->
outputs
;
for
(
auto
*
out_node
:
o
p_node
->
o
utputs
)
{
for
(
auto
*
out_node
:
outputs
)
{
ProcessOutputNode
(
block_idx
,
out_node
,
to_type
);
ProcessOutputNode
(
block_idx
,
out_node
,
to_type
);
}
}
}
else
{
}
else
{
...
@@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// 3. check op not support fp16/bf16 or in blacklist.
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
<<
", node input size "
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
<<
op_node
->
inputs
.
size
();
auto
in_nodes
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
in_nodes
)
{
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
AddCastOp
(
graph
,
...
@@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
}
}
}
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
())
continue
;
auto
*
real_node
=
GetRealVarNode
(
block_idx
,
node
);
if
(
!
VarNodeHasDtype
(
real_node
))
continue
;
if
(
vars_in_multi_block_with_pair_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
second
==
block_idx
&&
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
==
VarType
::
Type
())
{
vars_in_multi_block_with_pair_
.
at
(
real_node
->
Name
()).
first
=
real_node
->
Var
()
->
GetDataType
();
}
}
if
(
num_low_precision
)
if
(
num_low_precision
)
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
...
@@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
...
@@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// We modify op's input output precision, and we need to fix cast op in_dtype
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
// and out_dtype attribute.
// TODO(inference): we need a cast elimination pass.
void
ConvertToMixedPrecisionPass
::
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
void
ConvertToMixedPrecisionPass
::
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
...
@@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
if
(
VarNodeHasDtype
(
node
))
{
if
(
VarNodeHasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
node
->
Var
()
->
GetDataType
()
==
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
()
<<
", ptr "
<<
reinterpret_cast
<
void
*>
(
node
->
Var
());
weights_should_be_fp32
.
insert
(
node
->
Name
());
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
}
}
}
...
@@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
std
::
ostringstream
os
;
std
::
ostringstream
os
;
phi
::
CPUContext
ctx
;
phi
::
CPUContext
ctx
;
for
(
const
auto
&
param
:
parameters
)
{
for
(
const
auto
&
param
:
parameters
)
{
VLOG
(
3
)
<<
"Serialize param: "
<<
param
;
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
scope_
.
FindVar
(
param
),
scope_
.
FindVar
(
param
),
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
...
@@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
...
@@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
}
}
void
ConvertToMixedPrecisionPass
::
PatchForStrangeOp
()
{
for
(
auto
*
graph
:
graphes_
)
{
for
(
auto
op_node
:
framework
::
ir
::
TopologySortOperations
(
*
graph
))
{
if
(
op_node
->
Name
()
==
"fused_multi_transformer"
)
{
auto
cache_kv_inputs
=
op_node
->
Op
()
->
Input
(
"CacheKV"
);
auto
cache_kv_outputs
=
op_node
->
Op
()
->
Output
(
"CacheKVOut"
);
CHECK_EQ
(
cache_kv_inputs
.
size
(),
cache_kv_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
cache_kv_inputs
.
size
();
++
i
)
{
op_node
->
Op
()
->
RenameOutput
(
cache_kv_outputs
[
i
],
cache_kv_inputs
[
i
]);
}
}
}
}
}
}
// namespace
}
// namespace
void
AddCastOp
(
void
AddCastOp
(
...
@@ -893,6 +803,7 @@ void AddCastOp(
...
@@ -893,6 +803,7 @@ void AddCastOp(
}
}
next_op
->
Op
()
->
Rename
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
next_op
->
Op
()
->
Rename
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_UNLINK
(
node
,
next_op
);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录