Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
18d64025
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
18d64025
编写于
6月 01, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
6月 01, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
simplify inference api (#11104)
上级
86d8659c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
36 addition
and
27 deletion
+36
-27
paddle/contrib/inference/paddle_inference_api.h
paddle/contrib/inference/paddle_inference_api.h
+23
-17
paddle/contrib/inference/paddle_inference_api_impl.cc
paddle/contrib/inference/paddle_inference_api_impl.cc
+13
-9
paddle/contrib/inference/test_paddle_inference_api_impl.cc
paddle/contrib/inference/test_paddle_inference_api_impl.cc
+0
-1
未找到文件。
paddle/contrib/inference/paddle_inference_api.h
浏览文件 @
18d64025
...
@@ -40,14 +40,23 @@ struct PaddleBuf {
...
@@ -40,14 +40,23 @@ struct PaddleBuf {
struct
PaddleTensor
{
struct
PaddleTensor
{
std
::
string
name
;
// variable name.
std
::
string
name
;
// variable name.
std
::
vector
<
int
>
shape
;
std
::
vector
<
int
>
shape
;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
PaddleBuf
data
;
// blob of data.
PaddleBuf
data
;
// blob of data.
PaddleDType
dtype
;
PaddleDType
dtype
;
};
};
enum
class
PaddleEngineKind
{
kNative
=
0
,
// Use the native Fluid facility.
// TODO(Superjomn) support following engines latter.
// kAnakin, // Use Anakin for inference.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
/*
/*
* A simple Inference API for Paddle. Currently this API can be used by
* A simple Inference API for Paddle. Currently this API can be used by
* non-sequence scenerios.
* non-sequence scenerios.
* TODO(Superjomn) Support another API for NLP-related usages.
*/
*/
class
PaddlePredictor
{
class
PaddlePredictor
{
public:
public:
...
@@ -69,15 +78,6 @@ class PaddlePredictor {
...
@@ -69,15 +78,6 @@ class PaddlePredictor {
// Destroy the Predictor.
// Destroy the Predictor.
virtual
~
PaddlePredictor
()
{}
virtual
~
PaddlePredictor
()
{}
enum
class
EngineKind
{
kNative
=
-
1
,
// Use the native Fluid facility.
// TODO(Superjomn) support latter.
// kAnakin, // Use Anakin for inference.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
// The common configs for all the predictors.
// The common configs for all the predictors.
struct
Config
{
struct
Config
{
std
::
string
model_dir
;
// path to the model directory.
std
::
string
model_dir
;
// path to the model directory.
...
@@ -86,18 +86,24 @@ class PaddlePredictor {
...
@@ -86,18 +86,24 @@ class PaddlePredictor {
};
};
struct
NativeConfig
:
public
PaddlePredictor
::
Config
{
struct
NativeConfig
:
public
PaddlePredictor
::
Config
{
// GPU related fields.
bool
use_gpu
{
false
};
bool
use_gpu
{
false
};
int
device
;
int
device
{
0
};
float
fraction_of_gpu_memory
;
float
fraction_of_gpu_memory
{
-
1.
f
};
// Negative to notify initialization.
std
::
string
prog_file
;
std
::
string
prog_file
;
std
::
string
param_file
;
std
::
string
param_file
;
bool
share_variables
;
};
};
// A factory to help create difference predictor.
// A factory to help create different predictors.
template
<
//
typename
ConfigT
,
// FOR EXTENSION DEVELOPER:
PaddlePredictor
::
EngineKind
engine
=
PaddlePredictor
::
EngineKind
::
kNative
>
// Different predictors are designated by config type and engine kind. Similar
// configs can be merged, but there shouldn't be a huge config containing
// different fields for more than one kind of predictors.
//
// Similarly, each engine kind should map to a unique predictor implementation.
template
<
typename
ConfigT
,
PaddleEngineKind
engine
=
PaddleEngineKind
::
kNative
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
(
const
ConfigT
&
config
);
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
(
const
ConfigT
&
config
);
}
// namespace paddle
}
// namespace paddle
paddle/contrib/inference/paddle_inference_api_impl.cc
浏览文件 @
18d64025
...
@@ -57,8 +57,7 @@ std::string num2str(T a) {
...
@@ -57,8 +57,7 @@ std::string num2str(T a) {
bool
NativePaddlePredictor
::
Init
()
{
bool
NativePaddlePredictor
::
Init
()
{
VLOG
(
3
)
<<
"Predictor::init()"
;
VLOG
(
3
)
<<
"Predictor::init()"
;
// TODO(panyx0718): Should CPU vs GPU device be decided by id?
if
(
config_
.
use_gpu
)
{
if
(
config_
.
device
>=
0
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
place_
=
paddle
::
platform
::
CPUPlace
();
...
@@ -85,11 +84,13 @@ bool NativePaddlePredictor::Init() {
...
@@ -85,11 +84,13 @@ bool NativePaddlePredictor::Init() {
}
}
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
// Create variables
// Create temporary variables first, so that the first batch do not need to
// TODO(panyx0718): Why need to test share_variables here?
// create variables in the runtime. This is the logics of the old inference
if
(
config_
.
share_variables
)
{
// API.
executor_
->
CreateVariables
(
*
inference_program_
,
scope_
.
get
(),
0
);
// TODO(Superjomn) this should be modified when `Clone` is valid for
}
// multi-thread application.
executor_
->
CreateVariables
(
*
inference_program_
,
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
// Get the feed_target_names and fetch_target_names
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
fetch_target_names_
=
inference_program_
->
GetFetchTargetNames
();
fetch_target_names_
=
inference_program_
->
GetFetchTargetNames
();
...
@@ -124,7 +125,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
...
@@ -124,7 +125,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
scope_
.
get
(),
scope_
.
get
(),
&
feed_targets
,
&
feed_targets
,
&
fetch_targets
,
&
fetch_targets
,
!
config_
.
share_variables
);
false
/* don't create variable eatch time */
);
if
(
!
GetFetch
(
fetchs
,
output_data
))
{
if
(
!
GetFetch
(
fetchs
,
output_data
))
{
LOG
(
ERROR
)
<<
"fail to get fetchs"
;
LOG
(
ERROR
)
<<
"fail to get fetchs"
;
return
false
;
return
false
;
...
@@ -242,11 +243,14 @@ bool NativePaddlePredictor::GetFetch(
...
@@ -242,11 +243,14 @@ bool NativePaddlePredictor::GetFetch(
template
<
>
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
NativeConfig
,
Paddle
Predictor
::
EngineKind
::
kNative
>
(
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
const
NativeConfig
&
config
)
{
const
NativeConfig
&
config
)
{
VLOG
(
3
)
<<
"create NativePaddlePredictor"
;
VLOG
(
3
)
<<
"create NativePaddlePredictor"
;
if
(
config
.
use_gpu
)
{
if
(
config
.
use_gpu
)
{
// 1. GPU memeroy
// 1. GPU memeroy
PADDLE_ENFORCE
(
config
.
fraction_of_gpu_memory
>
0.
f
,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"
);
std
::
vector
<
std
::
string
>
flags
;
std
::
vector
<
std
::
string
>
flags
;
if
(
config
.
fraction_of_gpu_memory
>=
0.0
f
||
if
(
config
.
fraction_of_gpu_memory
>=
0.0
f
||
config
.
fraction_of_gpu_memory
<=
0.95
f
)
{
config
.
fraction_of_gpu_memory
<=
0.95
f
)
{
...
...
paddle/contrib/inference/test_paddle_inference_api_impl.cc
浏览文件 @
18d64025
...
@@ -47,7 +47,6 @@ NativeConfig GetConfig() {
...
@@ -47,7 +47,6 @@ NativeConfig GetConfig() {
config
.
fraction_of_gpu_memory
=
0.15
;
config
.
fraction_of_gpu_memory
=
0.15
;
config
.
use_gpu
=
true
;
config
.
use_gpu
=
true
;
config
.
device
=
0
;
config
.
device
=
0
;
config
.
share_variables
=
true
;
return
config
;
return
config
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录