Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9df2d8b5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9df2d8b5
编写于
9月 05, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
9月 05, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test/add text-classification test (#13081)
上级
4907d093
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
276 addition
and
61 deletion
+276
-61
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+22
-10
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+0
-1
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+3
-0
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+15
-3
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+18
-19
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+20
-7
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+7
-2
paddle/fluid/inference/analysis/flags.h
paddle/fluid/inference/analysis/flags.h
+22
-0
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+6
-3
paddle/fluid/inference/analysis/test_text_classification.cc
paddle/fluid/inference/analysis/test_text_classification.cc
+109
-0
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+13
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+17
-11
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+4
-2
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+2
-1
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+15
-0
paddle/fluid/operators/lookup_table_op.h
paddle/fluid/operators/lookup_table_op.h
+1
-1
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
9df2d8b5
...
...
@@ -86,15 +86,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
}
op_desc
.
SetInput
(
"Bias"
,
{
new_bias_var
});
}
#undef GET_NODE
// Create temp variables.
scope
->
Var
(
name_scope
+
"/BatchedInput.new"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
name_scope
+
"/BatchCellPreAct.new"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
name_scope
+
"/BatchedGate.new"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx_n
->
Name
()});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
name_scope
+
"/BatchedGate.new"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
name_scope
+
"/BatchCellPreAct.new"
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
name_scope
+
"/BatchedInput.new"
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm_n
->
Op
()
->
GetAttr
(
"use_peepholes"
));
// TODO(TJ): get from attr
...
...
@@ -130,8 +139,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
int
fusion_count
{
0
};
auto
fc_no_bias_handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
...
...
@@ -152,21 +161,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
if
(
with_fc_bias
)
{
GET_NODE
(
fc_bias
);
GET_NODE
(
elementwise_add
);
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul_n
,
lstm_n
,
elementwise_add_n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
lstm_creator
(
lstm
,
x
,
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
-
1
);
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul_n
,
lstm_n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
#undef GET_NODE
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul_n
,
lstm_n
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
fc_no_bias_
handler
);
gpd
(
graph
,
handler
);
return
fusion_count
;
}
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
浏览文件 @
9df2d8b5
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
9df2d8b5
...
...
@@ -73,7 +73,6 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
void
GraphPatternDetector
::
operator
()(
Graph
*
graph
,
GraphPatternDetector
::
handle_t
handler
)
{
if
(
!
MarkPDNodesInGraph
(
*
graph
))
{
LOG
(
INFO
)
<<
"Mark failed"
;
return
;
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
9df2d8b5
...
...
@@ -19,6 +19,9 @@
#endif
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
...
...
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
9df2d8b5
...
...
@@ -58,7 +58,7 @@ endif()
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
ARGS --infer_ditu_rnn_model=
${
DITU_INSTALL_DIR
}
/model
--infer_ditu_rnn_data=
${
DITU_INSTALL_DIR
}
/data.txt
)
--infer_ditu_rnn_data=
${
DITU_INSTALL_DIR
}
/data.txt
)
inference_analysis_test
(
test_data_flow_graph SRCS data_flow_graph_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
)
...
...
@@ -74,7 +74,7 @@ inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
set
(
CHINESE_NER_MODEL_URL
"http://paddle-inference-dist.bj.bcebos.com/chinese_ner_model.tar.gz"
)
set
(
CHINESE_NER_DATA_URL
"http://paddle-inference-dist.bj.bcebos.com/chinese_ner-data.txt.tar.gz"
)
set
(
CHINESE_NER_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo/chinese_ner"
CACHE PATH
"Chinese ner model and data root."
FORCE
)
if
(
NOT EXISTS
${
CHINESE_NER_INSTALL_DIR
}
AND WITH_TESTING
)
if
(
NOT EXISTS
${
CHINESE_NER_INSTALL_DIR
}
AND WITH_TESTING
AND WITH_INFERENCE
)
inference_download_and_uncompress
(
${
CHINESE_NER_INSTALL_DIR
}
${
CHINESE_NER_MODEL_URL
}
"chinese_ner_model.tar.gz"
)
inference_download_and_uncompress
(
${
CHINESE_NER_INSTALL_DIR
}
${
CHINESE_NER_DATA_URL
}
"chinese_ner-data.txt.tar.gz"
)
endif
()
...
...
@@ -87,7 +87,7 @@ inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc
set
(
LAC_MODEL_URL
"http://paddle-inference-dist.bj.bcebos.com/lac_model.tar.gz"
)
set
(
LAC_DATA_URL
"http://paddle-inference-dist.bj.bcebos.com/lac_data.txt.tar.gz"
)
set
(
LAC_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo/lac"
CACHE PATH
"LAC model and data root."
FORCE
)
if
(
NOT EXISTS
${
LAC_INSTALL_DIR
}
AND WITH_TESTING
)
if
(
NOT EXISTS
${
LAC_INSTALL_DIR
}
AND WITH_TESTING
AND WITH_INFERENCE
)
inference_download_and_uncompress
(
${
LAC_INSTALL_DIR
}
${
LAC_MODEL_URL
}
"lac_model.tar.gz"
)
inference_download_and_uncompress
(
${
LAC_INSTALL_DIR
}
${
LAC_DATA_URL
}
"lac_data.txt.tar.gz"
)
endif
()
...
...
@@ -96,3 +96,15 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api
ARGS --infer_model=
${
LAC_INSTALL_DIR
}
/model
--infer_data=
${
LAC_INSTALL_DIR
}
/data.txt
)
set
(
TEXT_CLASSIFICATION_MODEL_URL
"http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz"
)
set
(
TEXT_CLASSIFICATION_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo/text_classification"
CACHE PATH
"Text Classification model and data root."
FORCE
)
if
(
NOT EXISTS
${
TEXT_CLASSIFICATION_INSTALL_DIR
}
AND WITH_TESTING AND WITH_INFERENCE
)
inference_download_and_uncompress
(
${
TEXT_CLASSIFICATION_INSTALL_DIR
}
${
TEXT_CLASSIFICATION_MODEL_URL
}
"text-classification-Senta.tar.gz"
)
endif
()
inference_analysis_test
(
test_text_classification SRCS test_text_classification.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor
ARGS --infer_model=
${
TEXT_CLASSIFICATION_INSTALL_DIR
}
/text-classification-Senta
)
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
9df2d8b5
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
...
...
@@ -41,20 +42,16 @@ class DfgPassManagerImpl final : public DfgPassManager {
public:
DfgPassManagerImpl
()
{
// TODO(Superjomn) set the key with pass reprs.
LOG
(
INFO
)
<<
"-----------------------------------------------------------------"
;
if
(
FLAGS_IA_enable_ir
)
{
AddPass
(
"fluid-to-ir-pass"
,
new
FluidToIrPass
);
}
else
{
if
(
!
FLAGS_IA_enable_ir
)
{
AddPass
(
"fluid-to-data-flow-graph"
,
new
FluidToDataFlowGraphPass
);
}
else
{
AddPass
(
"fluid-to-ir-pass"
,
new
FluidToIrPass
);
}
TryAddTensorRtPass
();
AddPass
(
"data-flow-graph-to-fluid"
,
new
DataFlowGraphToFluidPass
);
if
(
!
FLAGS_IA_output_storage_path
.
empty
())
{
AddPass
(
"model-store-pass"
,
new
ModelStorePass
);
}
LOG
(
INFO
)
<<
"-----------------------------------------------------------------"
;
}
std
::
string
repr
()
const
override
{
return
"dfg-pass-manager"
;
}
...
...
@@ -101,19 +98,16 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer
::
Analyzer
()
{
Register
(
"manager1"
,
new
DfgPassManagerImpl
);
}
void
Analyzer
::
Run
(
Argument
*
argument
)
{
std
::
vector
<
std
::
string
>
passes
;
for
(
auto
&
pass
:
all_ir_passes_
)
{
if
(
!
disabled_ir_passes_
.
count
(
pass
))
{
passes
.
push_back
(
pass
);
passes
.
push_back
(
"graph_viz_pass"
);
// add graphviz for debug.
}
}
passes
.
push_back
(
"graph_viz_pass"
);
// Ugly support fluid-to-ir-pass
argument
->
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
({
// Manual update the passes here.
"graph_viz_pass"
,
//
"infer_clean_graph_pass"
,
"graph_viz_pass"
,
//
"attention_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"mul_lstm_fuse_pass"
,
"graph_viz_pass"
,
//
"seq_concat_fc_fuse_pass"
,
"graph_viz_pass"
,
//
"fc_fuse_pass"
,
"graph_viz_pass"
//
}));
argument
->
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
(
passes
));
for
(
auto
&
x
:
data_
)
{
PADDLE_ENFORCE
(
x
->
Initialize
(
argument
));
...
...
@@ -122,6 +116,11 @@ void Analyzer::Run(Argument* argument) {
}
}
Analyzer
&
Analyzer
::
DisableIrPasses
(
const
std
::
vector
<
std
::
string
>&
passes
)
{
disabled_ir_passes_
.
insert
(
passes
.
begin
(),
passes
.
end
());
return
*
this
;
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
9df2d8b5
...
...
@@ -36,16 +36,10 @@ limitations under the License. */
*/
#include <gflags/gflags.h>
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available.
DECLARE_bool
(
IA_enable_tensorrt_subgraph_engine
);
DECLARE_string
(
IA_graphviz_log_root
);
DECLARE_string
(
IA_output_storage_path
);
DECLARE_bool
(
IA_enable_ir
);
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
...
...
@@ -57,7 +51,26 @@ class Analyzer : public OrderedRegistry<PassManager> {
void
Run
(
Argument
*
argument
);
Analyzer
&
DisableIrPasses
(
const
std
::
vector
<
std
::
string
>&
passes
);
DISABLE_COPY_AND_ASSIGN
(
Analyzer
);
private:
// All avaiable IR passes.
// The bigger fuse comes first, so that the small operators prefer to be
// merged in a larger fuse op. The small fusion will not break the pattern of
// larger fusion.
const
std
::
vector
<
std
::
string
>
all_ir_passes_
{{
// Manual update the passes here.
"infer_clean_graph_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"fc_lstm_fuse_pass"
,
//
"mul_lstm_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
}};
std
::
unordered_set
<
std
::
string
>
disabled_ir_passes_
;
};
}
// namespace analysis
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
9df2d8b5
...
...
@@ -271,17 +271,22 @@ void TestDituRNNPrediction(const std::string &model_path,
const
std
::
string
&
data_path
,
int
batch_size
,
bool
use_analysis
,
bool
activate_ir
,
int
num_times
=
1
)
{
Native
Config
config
;
Analysis
Config
config
;
config
.
prog_file
=
FLAGS_infer_ditu_rnn_model
+
"/__model__"
;
config
.
param_file
=
FLAGS_infer_ditu_rnn_model
+
"/param"
;
config
.
use_gpu
=
false
;
config
.
device
=
0
;
config
.
specify_input_name
=
true
;
config
.
enable_ir_optim
=
activate_ir
;
PADDLE_ENFORCE
(
config
.
ir_mode
==
AnalysisConfig
::
IrPassMode
::
kExclude
);
// default
config
.
ir_passes
.
clear
();
// Do not exclude any pass.
auto
base_predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
auto
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kAnalysis
>
(
config
);
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
config
);
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
data_path
,
batch_size
);
// Prepare inputs.
...
...
paddle/fluid/inference/analysis/flags.h
0 → 100644
浏览文件 @
9df2d8b5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available.
DECLARE_bool
(
IA_enable_tensorrt_subgraph_engine
);
DECLARE_string
(
IA_graphviz_log_root
);
DECLARE_string
(
IA_output_storage_path
);
DECLARE_bool
(
IA_enable_ir
);
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
浏览文件 @
9df2d8b5
...
...
@@ -15,6 +15,7 @@
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
...
...
@@ -85,9 +86,11 @@ class FluidToIrPass final : public DataFlowGraphPass {
new
Scope
*
(
&
argument_
->
Get
<
Scope
>
(
ir
::
kParamScopeAttr
)));
}
const
auto
&
ir_passes_to_apply
=
argument_
->
Get
<
std
::
vector
<
std
::
string
>>
(
kFluidToIrPassesAttr
);
ir_passes
.
Apply
(
ir_passes_to_apply
);
if
(
FLAGS_IA_enable_ir
)
{
const
auto
&
ir_passes_to_apply
=
argument_
->
Get
<
std
::
vector
<
std
::
string
>>
(
kFluidToIrPassesAttr
);
ir_passes
.
Apply
(
ir_passes_to_apply
);
}
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
...
...
paddle/fluid/inference/analysis/test_text_classification.cc
0 → 100644
浏览文件 @
9df2d8b5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/timer.h"
DEFINE_string
(
infer_model
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
infer_data
,
""
,
"Path of the dataset."
);
DEFINE_int32
(
batch_size
,
1
,
"batch size."
);
DEFINE_int32
(
repeat
,
1
,
"How many times to repeat run."
);
namespace
paddle
{
template
<
typename
T
>
std
::
string
to_string
(
const
std
::
vector
<
T
>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
c
:
vec
)
{
ss
<<
c
<<
" "
;
}
return
ss
.
str
();
}
void
PrintTime
(
const
double
latency
,
const
int
bs
,
const
int
repeat
)
{
LOG
(
INFO
)
<<
"===========profile result==========="
;
LOG
(
INFO
)
<<
"batch_size: "
<<
bs
<<
", repeat: "
<<
repeat
<<
", avg latency: "
<<
latency
/
repeat
<<
"ms"
;
LOG
(
INFO
)
<<
"====================================="
;
}
void
Main
(
int
batch_size
)
{
// Three sequence inputs.
std
::
vector
<
PaddleTensor
>
input_slots
(
1
);
// one batch starts
// data --
int64_t
data0
[]
=
{
0
,
1
,
2
};
for
(
auto
&
input
:
input_slots
)
{
input
.
data
.
Reset
(
data0
,
sizeof
(
data0
));
input
.
shape
=
std
::
vector
<
int
>
({
3
,
1
});
// dtype --
input
.
dtype
=
PaddleDType
::
INT64
;
// LoD --
input
.
lod
=
std
::
vector
<
std
::
vector
<
size_t
>>
({{
0
,
3
}});
}
// shape --
// Create Predictor --
AnalysisConfig
config
;
config
.
model_dir
=
FLAGS_infer_model
;
config
.
use_gpu
=
false
;
config
.
enable_ir_optim
=
true
;
config
.
ir_passes
.
push_back
(
"fc_lstm_fuse_pass"
);
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
config
);
inference
::
Timer
timer
;
double
sum
=
0
;
std
::
vector
<
PaddleTensor
>
output_slots
;
for
(
int
i
=
0
;
i
<
FLAGS_repeat
;
i
++
)
{
timer
.
tic
();
CHECK
(
predictor
->
Run
(
input_slots
,
&
output_slots
));
sum
+=
timer
.
toc
();
}
PrintTime
(
sum
,
batch_size
,
FLAGS_repeat
);
// Get output
LOG
(
INFO
)
<<
"get outputs "
<<
output_slots
.
size
();
for
(
auto
&
output
:
output_slots
)
{
LOG
(
INFO
)
<<
"output.shape: "
<<
to_string
(
output
.
shape
);
// no lod ?
CHECK_EQ
(
output
.
lod
.
size
(),
0UL
);
LOG
(
INFO
)
<<
"output.dtype: "
<<
output
.
dtype
;
std
::
stringstream
ss
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
ss
<<
static_cast
<
float
*>
(
output
.
data
.
data
())[
i
]
<<
" "
;
}
LOG
(
INFO
)
<<
"output.data summary: "
<<
ss
.
str
();
// one batch ends
}
}
TEST
(
text_classification
,
basic
)
{
Main
(
FLAGS_batch_size
);
}
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
seq_concat_fc_fuse_pass
);
USE_PASS
(
fc_lstm_fuse_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
attention_lstm_fuse_pass
);
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
9df2d8b5
...
...
@@ -44,7 +44,19 @@ function(inference_api_test TARGET_NAME)
endfunction
(
inference_api_test
)
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor
)
cc_library
(
analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api
)
cc_library
(
analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api
analysis
ir_pass_manager
pass
fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detector
infer_clean_graph_pass
attention_lstm_fuse_pass
)
cc_test
(
test_paddle_inference_api
SRCS api_tester.cc
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
9df2d8b5
...
...
@@ -14,6 +14,8 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -28,6 +30,8 @@ bool AnalysisPredictor::Init(
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
LOG
(
WARNING
)
<<
"ir optimize only supports CPU currently"
;
config_
.
enable_ir_optim
=
false
;
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
}
...
...
@@ -73,7 +77,7 @@ bool AnalysisPredictor::Init(
void
AnalysisPredictor
::
OptimizeInferenceProgram
()
{
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_ir
=
config_
.
enable_ir_optim
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Analyze inference_program
...
...
@@ -90,24 +94,26 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
argument_
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Analyzer
().
Run
(
&
argument_
);
PADDLE_ENFORCE
(
config_
.
ir_mode
==
AnalysisConfig
::
IrPassMode
::
kExclude
,
"Only kExclude is supported yet."
);
Analyzer
().
DisableIrPasses
(
config_
.
ir_passes
).
Run
(
&
argument_
);
CHECK
(
argument_
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument_
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument_
.
Has
(
framework
::
ir
::
kParamScopeAttr
));
// Update scope.
scope_
.
reset
(
argument_
.
Release
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
));
LOG
(
INFO
)
<<
"optimize end =="
;
if
(
argument_
.
Has
(
framework
::
ir
::
kParamScopeAttr
))
{
// Update scope.
scope_
.
reset
(
argument_
.
Release
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
));
}
LOG
(
INFO
)
<<
"== optimize end =="
;
}
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
Native
Config
&
config
)
{
VLOG
(
3
)
<<
"create
NativePredictor
"
;
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
Analysis
Config
&
config
)
{
VLOG
(
3
)
<<
"create
AnalysisConfig
"
;
if
(
config
.
use_gpu
)
{
// 1. GPU memeroy
PADDLE_ENFORCE_GT
(
...
...
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
9df2d8b5
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
...
...
@@ -28,7 +30,7 @@ using framework::proto::ProgramDesc;
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
Native
Config
&
config
)
explicit
AnalysisPredictor
(
const
Analysis
Config
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
);
...
...
@@ -44,7 +46,7 @@ class AnalysisPredictor : public NativePaddlePredictor {
Argument
&
analysis_argument
()
{
return
argument_
;
}
private:
Native
Config
config_
;
Analysis
Config
config_
;
Argument
argument_
;
};
...
...
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
9df2d8b5
...
...
@@ -176,7 +176,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
framework
::
Scope
*
scope
)
{
VLOG
(
3
)
<<
"Predictor::set_feed"
;
if
(
inputs
.
size
()
!=
feeds_
.
size
())
{
LOG
(
ERROR
)
<<
"wrong feed input size."
;
LOG
(
ERROR
)
<<
"wrong feed input size, need "
<<
feeds_
.
size
()
<<
" but get "
<<
inputs
.
size
();
return
false
;
}
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
9df2d8b5
...
...
@@ -150,6 +150,21 @@ struct TensorRTConfig : public NativeConfig {
int
workspace_size
{
1
<<
30
};
};
// NOTE WIP, not stable yet.
struct
AnalysisConfig
:
public
NativeConfig
{
//
enum
class
IrPassMode
{
kSystem
,
// Use system default passes, not customize.
kInclude
,
// Specify the passes in `ir_passes`.
kExclude
// Specify the disabled passes in `ir_passes`.
};
bool
enable_ir_optim
=
true
;
IrPassMode
ir_mode
{
IrPassMode
::
kExclude
};
// attention lstm fuse works only on some specific models, disable as default.
std
::
vector
<
std
::
string
>
ir_passes
{
"attention_lstm_fuse_pass"
};
};
// A factory to help create different predictors.
//
// FOR EXTENSION DEVELOPER:
...
...
paddle/fluid/operators/lookup_table_op.h
浏览文件 @
9df2d8b5
...
...
@@ -57,7 +57,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
row_number
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
,
"ids %d"
,
i
);
memcpy
(
output
+
i
*
row_width
,
table
+
ids
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录