Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cb4eea92
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cb4eea92
编写于
7月 13, 2022
作者:
W
Wilber
提交者:
GitHub
7月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix convert error. (#44307)
上级
5a312fb9
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
17 addition
and
4 deletion
+17
-4
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+17
-4
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
cb4eea92
...
...
@@ -119,6 +119,15 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
return
false
;
}
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
BF16
||
type
==
framework
::
proto
::
VarType
::
FP64
)
return
true
;
return
false
;
}
void
ConvertTensorDtype
(
framework
::
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
,
bool
keep_io_types
,
...
...
@@ -146,8 +155,6 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
// LOG(INFO) << "process op " << op_type << ", corresponding phi type is "
// << phi_op_type;
// 1. set input dtype.
if
(
op_type
==
"feed"
)
{
block_desc
=
op_node
->
Op
()
->
Block
();
...
...
@@ -175,12 +182,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
continue
;
in_var
->
SetDataType
(
to_type
);
}
else
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var
->
GetDataType
())
&&
in_var
->
GetDataType
()
!=
to_type
)
{
AddCastOp
(
graph
,
in_node
,
...
...
@@ -193,6 +202,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
}
}
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
if
(
out_node
->
IsCtrlVar
())
continue
;
auto
*
out_var
=
out_node
->
Var
();
if
(
out_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
OutShouldNotConvert
(
out_node
))
continue
;
...
...
@@ -202,8 +212,9 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
}
else
{
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var
->
GetDataType
())
&&
in_var
->
GetDataType
()
!=
framework
::
proto
::
VarType
::
FP32
)
{
AddCastOp
(
graph
,
in_node
,
...
...
@@ -224,6 +235,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// trt pass should explicitle add cast op is input is bf16/tf32, etc.
if
(
op_node
->
Name
()
==
"tensorrt_engine"
)
continue
;
for
(
auto
*
in_node
:
op_node
->
inputs
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
...
...
@@ -242,6 +254,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// 4. if output_op's dtype is not compatible to output dtype, then just insert
// cast.
for
(
auto
*
node
:
output_nodes
)
{
if
(
node
->
IsCtrlVar
())
continue
;
auto
var
=
node
->
Var
();
if
(
keep_io_types
&&
var
->
GetDataType
()
==
to_type
)
{
// fp16/bf16 -> fp32.
...
...
@@ -381,7 +394,7 @@ void ConvertToMixedPrecision(const std::string& model_file,
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
(
))
continue
;
if
(
!
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
))
continue
;
if
(
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
node
->
Var
()
->
GetType
()
==
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录