Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d373f4ff
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d373f4ff
编写于
7月 21, 2022
作者:
W
Wilber
提交者:
GitHub
7月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix some convert error found in tipc. (#44457)
* fix some error found in tipc. * update
上级
37455714
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
177 addition
and
13 deletion
+177
-13
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+177
-13
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
d373f4ff
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
...
...
@@ -22,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
...
...
@@ -89,6 +91,31 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return
false
;
}
// Get weight names which appear in multiple block (block 0 and block n).
std
::
unordered_set
<
std
::
string
>
GetMultiBlockPersistableNames
(
framework
::
ProgramDesc
*
program_desc
)
{
std
::
unordered_set
<
std
::
string
>
special_weights
;
size_t
block_size
=
program_desc
->
Size
();
std
::
unordered_set
<
std
::
string
>
block_0_weights
;
for
(
auto
var
:
program_desc
->
Block
(
0
).
AllVars
())
{
if
(
var
->
Persistable
())
block_0_weights
.
insert
(
var
->
Name
());
}
for
(
size_t
i
=
1
;
i
<
block_size
;
++
i
)
{
// std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() <<
// std::endl;;
auto
all_ops
=
program_desc
->
Block
(
i
).
AllOps
();
for
(
auto
op
:
all_ops
)
{
for
(
auto
name
:
op
->
InputArgumentNames
())
{
if
(
block_0_weights
.
count
(
name
))
special_weights
.
insert
(
name
);
}
}
}
return
special_weights
;
}
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
auto
op_nodes
=
var_node
->
outputs
;
...
...
@@ -116,19 +143,139 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
}
}
// If cur_op's next is condition_flow op, then cur op should be fp32. Note, we
// now only convert to mixed in block 0.
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
var
:
op_node
->
outputs
)
{
for
(
auto
next_op
:
var
->
outputs
)
{
if
(
next_op
->
Op
()
->
HasAttr
(
"sub_block"
))
{
return
true
;
}
}
}
}
return
false
;
}
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
BF16
||
type
==
framework
::
proto
::
VarType
::
FP64
)
type
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
}
void
ConvertTensorDtype
(
framework
::
ir
::
Graph
*
graph
,
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
in_var
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
}
}
}
}
// Handle special ops which contains dtype attribute. e.g., fill_constant,
// assign_value.
void
HandleSpecialOps
(
framework
::
OpDesc
*
op_desc
)
{
if
(
op_desc
->
Type
()
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
!=
"cast"
)
continue
;
auto
input
=
op_node
->
inputs
[
0
];
auto
output
=
op_node
->
outputs
[
0
];
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
input
->
Var
()
->
GetDataType
()));
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
output
->
Var
()
->
GetDataType
()));
}
}
// If op's output var is condition flow op's input, then the op must be fp32
// precision.
bool
NextOpIncludesConditionFlowOp
(
framework
::
ir
::
Node
*
cur_op_node
)
{
auto
cur_op_outs
=
cur_op_node
->
outputs
;
for
(
auto
out_var
:
cur_op_outs
)
{
for
(
auto
next_op_node
:
out_var
->
outputs
)
{
if
(
next_op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
return
true
;
}
}
}
return
false
;
}
void
ConvertTensorDtype
(
framework
::
ProgramDesc
*
program_desc
,
framework
::
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
,
bool
keep_io_types
,
phi
::
Backend
backend
,
...
...
@@ -145,13 +292,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
static_cast
<
int
>
(
tensor_dtype
)));
}
auto
weight_name_in_multi_block
=
GetMultiBlockPersistableNames
(
program_desc
);
int
num_low_precision
=
0
;
int
suffix
=
0
;
framework
::
BlockDesc
*
block_desc
{
nullptr
};
std
::
vector
<
framework
::
ir
::
Node
*>
output_nodes
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map
;
for
(
auto
*
op_node
:
framework
::
ir
::
TopologySortOperations
(
*
graph
)
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
...
...
@@ -167,18 +315,29 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
auto
*
fetch_var
=
op_node
->
inputs
[
0
];
output_nodes
.
push_back
(
fetch_var
);
continue
;
}
else
if
(
op_type
==
"cast"
)
{
continue
;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
else
if
(
blacklist
.
count
(
phi_op_type
)
==
0
)
{
// NOLINT
else
if
(
blacklist
.
count
(
phi_op_type
)
==
0
&&
// NOLINT
!
NextOpIncludesConditionFlowOp
(
op_node
))
{
bool
support_precision
=
OpSupportPrecision
(
phi_op_type
,
backend
,
tensor_dtype
,
blacklist
);
VLOG
(
2
)
<<
"phi_op_type "
<<
phi_op_type
<<
" support low precision "
<<
support_precision
;
VLOG
(
2
)
<<
"op_type "
<<
op_type
<<
", phi_op_type "
<<
phi_op_type
<<
" support low precision "
<<
support_precision
<<
", "
<<
reinterpret_cast
<
void
*>
(
op_node
->
Op
()
->
Block
());
for
(
auto
in_node
:
op_node
->
inputs
)
{
if
(
weight_name_in_multi_block
.
count
(
in_node
->
Name
()))
support_precision
=
false
;
}
if
(
support_precision
)
{
HandleSpecialOps
(
op_node
->
Op
());
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
...
...
@@ -232,9 +391,8 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
// trt pass should explicitle add cast op is input is bf16/tf32, etc.
if
(
op_node
->
Name
()
==
"tensorrt_engine"
)
continue
;
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
auto
ins
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
ins
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
...
...
@@ -366,8 +524,14 @@ void ConvertToMixedPrecision(const std::string& model_file,
auto
graph
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc
));
ConvertTensorDtype
(
graph
.
get
(),
black_list
,
keep_io_types
,
backend
,
mixed_precision
);
ConvertAllFp64ToFp32
(
graph
.
get
());
ConvertTensorDtype
(
program_desc
.
get
(),
graph
.
get
(),
black_list
,
keep_io_types
,
backend
,
mixed_precision
);
FixCastAttr
(
graph
.
get
());
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
graph
,
&
mixed_program_desc
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录