Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
28ea9aad
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
28ea9aad
编写于
12月 14, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
12月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] rewrite convert_to_mixed_precision (#48853)
上级
b9fad5da
变更
15
展开全部
显示空白变更内容
内联
并排
Showing
15 changed file
with
324 addition
and
950 deletion
+324
-950
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
+163
-159
paddle/fluid/framework/ir/auto_mixed_precision_pass.h
paddle/fluid/framework/ir/auto_mixed_precision_pass.h
+26
-12
paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
.../fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
+1
-1
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
.../fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
+1
-1
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+1
-1
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+8
-7
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+8
-8
paddle/fluid/inference/analysis/passes/CMakeLists.txt
paddle/fluid/inference/analysis/passes/CMakeLists.txt
+1
-1
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
...d/inference/analysis/passes/convert_to_mixed_precision.cc
+64
-740
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
...id/inference/analysis/passes/convert_to_mixed_precision.h
+39
-9
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+4
-4
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+4
-4
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+1
-1
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-1
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
28ea9aad
...
@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference)
...
@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library
(
preln_residual_bias_fuse_pass inference
)
pass_library
(
preln_residual_bias_fuse_pass inference
)
pass_library
(
delete_fill_constant_op_pass inference
)
pass_library
(
delete_fill_constant_op_pass inference
)
pass_library
(
constant_folding_pass inference
)
pass_library
(
constant_folding_pass inference
)
pass_library
(
float_to_half
_pass inference
)
pass_library
(
auto_mixed_precision
_pass inference
)
pass_library
(
conv2d_fusion_layout_transfer_pass inference
)
pass_library
(
conv2d_fusion_layout_transfer_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/
float_to_half
_pass.cc
→
paddle/fluid/framework/ir/
auto_mixed_precision
_pass.cc
浏览文件 @
28ea9aad
此差异已折叠。
点击以展开。
paddle/fluid/framework/ir/
float_to_half
_pass.h
→
paddle/fluid/framework/ir/
auto_mixed_precision
_pass.h
浏览文件 @
28ea9aad
...
@@ -27,13 +27,13 @@ namespace paddle {
...
@@ -27,13 +27,13 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
FloatToHalf
Pass
:
public
FusePassBase
{
class
AutoMixedPrecision
Pass
:
public
FusePassBase
{
public:
public:
using
VarType
=
framework
::
proto
::
VarType
;
using
VarType
=
framework
::
proto
::
VarType
;
public:
public:
FloatToHalf
Pass
()
=
default
;
AutoMixedPrecision
Pass
()
=
default
;
~
FloatToHalf
Pass
()
=
default
;
~
AutoMixedPrecision
Pass
()
=
default
;
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
...
@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase {
...
@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase {
void
SetDefaultBlacklist
()
const
;
void
SetDefaultBlacklist
()
const
;
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
DataType
precision
,
phi
::
Backend
backend
=
phi
::
Backend
::
GPU
)
const
;
void
SetOpUniqueType
()
const
;
void
SetOpUniqueType
()
const
;
void
RestoreOpOriginType
()
const
;
void
RestoreOpOriginType
()
const
;
...
@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase {
...
@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase {
void
ConvertWeightsData
()
const
;
void
ConvertWeightsData
()
const
;
private:
private:
mutable
bool
keep_io_types_
;
mutable
bool
skip_pass_
{
false
};
mutable
bool
keep_io_types_
{
false
};
// float16 or bfloat16 now
// float16 or bfloat16 now
mutable
phi
::
DataType
half_precision_
;
mutable
phi
::
DataType
low_precision_
{
phi
::
DataType
::
FLOAT16
};
mutable
phi
::
Backend
backend_
{
phi
::
Backend
::
GPU
};
mutable
std
::
unordered_set
<
std
::
string
>
black_list_
;
mutable
std
::
unordered_set
<
std
::
string
>
black_list_
;
...
@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase {
...
@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase {
mutable
std
::
vector
<
std
::
vector
<
Node
*>>
all_op_nodes_
;
mutable
std
::
vector
<
std
::
vector
<
Node
*>>
all_op_nodes_
;
// op's unique type -> the op's origin type
// op's unique type -> the op's origin type
mutable
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_original_type_
;
mutable
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_original_type_
;
// op's unique type -> whether the op run at
half
precision
// op's unique type -> whether the op run at
low
precision
mutable
std
::
unordered_set
<
std
::
string
>
op_run_
half
_
;
mutable
std
::
unordered_set
<
std
::
string
>
op_run_
low_precision
_
;
mutable
std
::
unordered_set
<
std
::
string
>
vars_convert_to_
half
_
;
mutable
std
::
unordered_set
<
std
::
string
>
vars_convert_to_
low_precision
_
;
};
};
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
void
DoInsertCastOp
(
Graph
*
graph
,
Node
*
var_node
,
Node
*
op_node
,
proto
::
VarType
::
Type
from_type
,
proto
::
VarType
::
Type
to_type
,
framework
::
BlockDesc
*
block_desc
,
int
*
suffix
,
std
::
unordered_map
<
Node
*
,
Node
*>*
cache
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
浏览文件 @
28ea9aad
...
@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
bool
is_fp16_precision
=
bool
is_fp16_precision
=
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
phi
::
DataType
::
FLOAT16
||
phi
::
DataType
::
FLOAT16
||
Get
<
bool
>
(
"enable_gpu_
half
"
);
Get
<
bool
>
(
"enable_gpu_
mixed
"
);
bool
cutlass_enable
=
false
;
bool
cutlass_enable
=
false
;
#ifdef PADDLE_WITH_CUTLASS
#ifdef PADDLE_WITH_CUTLASS
...
...
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
浏览文件 @
28ea9aad
...
@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
bool
is_fp16_precision
=
bool
is_fp16_precision
=
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"model_precision"
))
==
phi
::
DataType
::
FLOAT16
||
phi
::
DataType
::
FLOAT16
||
Get
<
bool
>
(
"enable_gpu_
half
"
);
Get
<
bool
>
(
"enable_gpu_
mixed
"
);
constexpr
int
CUTLASS_NHWC_ALIGNMENT
=
8
;
constexpr
int
CUTLASS_NHWC_ALIGNMENT
=
8
;
if
(
is_fp16_precision
)
{
if
(
is_fp16_precision
)
{
#ifdef PADDLE_WITH_CUTLASS
#ifdef PADDLE_WITH_CUTLASS
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
28ea9aad
...
@@ -365,7 +365,7 @@ struct Argument {
...
@@ -365,7 +365,7 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
mixed_black_list
,
DECL_ARGUMENT_FIELD
(
mixed_black_list
,
MixedBlackList
,
MixedBlackList
,
std
::
unordered_set
<
std
::
string
>
);
std
::
unordered_set
<
std
::
string
>
);
DECL_ARGUMENT_FIELD
(
enable_gpu_
half
,
EnableGPUHalf
,
bool
);
DECL_ARGUMENT_FIELD
(
enable_gpu_
mixed
,
EnableGPUMixed
,
bool
);
DECL_ARGUMENT_FIELD
(
mixed_precision_mode
,
MixedPrecisionMode
,
int
);
DECL_ARGUMENT_FIELD
(
mixed_precision_mode
,
MixedPrecisionMode
,
int
);
// cinn compiler related
// cinn compiler related
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
28ea9aad
...
@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) {
...
@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) {
void
IRPassManager
::
CreatePasses
(
Argument
*
argument
,
void
IRPassManager
::
CreatePasses
(
Argument
*
argument
,
const
std
::
vector
<
std
::
string
>
&
passes
)
{
const
std
::
vector
<
std
::
string
>
&
passes
)
{
// For graph_viz_pass
std
::
string
pre_pass
;
std
::
string
pre_pass
;
int
pass_num
=
0
;
int
pass_num
=
0
;
for
(
const
std
::
string
&
pass_name
:
passes
)
{
for
(
const
std
::
string
&
pass_name
:
passes
)
{
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
pass
->
Set
(
"use_varseqlen"
,
new
bool
(
argument
->
tensorrt_use_varseqlen
()));
pass
->
Set
(
"use_varseqlen"
,
new
bool
(
argument
->
tensorrt_use_varseqlen
()));
...
@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument
->
tensorrt_tuned_dynamic_shape
();
argument
->
tensorrt_tuned_dynamic_shape
();
pass
->
Set
(
"with_dynamic_shape"
,
new
bool
(
with_dynamic_shape
));
pass
->
Set
(
"with_dynamic_shape"
,
new
bool
(
with_dynamic_shape
));
// mixed precision related
// Mixed precision related.
pass
->
Set
(
"model_precision"
,
new
int
(
argument
->
model_precision
()));
pass
->
Set
(
pass
->
Set
(
"mixed_black_list"
,
"mixed_black_list"
,
new
std
::
unordered_set
<
std
::
string
>
(
argument
->
mixed_black_list
()));
new
std
::
unordered_set
<
std
::
string
>
(
argument
->
mixed_black_list
()));
pass
->
Set
(
"enable_gpu_
half"
,
new
bool
(
argument
->
enable_gpu_half
()));
pass
->
Set
(
"enable_gpu_
mixed"
,
new
bool
(
argument
->
enable_gpu_mixed
()));
pass
->
Set
(
"mixed_precision_mode"
,
pass
->
Set
(
"mixed_precision_mode"
,
new
int
(
argument
->
mixed_precision_mode
()));
new
int
(
argument
->
mixed_precision_mode
()));
pass
->
Set
(
"model_precision"
,
new
int
(
argument
->
model_precision
()));
if
(
pass_name
==
"graph_viz_pass"
)
{
if
(
pass_name
==
"graph_viz_pass"
)
{
std
::
string
optim_cache_dir
=
argument
->
optim_cache_dir
();
std
::
string
optim_cache_dir
=
argument
->
optim_cache_dir
();
...
@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new
std
::
vector
<
std
::
string
>
(
argument
->
tensorrt_disabled_ops
()));
new
std
::
vector
<
std
::
string
>
(
argument
->
tensorrt_disabled_ops
()));
pass
->
Set
(
"trt_use_dla"
,
new
bool
(
argument
->
tensorrt_use_dla
()));
pass
->
Set
(
"trt_use_dla"
,
new
bool
(
argument
->
tensorrt_use_dla
()));
pass
->
Set
(
"trt_dla_core"
,
new
int
(
argument
->
tensorrt_dla_core
()));
pass
->
Set
(
"trt_dla_core"
,
new
int
(
argument
->
tensorrt_dla_core
()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not run fp16.
// not run fp16.
pass
->
Set
(
"disable_trt_plugin_fp16"
,
pass
->
Set
(
"disable_trt_plugin_fp16"
,
...
@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"root_predictor_id"
,
new
int
(
argument
->
root_predictor_id
()));
pass
->
Set
(
"root_predictor_id"
,
new
int
(
argument
->
root_predictor_id
()));
}
else
if
(
pass_name
==
"build_cinn_pass"
)
{
}
else
if
(
pass_name
==
"build_cinn_pass"
)
{
pass
->
Set
(
"is_inference_stage"
,
new
bool
(
argument
->
use_cinn_compiler
()));
pass
->
Set
(
"is_inference_stage"
,
new
bool
(
argument
->
use_cinn_compiler
()));
}
}
else
if
(
pass_name
==
"lite_subgraph_pass"
)
{
if
(
pass_name
==
"lite_subgraph_pass"
)
{
bool
lite_enable_int8
=
bool
lite_enable_int8
=
argument
->
lite_precision_mode
()
==
AnalysisConfig
::
Precision
::
kInt8
;
argument
->
lite_precision_mode
()
==
AnalysisConfig
::
Precision
::
kInt8
;
pass
->
Set
(
"program"
,
pass
->
Set
(
"program"
,
...
@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass
->
Set
(
"nnadapter_model_cache_token"
,
pass
->
Set
(
"nnadapter_model_cache_token"
,
new
std
::
vector
<
std
::
string
>
(
new
std
::
vector
<
std
::
string
>
(
argument
->
nnadapter_model_cache_token
()));
argument
->
nnadapter_model_cache_token
()));
}
}
else
if
(
pass_name
==
"fc_fuse_pass"
)
{
if
(
pass_name
==
"fc_fuse_pass"
)
{
pass
->
Set
(
"use_gpu"
,
new
bool
(
argument
->
use_gpu
()));
pass
->
Set
(
"use_gpu"
,
new
bool
(
argument
->
use_gpu
()));
bool
fc_mkldnn_pass
=
0
;
bool
fc_mkldnn_pass
=
0
;
for
(
const
std
::
string
&
pass_n
:
passes
)
{
for
(
const
std
::
string
&
pass_n
:
passes
)
{
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
28ea9aad
...
@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph,
...
@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph,
backend
,
backend
,
precision
,
precision
,
blacklist
))
{
blacklist
))
{
Add
CastOp
(
graph
,
Insert
CastOp
(
graph
,
var_node
,
var_node
,
next_op
,
next_op
,
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
to_type
,
to_type
,
&
suffix
,
block_desc
,
block_desc
,
&
suffix
,
&
var_to_cast_op_map
);
&
var_to_cast_op_map
);
var_node
->
Var
()
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
var_node
->
Var
()
->
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
}
}
...
...
paddle/fluid/inference/analysis/passes/CMakeLists.txt
浏览文件 @
28ea9aad
...
@@ -13,7 +13,7 @@ cc_library(
...
@@ -13,7 +13,7 @@ cc_library(
cc_library
(
cc_library
(
convert_to_mixed_precision
convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc
SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass
)
DEPS analysis_pass ir_graph_build_pass
auto_mixed_precision_pass
)
cc_library
(
cc_library
(
ir_params_sync_among_devices_pass
ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc
SRCS ir_params_sync_among_devices_pass.cc
...
...
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
浏览文件 @
28ea9aad
此差异已折叠。
点击以展开。
paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h
浏览文件 @
28ea9aad
...
@@ -15,14 +15,12 @@
...
@@ -15,14 +15,12 @@
#pragma once
#pragma once
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
...
@@ -30,20 +28,52 @@ namespace paddle {
...
@@ -30,20 +28,52 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
class
ConvertToMixedPrecisionPass
{
public:
explicit
ConvertToMixedPrecisionPass
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
mixed_model_file
,
const
std
::
string
&
mixed_params_file
,
phi
::
DataType
mixed_precision
,
phi
::
Backend
backend
,
bool
keep_io_types
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
void
Run
();
private:
void
LoadModel
();
void
SaveMixedModel
();
private:
std
::
string
model_file_
;
std
::
string
params_file_
;
std
::
string
mixed_model_file_
;
std
::
string
mixed_params_file_
;
phi
::
DataType
mixed_precision_
;
phi
::
Backend
backend_
;
bool
keep_io_types_
;
std
::
unordered_set
<
std
::
string
>
black_list_
;
framework
::
Scope
scope_
;
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
main_graph_
{
nullptr
};
};
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
blacklist
);
const
std
::
unordered_set
<
std
::
string
>&
black
_
list
);
void
Add
CastOp
(
void
Insert
CastOp
(
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Node
*
var_
node
,
framework
::
ir
::
Node
*
next_op
,
framework
::
ir
::
Node
*
op_node
,
framework
::
proto
::
VarType
::
Type
from_type
,
framework
::
proto
::
VarType
::
Type
from_type
,
framework
::
proto
::
VarType
::
Type
to_type
,
framework
::
proto
::
VarType
::
Type
to_type
,
int
*
suffix
,
framework
::
BlockDesc
*
block_desc
,
framework
::
BlockDesc
*
block_desc
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
map
);
int
*
suffix
,
std
::
unordered_map
<
framework
::
ir
::
Node
*
,
framework
::
ir
::
Node
*>*
visited
);
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
void
ConvertToMixedPrecision
(
const
std
::
string
&
model_file
,
const
std
::
string
&
params_file
,
const
std
::
string
&
params_file
,
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
28ea9aad
...
@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
...
@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
// default
// default
}
else
if
(
precision_mode
==
Precision
::
kHalf
||
}
else
if
(
precision_mode
==
Precision
::
kHalf
||
precision_mode
==
Precision
::
kBf16
)
{
precision_mode
==
Precision
::
kBf16
)
{
enable_gpu_
half
_
=
true
;
enable_gpu_
mixed
_
=
true
;
}
else
{
}
else
{
LOG
(
ERROR
)
LOG
(
ERROR
)
<<
"The Paddle-GPU inference currently only supports "
<<
"The Paddle-GPU inference currently only supports "
...
@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
...
@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Mixed precision related.
// Mixed precision related.
CP_MEMBER
(
mixed_black_list_
);
CP_MEMBER
(
mixed_black_list_
);
CP_MEMBER
(
enable_gpu_
half
_
);
CP_MEMBER
(
enable_gpu_
mixed
_
);
CP_MEMBER
(
mixed_precision_mode_
);
CP_MEMBER
(
mixed_precision_mode_
);
CP_MEMBER
(
enable_memory_optim_
);
CP_MEMBER
(
enable_memory_optim_
);
...
@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
...
@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss
<<
params_file_
;
ss
<<
params_file_
;
ss
<<
use_gpu_
;
ss
<<
use_gpu_
;
ss
<<
enable_gpu_
half
_
;
ss
<<
enable_gpu_
mixed
_
;
ss
<<
use_external_stream_
;
ss
<<
use_external_stream_
;
ss
<<
exec_stream_
;
ss
<<
exec_stream_
;
ss
<<
use_fc_padding_
;
ss
<<
use_fc_padding_
;
...
@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() {
...
@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() {
os
.
InsertRow
({
"use_gpu"
,
use_gpu_
?
"true"
:
"false"
});
os
.
InsertRow
({
"use_gpu"
,
use_gpu_
?
"true"
:
"false"
});
if
(
use_gpu_
)
{
if
(
use_gpu_
)
{
os
.
InsertRow
({
"gpu_device_id"
,
std
::
to_string
(
gpu_device_id_
)});
os
.
InsertRow
({
"gpu_device_id"
,
std
::
to_string
(
gpu_device_id_
)});
os
.
InsertRow
({
"enable_gpu_
half_"
,
std
::
to_string
(
enable_gpu_half
_
)});
os
.
InsertRow
({
"enable_gpu_
mixed_"
,
std
::
to_string
(
enable_gpu_mixed
_
)});
os
.
InsertRow
({
"memory_pool_init_size"
,
os
.
InsertRow
({
"memory_pool_init_size"
,
std
::
to_string
(
memory_pool_init_size_mb_
)
+
"MB"
});
std
::
to_string
(
memory_pool_init_size_mb_
)
+
"MB"
});
os
.
InsertRow
(
os
.
InsertRow
(
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
28ea9aad
...
@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() {
if
(
!
config_
.
ir_optim
())
{
if
(
!
config_
.
ir_optim
())
{
argument_
.
SetEnableIrOptim
(
false
);
argument_
.
SetEnableIrOptim
(
false
);
if
(
config_
.
enable_gpu_
half
_
)
{
if
(
config_
.
enable_gpu_
mixed
_
)
{
argument_
.
SetEnableIrOptim
(
true
);
argument_
.
SetEnableIrOptim
(
true
);
pass_builder
->
ClearPasses
();
pass_builder
->
ClearPasses
();
pass_builder
->
AppendPass
(
"
float_to_half
_pass"
);
pass_builder
->
AppendPass
(
"
auto_mixed_precision
_pass"
);
LOG
(
INFO
)
LOG
(
INFO
)
<<
"This model run in Paddle-GPU mixed precision mode with no ir "
<<
"This model run in Paddle-GPU mixed precision mode with no ir "
"optimization."
;
"optimization."
;
...
@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() {
if
(
config_
.
ir_debug_
)
{
if
(
config_
.
ir_debug_
)
{
pass_builder
->
TurnOnDebug
();
pass_builder
->
TurnOnDebug
();
}
}
if
(
config_
.
enable_gpu_
half
_
)
{
if
(
config_
.
enable_gpu_
mixed
_
)
{
LOG
(
INFO
)
<<
"This model run in Paddle-GPU mixed precision mode."
;
LOG
(
INFO
)
<<
"This model run in Paddle-GPU mixed precision mode."
;
}
}
}
}
...
@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison.
// mixed precison.
argument_
.
SetModelPrecision
(
static_cast
<
int
>
(
model_precision_
));
argument_
.
SetModelPrecision
(
static_cast
<
int
>
(
model_precision_
));
argument_
.
SetMixedBlackList
(
config_
.
mixed_black_list_
);
argument_
.
SetMixedBlackList
(
config_
.
mixed_black_list_
);
argument_
.
SetEnableGPU
Half
(
config_
.
enable_gpu_half
_
);
argument_
.
SetEnableGPU
Mixed
(
config_
.
enable_gpu_mixed
_
);
argument_
.
SetMixedPrecisionMode
(
static_cast
<
int
>
(
argument_
.
SetMixedPrecisionMode
(
static_cast
<
int
>
(
paddle
::
ConvertPrecision
(
config_
.
mixed_precision_mode_
)));
paddle
::
ConvertPrecision
(
config_
.
mixed_precision_mode_
)));
}
}
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
28ea9aad
...
@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig {
...
@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool
use_gpu_
{
false
};
bool
use_gpu_
{
false
};
int
gpu_device_id_
{
0
};
int
gpu_device_id_
{
0
};
uint64_t
memory_pool_init_size_mb_
{
100
};
// initial size is 100MB.
uint64_t
memory_pool_init_size_mb_
{
100
};
// initial size is 100MB.
bool
enable_gpu_
half
_
{
false
};
bool
enable_gpu_
mixed
_
{
false
};
bool
thread_local_stream_
{
false
};
bool
thread_local_stream_
{
false
};
bool
use_cudnn_
{
false
};
bool
use_cudnn_
{
false
};
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
28ea9aad
...
@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
...
@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
#endif //
#endif //
"transpose_flatten_concat_fuse_pass"
,
//
"transpose_flatten_concat_fuse_pass"
,
//
"float_to_half_pass"
,
//
"constant_folding_pass"
,
//
"auto_mixed_precision_pass"
,
//
});
});
use_gpu_
=
true
;
use_gpu_
=
true
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录