Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d373f4ff
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d373f4ff
编写于
7月 21, 2022
作者:
W
Wilber
提交者:
GitHub
7月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix some convert error found in tipc. (#44457)
* fix some error found in tipc. * update
上级
37455714
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
177 addition
and
13 deletion
+177
-13
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+177
-13
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
d373f4ff
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <string>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
...
@@ -22,6 +23,7 @@
...
@@ -22,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
...
@@ -89,6 +91,31 @@ bool OutShouldNotConvert(ir::Node* var_node) {
...
@@ -89,6 +91,31 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return
false
;
return
false
;
}
}
// Get weight names which appear in multiple block (block 0 and block n).
std
::
unordered_set
<
std
::
string
>
GetMultiBlockPersistableNames
(
framework
::
ProgramDesc
*
program_desc
)
{
std
::
unordered_set
<
std
::
string
>
special_weights
;
size_t
block_size
=
program_desc
->
Size
();
std
::
unordered_set
<
std
::
string
>
block_0_weights
;
for
(
auto
var
:
program_desc
->
Block
(
0
).
AllVars
())
{
if
(
var
->
Persistable
())
block_0_weights
.
insert
(
var
->
Name
());
}
for
(
size_t
i
=
1
;
i
<
block_size
;
++
i
)
{
// std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() <<
// std::endl;;
auto
all_ops
=
program_desc
->
Block
(
i
).
AllOps
();
for
(
auto
op
:
all_ops
)
{
for
(
auto
name
:
op
->
InputArgumentNames
())
{
if
(
block_0_weights
.
count
(
name
))
special_weights
.
insert
(
name
);
}
}
}
return
special_weights
;
}
// Just process special cases for weights conversion.
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
auto
op_nodes
=
var_node
->
outputs
;
auto
op_nodes
=
var_node
->
outputs
;
...
@@ -116,19 +143,139 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
...
@@ -116,19 +143,139 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
}
}
}
}
// If cur_op's next is condition_flow op, then cur op should be fp32. Note, we
// now only convert to mixed in block 0.
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
var
:
op_node
->
outputs
)
{
for
(
auto
next_op
:
var
->
outputs
)
{
if
(
next_op
->
Op
()
->
HasAttr
(
"sub_block"
))
{
return
true
;
}
}
}
}
return
false
;
return
false
;
}
}
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
BF16
||
type
==
framework
::
proto
::
VarType
::
BF16
)
type
==
framework
::
proto
::
VarType
::
FP64
)
return
true
;
return
true
;
return
false
;
return
false
;
}
}
void
ConvertTensorDtype
(
framework
::
ir
::
Graph
*
graph
,
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
if
(
op_type
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
else
if
(
op_type
==
"cast"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"in_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
if
(
PADDLE_GET_CONST
(
int
,
op_node
->
Op
()
->
GetAttr
(
"out_dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP64
))
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
));
}
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
in_var
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
}
}
}
}
// Handle special ops which contains dtype attribute. e.g., fill_constant,
// assign_value.
void
HandleSpecialOps
(
framework
::
OpDesc
*
op_desc
)
{
if
(
op_desc
->
Type
()
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
!=
"cast"
)
continue
;
auto
input
=
op_node
->
inputs
[
0
];
auto
output
=
op_node
->
outputs
[
0
];
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
input
->
Var
()
->
GetDataType
()));
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
output
->
Var
()
->
GetDataType
()));
}
}
// If op's output var is condition flow op's input, then the op must be fp32
// precision.
bool
NextOpIncludesConditionFlowOp
(
framework
::
ir
::
Node
*
cur_op_node
)
{
auto
cur_op_outs
=
cur_op_node
->
outputs
;
for
(
auto
out_var
:
cur_op_outs
)
{
for
(
auto
next_op_node
:
out_var
->
outputs
)
{
if
(
next_op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
{
return
true
;
}
}
}
return
false
;
}
void
ConvertTensorDtype
(
framework
::
ProgramDesc
*
program_desc
,
framework
::
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
,
bool
keep_io_types
,
bool
keep_io_types
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
...
@@ -145,13 +292,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
...
@@ -145,13 +292,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
static_cast
<
int
>
(
tensor_dtype
)));
static_cast
<
int
>
(
tensor_dtype
)));
}
}
auto
weight_name_in_multi_block
=
GetMultiBlockPersistableNames
(
program_desc
);
int
num_low_precision
=
0
;
int
num_low_precision
=
0
;
int
suffix
=
0
;
int
suffix
=
0
;
framework
::
BlockDesc
*
block_desc
{
nullptr
};
framework
::
BlockDesc
*
block_desc
{
nullptr
};
std
::
vector
<
framework
::
ir
::
Node
*>
output_nodes
;
std
::
vector
<
framework
::
ir
::
Node
*>
output_nodes
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map
;
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
framework
::
ir
::
TopologySortOperations
(
*
graph
)
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
...
@@ -167,18 +315,29 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
...
@@ -167,18 +315,29 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
auto
*
fetch_var
=
op_node
->
inputs
[
0
];
auto
*
fetch_var
=
op_node
->
inputs
[
0
];
output_nodes
.
push_back
(
fetch_var
);
output_nodes
.
push_back
(
fetch_var
);
continue
;
continue
;
}
else
if
(
op_type
==
"cast"
)
{
continue
;
}
}
// 2. if op support fp16/bf16 and not in blacklist.
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
// - set output dtype.
else
if
(
blacklist
.
count
(
phi_op_type
)
==
0
)
{
// NOLINT
else
if
(
blacklist
.
count
(
phi_op_type
)
==
0
&&
// NOLINT
!
NextOpIncludesConditionFlowOp
(
op_node
))
{
bool
support_precision
=
bool
support_precision
=
OpSupportPrecision
(
phi_op_type
,
backend
,
tensor_dtype
,
blacklist
);
OpSupportPrecision
(
phi_op_type
,
backend
,
tensor_dtype
,
blacklist
);
VLOG
(
2
)
<<
"phi_op_type "
<<
phi_op_type
<<
" support low precision "
VLOG
(
2
)
<<
"op_type "
<<
op_type
<<
", phi_op_type "
<<
phi_op_type
<<
support_precision
;
<<
" support low precision "
<<
support_precision
<<
", "
<<
reinterpret_cast
<
void
*>
(
op_node
->
Op
()
->
Block
());
for
(
auto
in_node
:
op_node
->
inputs
)
{
if
(
weight_name_in_multi_block
.
count
(
in_node
->
Name
()))
support_precision
=
false
;
}
if
(
support_precision
)
{
if
(
support_precision
)
{
HandleSpecialOps
(
op_node
->
Op
());
++
num_low_precision
;
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
...
@@ -232,9 +391,8 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
...
@@ -232,9 +391,8 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// 3. check op not support fp16/bf16 or in blacklist.
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
else
{
// NOLINT
// trt pass should explicitle add cast op is input is bf16/tf32, etc.
auto
ins
=
op_node
->
inputs
;
if
(
op_node
->
Name
()
==
"tensorrt_engine"
)
continue
;
for
(
auto
*
in_node
:
ins
)
{
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
if
(
in_var
->
GetDataType
()
==
to_type
)
{
...
@@ -366,8 +524,14 @@ void ConvertToMixedPrecision(const std::string& model_file,
...
@@ -366,8 +524,14 @@ void ConvertToMixedPrecision(const std::string& model_file,
auto
graph
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
auto
graph
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
new
framework
::
ir
::
Graph
(
*
program_desc
));
new
framework
::
ir
::
Graph
(
*
program_desc
));
ConvertTensorDtype
(
ConvertAllFp64ToFp32
(
graph
.
get
());
graph
.
get
(),
black_list
,
keep_io_types
,
backend
,
mixed_precision
);
ConvertTensorDtype
(
program_desc
.
get
(),
graph
.
get
(),
black_list
,
keep_io_types
,
backend
,
mixed_precision
);
FixCastAttr
(
graph
.
get
());
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
graph
,
&
mixed_program_desc
);
framework
::
ir
::
GraphToProgram
(
*
graph
,
&
mixed_program_desc
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录