Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
de6f15b6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
de6f15b6
编写于
10月 18, 2022
作者:
W
Wilber
提交者:
GitHub
10月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reconstruct code for convert_fp16 (#46428) (#47087)
上级
2cc8797e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
505 addition
and
527 deletion
+505
-527
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+498
-493
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
...id/inference/analysis/passes/convert_to_mixed_precision.h
+1
-1
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
...ence/analysis/passes/ir_params_sync_among_devices_pass.cc
+6
-33
未找到文件。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
de6f15b6
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <algorithm>
#include <algorithm>
#include <iterator>
#include <iterator>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
...
@@ -31,9 +32,14 @@
...
@@ -31,9 +32,14 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
using
namespace
paddle
::
framework
;
// NOLINT
using
namespace
paddle
::
framework
;
// NOLINT
...
@@ -43,160 +49,6 @@ namespace inference {
...
@@ -43,160 +49,6 @@ namespace inference {
namespace
analysis
{
namespace
analysis
{
namespace
{
namespace
{
inline
std
::
string
SerializeParams
(
framework
::
Scope
*
scope
,
const
std
::
vector
<
std
::
string
>&
params
)
{
std
::
ostringstream
os
;
phi
::
CPUContext
ctx
;
for
(
const
auto
&
param
:
params
)
{
VLOG
(
3
)
<<
"Serialize param: "
<<
param
;
PADDLE_ENFORCE_NOT_NULL
(
scope
->
FindVar
(
param
),
platform
::
errors
::
NotFound
(
"Block should already have a '%s' variable"
,
param
));
auto
*
tensor
=
scope
->
FindVar
(
param
)
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
SerializeToStream
(
os
,
*
tensor
,
ctx
);
}
return
os
.
str
();
}
inline
void
StrToBinary
(
const
std
::
string
&
path
,
const
std
::
string
&
str
)
{
std
::
ofstream
file
(
path
.
c_str
(),
std
::
ios
::
binary
);
file
.
write
(
str
.
c_str
(),
str
.
size
());
file
.
close
();
}
inline
bool
NodeVarHasDtype
(
framework
::
ir
::
Node
*
node
)
{
if
(
node
->
IsCtrlVar
())
return
false
;
if
(
node
->
IsVar
()
&&
(
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
STRINGS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
VOCAB
))
{
return
true
;
}
return
false
;
}
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealNode
(
const
std
::
vector
<
framework
::
ir
::
Graph
*>&
graphes
,
int
block_idx
,
framework
::
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>*
vars_in_multi_block_map
)
{
if
(
vars_in_multi_block_map
->
count
(
node
->
Name
()))
{
int
var_origin_block_id
=
vars_in_multi_block_map
->
at
(
node
->
Name
()).
second
;
if
(
block_idx
!=
var_origin_block_id
)
{
auto
graph
=
graphes
[
var_origin_block_id
];
for
(
auto
nd
:
graph
->
Nodes
())
{
if
(
nd
->
Name
()
==
node
->
Name
())
{
return
nd
;
}
}
}
}
return
node
;
}
inline
bool
VarIsMultiOpsOut
(
const
std
::
vector
<
framework
::
ir
::
Graph
*>&
graphes
,
int
block_idx
,
framework
::
ir
::
Node
*
op_node
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>*
vars_in_multi_block_map
,
const
std
::
vector
<
std
::
set
<
std
::
string
>>&
vars_appear_multi_in_one_block
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
for
(
auto
*
out
:
op_node
->
outputs
)
{
if
(
out
->
IsCtrlVar
())
continue
;
auto
*
real_node
=
GetRealNode
(
graphes
,
block_idx
,
out
,
vars_in_multi_block_map
);
if
(
!
real_node
->
Var
()
->
Persistable
()
&&
vars_appear_multi_in_one_block
[
block_idx
].
count
(
out
->
Name
()))
{
VLOG
(
2
)
<<
out
->
Name
()
<<
" is multi op's out, so we skip convert to fp16"
;
return
true
;
}
}
return
false
;
}
void
SaveMixedModel
(
framework
::
ir
::
Graph
*
graph
,
framework
::
Scope
*
scope
,
framework
::
ProgramDesc
*
mixed_program_desc
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
const
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>&
vars_in_multi_block_map
)
{
paddle
::
CPUPlace
place
;
auto
parameters
=
scope
->
LocalVarNames
();
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()))
continue
;
if
(
NodeVarHasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
paddle
::
framework
::
proto
::
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
}
}
for
(
const
auto
&
param_name
:
parameters
)
{
auto
*
var
=
scope
->
FindLocalVar
(
param_name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
()
||
var
->
IsType
<
framework
::
Tensor
>
())
{
auto
*
t
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
if
(
t
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
framework
::
Tensor
mixed_tensor
;
mixed_tensor
.
Resize
(
t
->
dims
());
auto
*
data
=
t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
if
(
mixed_precision
==
phi
::
DataType
::
FLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
))
{
mixed_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
auto
*
mixed_data
=
mixed_tensor
.
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
t
->
numel
();
i
++
)
{
mixed_data
[
i
]
=
static_cast
<
float16
>
(
data
[
i
]);
}
t
->
clear
();
paddle
::
framework
::
TensorCopySync
(
mixed_tensor
,
place
,
t
);
}
else
if
(
mixed_precision
==
phi
::
DataType
::
BFLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
))
{
mixed_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
BFLOAT16
);
auto
*
mixed_data
=
mixed_tensor
.
mutable_data
<
bfloat16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
t
->
numel
();
i
++
)
{
mixed_data
[
i
]
=
static_cast
<
bfloat16
>
(
data
[
i
]);
}
t
->
clear
();
paddle
::
framework
::
TensorCopySync
(
mixed_tensor
,
place
,
t
);
}
}
}
StrToBinary
(
mixed_model_file
,
mixed_program_desc
->
Proto
()
->
SerializeAsString
());
StrToBinary
(
mixed_params_file
,
SerializeParams
(
scope
,
parameters
));
}
bool
PhiKernelSupportPrecision
(
bool
PhiKernelSupportPrecision
(
const
std
::
string
&
op_type
,
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
...
@@ -235,8 +87,236 @@ bool GpuKernelSupportPrecision(
...
@@ -235,8 +87,236 @@ bool GpuKernelSupportPrecision(
return
res
;
return
res
;
}
}
class
ConvertToMixedPrecisionPass
{
public:
explicit
ConvertToMixedPrecisionPass
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
)
:
model_file_
(
model_file
),
params_file_
(
params_file
),
mixed_model_file_
(
mixed_model_file
),
mixed_params_file_
(
mixed_params_file
),
mixed_precision_
(
mixed_precision
),
backend_
(
backend
),
keep_io_types_
(
keep_io_types
),
black_list_
(
black_list
),
place_
(
paddle
::
CPUPlace
()),
executor_
(
place_
)
{
black_list_
.
insert
(
"assign"
);
black_list_
.
insert
(
"fill_constant"
);
black_list_
.
insert
(
"assign_value"
);
black_list_
.
insert
(
"eye"
);
black_list_
.
insert
(
"fill_any_like"
);
black_list_
.
insert
(
"fill_constant_batch_size_like"
);
}
void
Run
();
private:
void
LoadAndPrepare
();
inline
bool
NodeVarHasDtype
(
framework
::
ir
::
Node
*
node
);
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
);
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
);
void
SaveMixedModel
();
void
ConvertTensorDtype
(
int
block_idx
);
void
ProcessInputNode
(
bool
support_precision
,
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
block_idx
);
void
ProcessOutputNode
(
int
block_idx
,
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
);
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
);
bool
OutShouldNotConvert
(
ir
::
Node
*
var_node
);
// Just process special cases for weights conversion.
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework
::
ir
::
Node
*
GetRealNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
);
void
FindVarsInMultiBlock
();
inline
bool
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
framework
::
ir
::
Node
*
op_node
);
private:
// A trick. Patch for strange op, which input name equal to output name, such
// as `fused_multi_transformer`
void
PatchForStrangeOp
();
private:
std
::
string
model_file_
;
std
::
string
params_file_
;
std
::
string
mixed_model_file_
;
std
::
string
mixed_params_file_
;
phi
::
DataType
mixed_precision_
;
phi
::
Backend
backend_
;
bool
keep_io_types_
;
std
::
unordered_set
<
std
::
string
>
black_list_
;
paddle
::
CPUPlace
place_
;
framework
::
Executor
executor_
;
framework
::
Scope
scope_
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map_
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>
vars_in_multi_block_map_
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
vars_appear_multi_in_one_block_
;
int
suffix_
{
0
};
std
::
unique_ptr
<
framework
::
ProgramDesc
>
program_desc_
{
nullptr
};
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes_
;
};
framework
::
ir
::
Node
*
ConvertToMixedPrecisionPass
::
GetRealNode
(
int
block_idx
,
framework
::
ir
::
Node
*
node
)
{
if
(
vars_in_multi_block_map_
.
count
(
node
->
Name
()))
{
int
var_origin_block_id
=
vars_in_multi_block_map_
.
at
(
node
->
Name
()).
second
;
if
(
block_idx
!=
var_origin_block_id
)
{
auto
graph
=
graphes_
[
var_origin_block_id
];
for
(
auto
nd
:
graph
->
Nodes
())
{
if
(
nd
->
Name
()
==
node
->
Name
())
{
return
nd
;
}
}
}
}
return
node
;
}
inline
bool
ConvertToMixedPrecisionPass
::
NodeVarHasDtype
(
framework
::
ir
::
Node
*
node
)
{
if
(
node
->
IsVar
()
&&
(
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
STRINGS
||
node
->
Var
()
->
GetType
()
==
paddle
::
framework
::
proto
::
VarType
::
VOCAB
))
{
return
true
;
}
return
false
;
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline
bool
ConvertToMixedPrecisionPass
::
VarIsMultiPrecisionOpsOut
(
int
block_idx
,
framework
::
ir
::
Node
*
op_node
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
bool
ret
{
false
};
for
(
auto
*
out
:
op_node
->
outputs
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
out
);
if
(
!
real_node
->
Var
()
->
Persistable
()
&&
vars_appear_multi_in_one_block_
[
block_idx
].
count
(
out
->
Name
()))
{
for
(
auto
op_type
:
vars_appear_multi_in_one_block_
[
block_idx
].
at
(
out
->
Name
()))
{
if
(
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
))
{
ret
=
true
;
VLOG
(
2
)
<<
out
->
Name
()
<<
" is multi precision op's out, so we skip convert to fp16"
;
break
;
}
}
}
if
(
ret
)
break
;
}
return
ret
;
}
void
ConvertToMixedPrecisionPass
::
ProcessInputNode
(
bool
support_precision
,
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
block_idx
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
graph
=
graphes_
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
auto
prev_type
=
in_var_type
;
bool
is_in_multi_block
=
vars_in_multi_block_map_
.
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
in_var_type
=
vars_in_multi_block_map_
.
at
(
in_var
->
Name
()).
first
;
}
if
(
support_precision
)
{
if
(
in_var
->
Persistable
()
&&
in_var_type
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
in_var
->
SetDataType
(
to_type
);
in_var_type
=
to_type
;
VLOG
(
3
)
<<
" in_node name "
<<
in_var
->
Name
()
<<
" from "
<<
prev_type
<<
" to "
<<
to_type
;
}
else
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var_type
)
&&
in_var_type
!=
to_type
)
{
AddCastOp
(
graph
,
in_node
,
op_node
,
in_var_type
,
to_type
,
suffix
,
block_desc
,
&
cast_map_
);
VLOG
(
3
)
<<
" in_node name "
<<
in_var
->
Name
()
<<
"("
<<
prev_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
to_type
<<
")"
;
}
}
else
{
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var_type
)
&&
in_var_type
!=
to_type
)
{
AddCastOp
(
graph
,
in_node
,
op_node
,
in_var_type
,
to_type
,
suffix
,
block_desc
,
&
cast_map_
);
VLOG
(
3
)
<<
" in_node name "
<<
in_var
->
Name
()
<<
"("
<<
prev_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
to_type
<<
")"
;
}
}
}
void
ConvertToMixedPrecisionPass
::
ProcessOutputNode
(
int
block_idx
,
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
)
{
auto
*
real_node
=
GetRealNode
(
block_idx
,
var_node
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
auto
prev_type
=
out_var
->
GetDataType
();
if
(
out_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
OutShouldNotConvert
(
var_node
))
return
;
out_var
->
SetDataType
(
to_type
);
}
VLOG
(
3
)
<<
" out_node name "
<<
var_node
->
Name
()
<<
" from dtype "
<<
prev_type
<<
" to "
<<
out_var
->
GetDataType
();
}
// Just process special cases.
// Just process special cases.
bool
OutShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
ConvertToMixedPrecisionPass
::
OutShouldNotConvert
(
ir
::
Node
*
var_node
)
{
auto
op_node
=
var_node
->
inputs
[
0
];
auto
op_node
=
var_node
->
inputs
[
0
];
auto
*
op_desc
=
op_node
->
Op
();
auto
*
op_desc
=
op_node
->
Op
();
...
@@ -262,28 +342,8 @@ bool OutShouldNotConvert(ir::Node* var_node) {
...
@@ -262,28 +342,8 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return
false
;
return
false
;
}
}
void
ProcessOutputNode
(
const
std
::
vector
<
framework
::
ir
::
Graph
*>&
graphes
,
int
block_idx
,
ir
::
Node
*
var_node
,
framework
::
proto
::
VarType
::
Type
to_type
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>*
vars_in_multi_block_map
)
{
auto
*
real_node
=
GetRealNode
(
graphes
,
block_idx
,
var_node
,
vars_in_multi_block_map
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
*
out_var
=
real_node
->
Var
();
if
(
out_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
OutShouldNotConvert
(
var_node
))
return
;
out_var
->
SetDataType
(
to_type
);
}
VLOG
(
3
)
<<
" out_node name "
<<
var_node
->
Name
()
<<
" data_type "
<<
out_var
->
GetDataType
();
}
// Just process special cases for weights conversion.
bool
ConvertToMixedPrecisionPass
::
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
bool
WeightsShouldNotConvert
(
ir
::
Node
*
var_node
)
{
auto
op_nodes
=
var_node
->
outputs
;
auto
op_nodes
=
var_node
->
outputs
;
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
auto
*
op_desc
=
op_node
->
Op
();
auto
*
op_desc
=
op_node
->
Op
();
...
@@ -331,72 +391,69 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
...
@@ -331,72 +391,69 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
return
false
;
return
false
;
}
}
inline
bool
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
inline
bool
ConvertToMixedPrecisionPass
::
IsFloatVarType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
if
(
type
==
framework
::
proto
::
VarType
::
FP16
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
FP32
||
type
==
framework
::
proto
::
VarType
::
BF16
)
type
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
true
;
return
false
;
return
false
;
}
}
void
ProcessInputNode
(
bool
support_precision
,
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes
,
ir
::
Node
*
in_node
,
ir
::
Node
*
op_node
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
cast_map
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
block_idx
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>*
vars_in_multi_block_map
)
{
auto
*
real_node
=
GetRealNode
(
graphes
,
block_idx
,
in_node
,
vars_in_multi_block_map
);
if
(
!
NodeVarHasDtype
(
real_node
))
return
;
auto
graph
=
graphes
[
block_idx
];
bool
is_main_block
=
block_idx
==
0
;
auto
*
in_var
=
real_node
->
Var
();
auto
in_var_type
=
in_var
->
GetDataType
();
bool
is_in_multi_block
=
vars_in_multi_block_map
->
count
(
in_var
->
Name
());
if
(
!
is_main_block
&&
is_in_multi_block
)
{
void
ConvertToMixedPrecisionPass
::
LoadAndPrepare
()
{
in_var_type
=
vars_in_multi_block_map
->
at
(
in_var
->
Name
()).
first
;
program_desc_
=
}
inference
::
Load
(
&
executor_
,
&
scope_
,
model_file_
,
params_file_
);
if
(
support_precision
)
{
main_graph_
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
if
(
in_var
->
Persistable
()
&&
new
framework
::
ir
::
Graph
(
*
program_desc_
));
in_var_type
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
WeightsShouldNotConvert
(
in_node
))
return
;
// Remove all control var
in_var
->
SetDataType
(
to_type
);
IrInferCleanGraphPass
pass
;
in_var_type
=
to_type
;
Argument
arg
;
}
else
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var_type
)
&&
arg
.
SetMainGraphNotOwned
(
main_graph_
.
get
());
in_var_type
!=
to_type
)
{
pass
.
Run
(
&
arg
);
AddCastOp
(
graph
,
in_node
,
vars_appear_multi_in_one_block_
.
resize
(
program_desc_
->
Size
());
op_node
,
FindVarsInMultiBlock
();
in_var_type
,
}
to_type
,
suffix
,
void
ConvertToMixedPrecisionPass
::
FindVarsInMultiBlock
()
{
block_desc
,
std
::
vector
<
std
::
set
<
std
::
string
>>
block_var_names_set
(
program_desc_
->
Size
());
cast_map
);
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
();
++
i
)
{
for
(
auto
op
:
program_desc_
->
Block
(
i
).
AllOps
())
{
auto
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
i
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
auto
out_names
=
op
->
OutputArgumentNames
();
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
for
(
auto
&
n
:
out_names
)
{
if
(
block_var_names_set
[
i
].
count
(
n
))
{
vars_appear_multi_in_one_block_
[
i
][
n
].
push_back
(
op
->
Type
());
}
}
}
block_var_names_set
[
i
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
}
}
}
else
{
}
if
(
!
in_var
->
Persistable
()
&&
IsFloatVarType
(
in_var_type
)
&&
in_var_type
!=
to_type
)
{
for
(
size_t
i
=
0
;
i
<
program_desc_
->
Size
()
-
1
;
++
i
)
{
AddCastOp
(
graph
,
for
(
size_t
j
=
i
+
1
;
j
<
program_desc_
->
Size
();
++
j
)
{
in_node
,
std
::
set
<
std
::
string
>
vars_in_multi_block
;
op_node
,
std
::
set_intersection
(
in_var_type
,
block_var_names_set
[
i
].
begin
(),
to_type
,
block_var_names_set
[
i
].
end
(),
suffix
,
block_var_names_set
[
j
].
begin
(),
block_desc
,
block_var_names_set
[
j
].
end
(),
cast_map
);
std
::
inserter
(
vars_in_multi_block
,
vars_in_multi_block
.
begin
()));
for
(
auto
name
:
vars_in_multi_block
)
{
vars_in_multi_block_map_
.
emplace
(
name
,
std
::
make_pair
(
framework
::
proto
::
VarType
::
FP32
,
i
));
}
}
}
}
}
VLOG
(
3
)
<<
" in_node name "
<<
in_var
->
Name
()
<<
" data_type "
<<
in_var_type
;
}
}
void
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
void
ConvertToMixedPrecisionPass
::
ConvertAllFp64ToFp32
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
!
op_node
->
IsOp
())
continue
;
...
@@ -436,7 +493,6 @@ void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
...
@@ -436,7 +493,6 @@ void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
if
(
!
in_var
->
Persistable
()
&&
if
(
!
in_var
->
Persistable
()
&&
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
in_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
...
@@ -446,158 +502,47 @@ void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
...
@@ -446,158 +502,47 @@ void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
}
}
}
}
// Handle special ops which contains dtype attribute. e.g., fill_constant,
void
ConvertToMixedPrecisionPass
::
Run
()
{
// assign_value.
LoadAndPrepare
();
void
HandleSpecialOps
(
framework
::
OpDesc
*
op_desc
)
{
if
(
op_desc
->
Type
()
==
"fill_constant"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"assign_value"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"eye"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"fill_any_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
else
if
(
op_desc
->
Type
()
==
"fill_constant_batch_size_like"
)
{
if
(
PADDLE_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
))
==
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP32
))
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
proto
::
VarType
::
FP16
));
}
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
!=
"cast"
)
continue
;
auto
input
=
op_node
->
inputs
[
0
];
auto
output
=
op_node
->
outputs
[
0
];
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
input
->
Var
()
->
GetDataType
()));
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
output
->
Var
()
->
GetDataType
()));
}
}
void
FindVarsInMultiBlock
(
for
(
size_t
i
=
0
;
i
<
main_graph_
->
SubGraphsSize
();
++
i
)
{
framework
::
ProgramDesc
*
program_desc
,
auto
graph
=
main_graph_
->
GetSubGraph
(
i
);
std
::
unordered_map
<
std
::
string
,
graphes_
.
push_back
(
graph
);
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>*
VLOG
(
2
)
<<
" -------- handle subgraph "
<<
i
<<
", has "
vars_in_multi_block_map
,
<<
graph
->
Nodes
().
size
()
<<
" nodes --------"
;
std
::
vector
<
std
::
set
<
std
::
string
>>*
vars_appear_multi_in_one_block
)
{
std
::
vector
<
std
::
set
<
std
::
string
>>
block_var_names_set
(
program_desc
->
Size
());
for
(
size_t
i
=
0
;
i
<
program_desc
->
Size
();
++
i
)
{
for
(
auto
op
:
program_desc
->
Block
(
i
).
AllOps
())
{
auto
in_names
=
op
->
InputArgumentNames
();
block_var_names_set
[
i
].
insert
(
in_names
.
begin
(),
in_names
.
end
());
auto
out_names
=
op
->
OutputArgumentNames
();
if
(
op
->
HasAttr
(
"sub_block"
)
==
false
)
{
for
(
auto
&
n
:
out_names
)
{
if
(
block_var_names_set
[
i
].
count
(
n
))
{
(
*
vars_appear_multi_in_one_block
)[
i
].
insert
(
n
);
}
}
}
block_var_names_set
[
i
].
insert
(
out_names
.
begin
(),
out_names
.
end
());
}
}
for
(
size_t
i
=
0
;
i
<
program_desc
->
Size
()
-
1
;
++
i
)
{
ConvertAllFp64ToFp32
(
graph
);
for
(
size_t
j
=
i
+
1
;
j
<
program_desc
->
Size
();
++
j
)
{
ConvertTensorDtype
(
i
);
std
::
set
<
std
::
string
>
vars_in_multi_block
;
FixCastAttr
(
graph
);
std
::
set_intersection
(
block_var_names_set
[
i
].
begin
(),
block_var_names_set
[
i
].
end
(),
block_var_names_set
[
j
].
begin
(),
block_var_names_set
[
j
].
end
(),
std
::
inserter
(
vars_in_multi_block
,
vars_in_multi_block
.
begin
()));
for
(
auto
name
:
vars_in_multi_block
)
{
// A trick
vars_in_multi_block_map
->
emplace
(
PatchForStrangeOp
();
name
,
std
::
make_pair
(
framework
::
proto
::
VarType
::
FP32
,
i
));
}
}
}
}
bool
OpInOutHasTensorArray
(
CHECK_EQ
(
ir
::
VarDescIsConsistency
(
*
graph
),
true
);
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes
,
int
block_idx
,
framework
::
ir
::
Node
*
op_node
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>*
vars_in_multi_block_map
)
{
CHECK_EQ
(
op_node
->
IsOp
(),
true
);
for
(
auto
in
:
op_node
->
inputs
)
{
auto
*
real_node
=
GetRealNode
(
graphes
,
block_idx
,
in
,
vars_in_multi_block_map
);
if
(
!
NodeVarHasDtype
(
real_node
))
continue
;
if
(
real_node
->
Var
()
->
GetType
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
return
true
;
}
}
for
(
auto
out
:
op_node
->
outputs
)
{
SaveMixedModel
();
auto
*
real_node
=
GetRealNode
(
graphes
,
block_idx
,
out
,
vars_in_multi_block_map
);
if
(
!
NodeVarHasDtype
(
real_node
))
continue
;
if
(
real_node
->
Var
()
->
GetType
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
return
true
;
}
return
false
;
}
}
void
ConvertTensorDtype
(
void
ConvertToMixedPrecisionPass
::
ConvertTensorDtype
(
int
block_idx
)
{
framework
::
ProgramDesc
*
program_desc
,
auto
graph
=
graphes_
[
block_idx
];
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
,
bool
keep_io_types
,
phi
::
Backend
backend
,
phi
::
DataType
tensor_dtype
,
int
block_idx
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>*
vars_in_multi_block_map
,
const
std
::
vector
<
std
::
set
<
std
::
string
>>&
vars_appear_multi_in_one_block
)
{
auto
graph
=
graphes
[
block_idx
];
framework
::
proto
::
VarType
::
Type
to_type
;
framework
::
proto
::
VarType
::
Type
to_type
;
if
(
tensor_dtype
==
phi
::
DataType
::
FLOAT16
)
{
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
FP16
;
to_type
=
framework
::
proto
::
VarType
::
FP16
;
}
else
if
(
tensor_dtype
==
phi
::
DataType
::
BFLOAT16
)
{
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
to_type
=
framework
::
proto
::
VarType
::
BF16
;
to_type
=
framework
::
proto
::
VarType
::
BF16
;
}
else
{
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"mixed_precision currently not supported dtype %d, we now only "
"mixed_precision currently not supported dtype %d, we now only "
"support fp16 and bf16."
,
"support fp16 and bf16."
,
static_cast
<
int
>
(
tensor_dtype
)));
static_cast
<
int
>
(
mixed_precision_
)));
}
}
auto
*
block_desc
=
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
framework
::
ir
::
TopologySortOperations
(
*
graph
)[
0
]
->
Op
()
->
Block
();
auto
*
block_desc
=
op_nodes
[
0
]
->
Op
()
->
Block
();
int
num_low_precision
=
0
;
int
num_low_precision
=
0
;
int
suffix
=
0
;
std
::
vector
<
framework
::
ir
::
Node
*>
output_nodes
;
std
::
vector
<
framework
::
ir
::
Node
*>
output_nodes
;
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>
cast_map
;
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
for
(
auto
*
op_node
:
op_nodes
)
{
for
(
auto
*
op_node
:
op_nodes
)
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
!
op_node
->
IsOp
())
continue
;
auto
op_type
=
op_node
->
Op
()
->
Type
();
auto
op_type
=
op_node
->
Op
()
->
Type
();
...
@@ -606,7 +551,7 @@ void ConvertTensorDtype(
...
@@ -606,7 +551,7 @@ void ConvertTensorDtype(
// 1. set input dtype.
// 1. set input dtype.
if
(
op_type
==
"feed"
)
{
if
(
op_type
==
"feed"
)
{
auto
feed_var
=
op_node
->
outputs
[
0
]
->
Var
();
auto
feed_var
=
op_node
->
outputs
[
0
]
->
Var
();
if
(
!
keep_io_types
&&
if
(
!
keep_io_types
_
&&
feed_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
feed_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
feed_var
->
SetDataType
(
to_type
);
feed_var
->
SetDataType
(
to_type
);
}
}
...
@@ -623,16 +568,14 @@ void ConvertTensorDtype(
...
@@ -623,16 +568,14 @@ void ConvertTensorDtype(
// same name.
// same name.
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
std
::
unordered_map
<
std
::
string
,
framework
::
ir
::
Node
*>
in_name_to_node
;
for
(
auto
*
in
:
op_node
->
inputs
)
{
for
(
auto
*
in
:
op_node
->
inputs
)
{
auto
*
real_node
=
auto
*
real_node
=
GetRealNode
(
block_idx
,
in
);
GetRealNode
(
graphes
,
block_idx
,
in
,
vars_in_multi_block_map
);
if
(
NodeVarHasDtype
(
real_node
))
{
if
(
NodeVarHasDtype
(
real_node
))
{
in_name_to_node
[
in
->
Name
()]
=
in
;
in_name_to_node
[
in
->
Name
()]
=
in
;
}
}
}
}
for
(
auto
out
:
op_node
->
outputs
)
{
for
(
auto
out
:
op_node
->
outputs
)
{
auto
*
real_node
=
auto
*
real_node
=
GetRealNode
(
block_idx
,
out
);
GetRealNode
(
graphes
,
block_idx
,
out
,
vars_in_multi_block_map
);
if
(
NodeVarHasDtype
(
real_node
))
{
if
(
NodeVarHasDtype
(
real_node
))
{
if
(
in_name_to_node
.
count
(
out
->
Name
()))
if
(
in_name_to_node
.
count
(
out
->
Name
()))
real_node
->
Var
()
->
SetDataType
(
real_node
->
Var
()
->
SetDataType
(
...
@@ -643,23 +586,6 @@ void ConvertTensorDtype(
...
@@ -643,23 +586,6 @@ void ConvertTensorDtype(
continue
;
continue
;
}
}
// A strange case found in multi block.
else
if
(
op_type
==
"assign"
&&
// NOLINT
op_node
->
inputs
[
0
]
->
Name
()
==
op_node
->
outputs
[
0
]
->
Name
())
{
VLOG
(
2
)
<<
" in out are same, continue"
;
continue
;
}
// Handle tensor array.
else
if
(
OpInOutHasTensorArray
(
// NOLINT
graphes
,
block_idx
,
op_node
,
vars_in_multi_block_map
))
{
VLOG
(
2
)
<<
" in or out has tensor array, continue"
;
continue
;
}
// 2. if op support fp16/bf16 and not in blacklist.
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
...
@@ -667,22 +593,16 @@ void ConvertTensorDtype(
...
@@ -667,22 +593,16 @@ void ConvertTensorDtype(
//
//
// If a var(op's out var) appears multiple times in a block, we should not
// If a var(op's out var) appears multiple times in a block, we should not
// convert to fp16.
// convert to fp16.
else
if
(
blacklist
.
count
(
op_type
)
==
0
&&
// NOLINT
else
if
(
black_list_
.
count
(
op_type
)
==
0
&&
// NOLINT
!
VarIsMultiOpsOut
(
graphes
,
!
VarIsMultiPrecisionOpsOut
(
block_idx
,
op_node
))
{
block_idx
,
op_node
,
vars_in_multi_block_map
,
vars_appear_multi_in_one_block
))
{
bool
support_precision
=
bool
support_precision
=
OpSupportPrecision
(
op_type
,
backend
,
tensor_dtype
,
blacklist
);
OpSupportPrecision
(
op_type
,
backend_
,
mixed_precision_
,
black_list_
);
VLOG
(
2
)
<<
" support low precision "
<<
support_precision
;
// if op not has float input, we will not choose the low precision kernel.
// if op not has float input, we will not choose the low precision kernel.
{
{
bool
has_float_input
{
false
};
bool
has_float_input
{
false
};
for
(
auto
in_node
:
op_node
->
inputs
)
{
for
(
auto
in_node
:
op_node
->
inputs
)
{
auto
*
real_node
=
auto
*
real_node
=
GetRealNode
(
block_idx
,
in_node
);
GetRealNode
(
graphes
,
block_idx
,
in_node
,
vars_in_multi_block_map
);
if
(
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP16
||
if
(
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP16
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP32
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP64
||
real_node
->
Var
()
->
GetDataType
()
==
proto
::
VarType
::
FP64
||
...
@@ -696,42 +616,47 @@ void ConvertTensorDtype(
...
@@ -696,42 +616,47 @@ void ConvertTensorDtype(
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
VLOG
(
2
)
<<
" op doesn't has float input, just skip."
;
}
}
}
}
VLOG
(
2
)
<<
" support low precision "
<<
support_precision
;
if
(
support_precision
)
{
if
(
support_precision
)
{
HandleSpecialOps
(
op_node
->
Op
())
;
VLOG
(
2
)
<<
" process input nodes:"
;
++
num_low_precision
;
++
num_low_precision
;
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
// Just for paddle's terriable case: op's input and output has the same
// name.
std
::
unordered_map
<
std
::
string
,
std
::
string
>
names_map
;
for
(
auto
out_node
:
op_node
->
outputs
)
{
for
(
auto
in_node
:
op_node
->
inputs
)
{
if
(
out_node
->
Name
()
==
in_node
->
Name
())
{
names_map
[
out_node
->
Name
()]
=
in_node
->
Name
();
}
}
}
// Process inputs.
// Process inputs.
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
ProcessInputNode
(
true
,
ProcessInputNode
(
graphes
,
true
,
in_node
,
op_node
,
&
suffix_
,
block_desc
,
to_type
,
block_idx
);
in_node
,
if
(
names_map
.
count
(
in_node
->
Name
())
&&
cast_map_
.
count
(
in_node
))
{
op_node
,
names_map
[
in_node
->
Name
()]
=
cast_map_
[
in_node
]
->
Name
();
&
suffix
,
}
block_desc
,
&
cast_map
,
to_type
,
block_idx
,
vars_in_multi_block_map
);
}
}
VLOG
(
2
)
<<
" process output nodes:"
;
// Process outputs.
// Process outputs.
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
for
(
auto
*
out_node
:
op_node
->
outputs
)
{
ProcessOutputNode
(
ProcessOutputNode
(
block_idx
,
out_node
,
to_type
);
graphes
,
block_idx
,
out_node
,
to_type
,
vars_in_multi_block_map
);
}
}
}
else
{
}
else
{
auto
inputs
=
op_node
->
inputs
;
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
inputs
)
{
for
(
auto
*
in_node
:
inputs
)
{
ProcessInputNode
(
false
,
ProcessInputNode
(
false
,
graphes
,
in_node
,
in_node
,
op_node
,
op_node
,
&
suffix
,
&
suffix
_
,
block_desc
,
block_desc
,
&
cast_map
,
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
block_idx
,
block_idx
);
vars_in_multi_block_map
);
}
}
}
}
}
}
...
@@ -739,9 +664,9 @@ void ConvertTensorDtype(
...
@@ -739,9 +664,9 @@ void ConvertTensorDtype(
// 3. check op not support fp16/bf16 or in blacklist.
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
// - add cast op if the input dtype is not fp32.
else
{
// NOLINT
else
{
// NOLINT
VLOG
(
3
)
<<
"not to run fp16 op_type: "
<<
op_type
;
auto
ins
=
op_node
->
inputs
;
auto
ins
=
op_node
->
inputs
;
for
(
auto
*
in_node
:
ins
)
{
for
(
auto
*
in_node
:
ins
)
{
if
(
in_node
->
IsCtrlVar
())
continue
;
auto
*
in_var
=
in_node
->
Var
();
auto
*
in_var
=
in_node
->
Var
();
if
(
in_var
->
GetDataType
()
==
to_type
)
{
if
(
in_var
->
GetDataType
()
==
to_type
)
{
AddCastOp
(
graph
,
AddCastOp
(
graph
,
...
@@ -749,9 +674,12 @@ void ConvertTensorDtype(
...
@@ -749,9 +674,12 @@ void ConvertTensorDtype(
op_node
,
op_node
,
to_type
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
&
suffix
,
&
suffix
_
,
block_desc
,
block_desc
,
&
cast_map
);
&
cast_map_
);
VLOG
(
3
)
<<
"-- "
<<
in_node
->
Name
()
<<
"("
<<
to_type
<<
") to "
<<
cast_map_
[
in_node
]
->
Name
()
<<
"("
<<
framework
::
proto
::
VarType
::
FP32
<<
")"
;
}
}
}
}
}
}
...
@@ -760,40 +688,45 @@ void ConvertTensorDtype(
...
@@ -760,40 +688,45 @@ void ConvertTensorDtype(
// 4. if output_op's dtype is not compatible to output dtype, then just
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
// insert cast.
for
(
auto
*
node
:
output_nodes
)
{
for
(
auto
*
node
:
output_nodes
)
{
if
(
node
->
IsCtrlVar
())
continue
;
ir
::
Node
*
fetch_op
{
nullptr
};
for
(
auto
*
op_node
:
node
->
outputs
)
{
if
(
op_node
->
IsOp
()
&&
op_node
->
Op
()
->
Type
()
==
"fetch"
)
{
fetch_op
=
op_node
;
}
}
CHECK_NOTNULL
(
fetch_op
);
auto
var
=
node
->
Var
();
auto
var
=
node
->
Var
();
if
(
keep_io_types
&&
var
->
GetDataType
()
==
to_type
)
{
if
(
keep_io_types
_
&&
var
->
GetDataType
()
==
to_type
)
{
// fp16/bf16 -> fp32.
// fp16/bf16 -> fp32.
AddCastOp
(
graph
,
AddCastOp
(
graph
,
node
,
node
,
node
->
outputs
[
0
]
,
fetch_op
,
to_type
,
to_type
,
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
&
suffix
,
&
suffix
_
,
block_desc
,
block_desc
,
&
cast_map
);
&
cast_map
_
);
}
else
if
(
!
keep_io_types
&&
}
else
if
(
!
keep_io_types
_
&&
var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
// fp32 -> fp16/bf16
// fp32 -> fp16/bf16
AddCastOp
(
graph
,
AddCastOp
(
graph
,
node
,
node
,
node
->
outputs
[
0
]
,
fetch_op
,
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
to_type
,
to_type
,
&
suffix
,
&
suffix
_
,
block_desc
,
block_desc
,
&
cast_map
);
&
cast_map
_
);
}
}
}
}
for
(
auto
node
:
graph
->
Nodes
())
{
for
(
auto
node
:
graph
->
Nodes
())
{
auto
*
real_node
=
auto
*
real_node
=
GetRealNode
(
block_idx
,
node
);
GetRealNode
(
graphes
,
block_idx
,
node
,
vars_in_multi_block_map
);
if
(
!
NodeVarHasDtype
(
real_node
))
continue
;
if
(
!
NodeVarHasDtype
(
real_node
))
continue
;
if
(
vars_in_multi_block_map
->
count
(
real_node
->
Name
())
&&
if
(
vars_in_multi_block_map
_
.
count
(
real_node
->
Name
())
&&
vars_in_multi_block_map
->
at
(
real_node
->
Name
()).
second
==
block_idx
)
{
vars_in_multi_block_map
_
.
at
(
real_node
->
Name
()).
second
==
block_idx
)
{
vars_in_multi_block_map
->
at
(
real_node
->
Name
()).
first
=
vars_in_multi_block_map
_
.
at
(
real_node
->
Name
()).
first
=
real_node
->
Var
()
->
GetDataType
();
real_node
->
Var
()
->
GetDataType
();
}
}
}
}
...
@@ -802,24 +735,118 @@ void ConvertTensorDtype(
...
@@ -802,24 +735,118 @@ void ConvertTensorDtype(
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
LOG
(
INFO
)
<<
"--- detected "
<<
num_low_precision
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
<<
" low precision ops in "
<<
block_idx
<<
" subgraph"
;
}
}
}
// namespace
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
// We modify op's input output precision, and we need to fix cast op in_dtype
phi
::
Backend
backend
,
// and out_dtype attribute.
phi
::
DataType
precision
,
void
ConvertToMixedPrecisionPass
::
FixCastAttr
(
framework
::
ir
::
Graph
*
graph
)
{
const
std
::
unordered_set
<
std
::
string
>&
blacklist
)
{
auto
op_nodes
=
framework
::
ir
::
TopologySortOperations
(
*
graph
);
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
for
(
auto
*
op_node
:
op_nodes
)
{
bool
support_precision
=
false
;
if
(
!
op_node
->
IsOp
())
continue
;
if
(
blacklist
.
count
(
op_type
)
==
0
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
backend
==
phi
::
Backend
::
GPU
)
if
(
op_type
!=
"cast"
)
continue
;
support_precision
=
GpuKernelSupportPrecision
(
op_type
,
precision
);
auto
input
=
op_node
->
inputs
[
0
];
else
auto
output
=
op_node
->
outputs
[
0
];
support_precision
=
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
PhiKernelSupportPrecision
(
phi_op_type
,
backend
,
precision
);
static_cast
<
int
>
(
input
->
Var
()
->
GetDataType
()));
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
output
->
Var
()
->
GetDataType
()));
}
}
return
support_precision
;
}
}
void
ConvertToMixedPrecisionPass
::
SaveMixedModel
()
{
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
main_graph_
,
&
mixed_program_desc
);
paddle
::
CPUPlace
place
;
auto
parameters
=
scope_
.
LocalVarNames
();
std
::
sort
(
parameters
.
begin
(),
parameters
.
end
());
std
::
unordered_set
<
std
::
string
>
weights_should_be_fp32
;
for
(
auto
*
node
:
main_graph_
->
Nodes
())
{
if
(
!
(
node
->
IsVar
()))
continue
;
if
(
NodeVarHasDtype
(
node
))
{
if
(
node
->
Var
()
->
Persistable
()
&&
node
->
Var
()
->
GetDataType
()
==
paddle
::
framework
::
proto
::
VarType
::
FP32
)
{
VLOG
(
2
)
<<
"weights keep to fp32: "
<<
node
->
Name
();
weights_should_be_fp32
.
insert
(
node
->
Name
());
}
}
}
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int i = 0; i < t->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(data[i]); \
} \
t->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
for
(
const
auto
&
param_name
:
parameters
)
{
auto
*
var
=
scope_
.
FindLocalVar
(
param_name
);
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
auto
*
t
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
t
->
dtype
()
!=
phi
::
DataType
::
FLOAT32
)
continue
;
phi
::
DenseTensor
mixed_tensor
;
mixed_tensor
.
Resize
(
t
->
dims
());
auto
*
data
=
t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
if
(
mixed_precision_
==
phi
::
DataType
::
FLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
))
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
FLOAT16
,
phi
::
dtype
::
float16
);
}
else
if
(
mixed_precision_
==
phi
::
DataType
::
BFLOAT16
&&
!
weights_should_be_fp32
.
count
(
param_name
))
{
CONVERT_TENSOR_DTYPE
(
paddle
::
experimental
::
DataType
::
BFLOAT16
,
phi
::
dtype
::
bfloat16
);
}
}
}
#undef CONVERT_TENSOR_DTYPE
auto
SerializeParams
=
[
&
]()
->
std
::
string
{
std
::
ostringstream
os
;
phi
::
CPUContext
ctx
;
for
(
const
auto
&
param
:
parameters
)
{
VLOG
(
3
)
<<
"Serialize param: "
<<
param
;
PADDLE_ENFORCE_NOT_NULL
(
scope_
.
FindVar
(
param
),
platform
::
errors
::
NotFound
(
"Block should already have a '%s' variable"
,
param
));
auto
*
tensor
=
scope_
.
FindVar
(
param
)
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
SerializeToStream
(
os
,
*
tensor
,
ctx
);
}
return
os
.
str
();
};
auto
StrToBinary
=
[](
const
std
::
string
&
path
,
const
std
::
string
&
str
)
{
std
::
ofstream
file
(
path
.
c_str
(),
std
::
ios
::
binary
);
file
.
write
(
str
.
c_str
(),
str
.
size
());
file
.
close
();
};
StrToBinary
(
mixed_model_file_
,
mixed_program_desc
.
Proto
()
->
SerializeAsString
());
StrToBinary
(
mixed_params_file_
,
SerializeParams
());
}
void
ConvertToMixedPrecisionPass
::
PatchForStrangeOp
()
{
for
(
auto
*
graph
:
graphes_
)
{
for
(
auto
op_node
:
framework
::
ir
::
TopologySortOperations
(
*
graph
))
{
if
(
op_node
->
Name
()
==
"fused_multi_transformer"
)
{
auto
cache_kv_inputs
=
op_node
->
Op
()
->
Input
(
"CacheKV"
);
auto
cache_kv_outputs
=
op_node
->
Op
()
->
Output
(
"CacheKVOut"
);
CHECK_EQ
(
cache_kv_inputs
.
size
(),
cache_kv_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
cache_kv_inputs
.
size
();
++
i
)
{
op_node
->
Op
()
->
RenameOutput
(
cache_kv_outputs
[
i
],
cache_kv_inputs
[
i
]);
}
}
}
}
}
}
// namespace
void
AddCastOp
(
void
AddCastOp
(
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
node
,
...
@@ -865,11 +892,27 @@ void AddCastOp(
...
@@ -865,11 +892,27 @@ void AddCastOp(
IR_NODE_LINK_TO
(
cast_op_node
,
cast_output_node
);
IR_NODE_LINK_TO
(
cast_op_node
,
cast_output_node
);
(
*
map
)[
node
]
=
cast_output_node
;
(
*
map
)[
node
]
=
cast_output_node
;
}
}
next_op
->
Op
()
->
Rename
Input
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
next_op
->
Op
()
->
Rename
(
node
->
Name
(),
map
->
at
(
node
)
->
Name
());
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_LINK_TO
(
node
,
map
->
at
(
node
)
->
inputs
[
0
]);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
IR_NODE_LINK_TO
(
map
->
at
(
node
),
next_op
);
}
}
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
)
{
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
bool
support_precision
=
false
;
if
(
blacklist
.
count
(
op_type
)
==
0
)
{
if
(
backend
==
phi
::
Backend
::
GPU
)
support_precision
=
GpuKernelSupportPrecision
(
op_type
,
precision
);
else
support_precision
=
PhiKernelSupportPrecision
(
phi_op_type
,
backend
,
precision
);
}
return
support_precision
;
}
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_model_file
,
...
@@ -878,53 +921,15 @@ void ConvertToMixedPrecision(const std::string& model_file,
...
@@ -878,53 +921,15 @@ void ConvertToMixedPrecision(const std::string& model_file,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
bool
keep_io_types
,
bool
keep_io_types
,
std
::
unordered_set
<
std
::
string
>
black_list
)
{
std
::
unordered_set
<
std
::
string
>
black_list
)
{
paddle
::
CPUPlace
place
;
ConvertToMixedPrecisionPass
pass
(
model_file
,
framework
::
Executor
executor
(
place
);
params_file
,
framework
::
Scope
scope
;
mixed_model_file
,
auto
program_desc
=
mixed_params_file
,
inference
::
Load
(
&
executor
,
&
scope
,
model_file
,
params_file
);
mixed_precision
,
auto
main_graph
=
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
(
backend
,
new
framework
::
ir
::
Graph
(
*
program_desc
));
keep_io_types
,
black_list
);
std
::
unordered_map
<
std
::
string
,
pass
.
Run
();
std
::
pair
<
framework
::
proto
::
VarType
::
Type
,
int
>>
vars_in_multi_block_map
;
std
::
vector
<
std
::
set
<
std
::
string
>>
vars_appear_multi_in_one_block
(
program_desc
->
Size
());
FindVarsInMultiBlock
(
program_desc
.
get
(),
&
vars_in_multi_block_map
,
&
vars_appear_multi_in_one_block
);
std
::
vector
<
framework
::
ir
::
Graph
*>
graphes
;
for
(
size_t
i
=
0
;
i
<
main_graph
->
SubGraphsSize
();
++
i
)
{
auto
graph
=
main_graph
->
GetSubGraph
(
i
);
graphes
.
push_back
(
graph
);
VLOG
(
2
)
<<
" -------- handle subgraph "
<<
i
<<
", has "
<<
graph
->
Nodes
().
size
()
<<
" nodes --------"
;
ConvertAllFp64ToFp32
(
graph
);
ConvertTensorDtype
(
program_desc
.
get
(),
graphes
,
black_list
,
keep_io_types
,
backend
,
mixed_precision
,
i
,
&
vars_in_multi_block_map
,
vars_appear_multi_in_one_block
);
FixCastAttr
(
graph
);
}
framework
::
ProgramDesc
mixed_program_desc
;
framework
::
ir
::
GraphToProgram
(
*
main_graph
,
&
mixed_program_desc
);
SaveMixedModel
(
main_graph
.
get
(),
&
scope
,
&
mixed_program_desc
,
mixed_model_file
,
mixed_params_file
,
mixed_precision
,
vars_in_multi_block_map
);
}
}
}
// namespace analysis
}
// namespace analysis
...
...
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
浏览文件 @
de6f15b6
...
@@ -30,7 +30,7 @@ namespace paddle {
...
@@ -30,7 +30,7 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
bool
OpSupportPrecision
(
const
std
::
string
&
phi_
op_type
,
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
);
const
std
::
unordered_set
<
std
::
string
>&
blacklist
);
...
...
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
浏览文件 @
de6f15b6
...
@@ -140,39 +140,12 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
...
@@ -140,39 +140,12 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto
var_data_type
=
var_node
->
Var
()
->
GetDataType
();
auto
var_data_type
=
var_node
->
Var
()
->
GetDataType
();
VLOG
(
5
)
<<
"var_name is "
<<
var_name
<<
", data type is "
VLOG
(
5
)
<<
"var_name is "
<<
var_name
<<
", data type is "
<<
var_data_type
;
<<
var_data_type
;
if
(
var_data_type
==
paddle
::
framework
::
proto
::
VarType
::
FP16
&&
platform
::
CPUPlace
cpu_place
;
t
->
dtype
()
!=
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
framework
::
LoDTensor
temp_tensor
;
framework
::
Tensor
half_tensor
;
temp_tensor
.
Resize
(
t
->
dims
());
half_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
paddle
::
framework
::
TensorCopySync
(
*
t
,
cpu_place
,
&
temp_tensor
);
half_tensor
.
Resize
(
t
->
dims
());
t
->
clear
();
auto
*
half_data
=
paddle
::
framework
::
TensorCopySync
(
temp_tensor
,
place
,
t
);
half_tensor
.
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
t
->
numel
();
i
++
)
{
auto
*
data
=
t
->
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
half_data
[
i
]
=
static_cast
<
float16
>
(
data
[
i
]);
}
t
->
clear
();
paddle
::
framework
::
TensorCopySync
(
half_tensor
,
place
,
t
);
}
else
if
(
var_data_type
==
paddle
::
framework
::
proto
::
VarType
::
BF16
)
{
framework
::
Tensor
bf16_tensor
;
bf16_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
BFLOAT16
);
bf16_tensor
.
Resize
(
t
->
dims
());
auto
*
bf16_data
=
bf16_tensor
.
mutable_data
<
platform
::
bfloat16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
t
->
numel
();
i
++
)
{
auto
*
data
=
t
->
mutable_data
<
bfloat16
>
(
platform
::
CPUPlace
());
bf16_data
[
i
]
=
static_cast
<
platform
::
bfloat16
>
(
data
[
i
]);
}
t
->
clear
();
paddle
::
framework
::
TensorCopySync
(
bf16_tensor
,
place
,
t
);
}
else
{
platform
::
CPUPlace
cpu_place
;
framework
::
LoDTensor
temp_tensor
;
temp_tensor
.
Resize
(
t
->
dims
());
paddle
::
framework
::
TensorCopySync
(
*
t
,
cpu_place
,
&
temp_tensor
);
t
->
clear
();
paddle
::
framework
::
TensorCopySync
(
temp_tensor
,
place
,
t
);
}
}
}
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录