Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
c73c5ed5
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c73c5ed5
编写于
8月 22, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
差异文件
use for_range
上级
b548ecbc
e8b4e0d6
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
1123 addition
and
260 deletion
+1123
-260
cmake/external/anakin.cmake
cmake/external/anakin.cmake
+23
-6
doc/fluid/dev/new_op_cn.md
doc/fluid/dev/new_op_cn.md
+33
-1
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+1
-1
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+33
-12
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+1
-2
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+1
-2
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+263
-3
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+11
-5
paddle/fluid/inference/api/api.cc
paddle/fluid/inference/api/api.cc
+0
-3
paddle/fluid/inference/api/api_anakin_engine.cc
paddle/fluid/inference/api/api_anakin_engine.cc
+114
-23
paddle/fluid/inference/api/api_anakin_engine.h
paddle/fluid/inference/api/api_anakin_engine.h
+2
-4
paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc
paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc
+315
-0
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+5
-2
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+110
-0
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+4
-2
paddle/fluid/operators/mul_op.cc
paddle/fluid/operators/mul_op.cc
+3
-3
paddle/fluid/operators/stack_op.cc
paddle/fluid/operators/stack_op.cc
+4
-42
paddle/fluid/operators/stack_op.cu
paddle/fluid/operators/stack_op.cu
+4
-88
paddle/fluid/operators/stack_op.h
paddle/fluid/operators/stack_op.h
+116
-40
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+3
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+77
-16
python/paddle/fluid/tests/unittests/test_profiler.py
python/paddle/fluid/tests/unittests/test_profiler.py
+0
-5
未找到文件。
cmake/external/anakin.cmake
浏览文件 @
c73c5ed5
...
...
@@ -2,6 +2,11 @@ if (NOT WITH_ANAKIN)
return
()
endif
()
option
(
ANAKIN_ENABLE_OP_TIMER
"Get more detailed information with Anakin op time"
OFF
)
if
(
ANAKIN_ENABLE_OP_TIMER
)
add_definitions
(
-DPADDLE_ANAKIN_ENABLE_OP_TIMER
)
endif
()
INCLUDE
(
ExternalProject
)
set
(
ANAKIN_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/anakin
)
# the anakin install dir is only default one now
...
...
@@ -11,23 +16,34 @@ set(ANAKIN_LIBRARY ${ANAKIN_INSTALL_DIR})
set
(
ANAKIN_SHARED_LIB
${
ANAKIN_LIBRARY
}
/libanakin.so
)
set
(
ANAKIN_SABER_LIB
${
ANAKIN_LIBRARY
}
/libanakin_saber_common.so
)
# TODO(luotao): ANAKIN_MODLE_URL will move to demo ci later.
set
(
ANAKIN_MODLE_URL
"http://paddle-inference-dist.bj.bcebos.com/mobilenet_v2.anakin.bin"
)
# TODO(luotao): ANAKIN_MODLE_URL etc will move to demo ci later.
set
(
INFERENCE_URL
"http://paddle-inference-dist.bj.bcebos.com"
)
set
(
ANAKIN_MODLE_URL
"
${
INFERENCE_URL
}
/mobilenet_v2.anakin.bin"
)
set
(
ANAKIN_RNN_MODLE_URL
"
${
INFERENCE_URL
}
/anakin_test%2Fditu_rnn.anakin2.model.bin"
)
set
(
ANAKIN_RNN_DATA_URL
"
${
INFERENCE_URL
}
/anakin_test%2Fditu_rnn_data.txt"
)
execute_process
(
COMMAND bash -c
"mkdir -p
${
ANAKIN_SOURCE_DIR
}
"
)
execute_process
(
COMMAND bash -c
"cd
${
ANAKIN_SOURCE_DIR
}
; wget -q --no-check-certificate
${
ANAKIN_MODLE_URL
}
"
)
execute_process
(
COMMAND bash -c
"cd
${
ANAKIN_SOURCE_DIR
}
; wget -q --no-check-certificate
${
ANAKIN_MODLE_URL
}
-N"
)
execute_process
(
COMMAND bash -c
"cd
${
ANAKIN_SOURCE_DIR
}
; wget -q --no-check-certificate
${
ANAKIN_RNN_MODLE_URL
}
-N"
)
execute_process
(
COMMAND bash -c
"cd
${
ANAKIN_SOURCE_DIR
}
; wget -q --no-check-certificate
${
ANAKIN_RNN_DATA_URL
}
-N"
)
include_directories
(
${
ANAKIN_INCLUDE
}
)
include_directories
(
${
ANAKIN_INCLUDE
}
/saber/
)
include_directories
(
${
ANAKIN_INCLUDE
}
/saber/core/
)
include_directories
(
${
ANAKIN_INCLUDE
}
/saber/funcs/impl/x86/
)
include_directories
(
${
ANAKIN_INCLUDE
}
/saber/funcs/impl/cuda/base/cuda_c/
)
set
(
ANAKIN_COMPILE_EXTRA_FLAGS
-Wno-error=unused-but-set-variable -Wno-unused-but-set-variable
-Wno-error=unused-variable -Wno-unused-variable
-Wno-error=format-extra-args -Wno-format-extra-args
-Wno-error=comment -Wno-comment
-Wno-error=format -Wno-format
-Wno-error=comment -Wno-comment
-Wno-error=format -Wno-format
-Wno-error=maybe-uninitialized -Wno-maybe-uninitialized
-Wno-error=switch -Wno-switch
-Wno-error=return-type -Wno-return-type
-Wno-error=non-virtual-dtor -Wno-non-virtual-dtor
-Wno-error=ignored-qualifiers
-Wno-ignored-qualifiers
-Wno-sign-compare
-Wno-reorder
-Wno-error=cpp
)
...
...
@@ -38,7 +54,7 @@ ExternalProject_Add(
DEPENDS
${
MKLML_PROJECT
}
# Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code.
GIT_REPOSITORY
"https://github.com/luotao1/Anakin"
GIT_TAG
"
bcf17aabe7921ceb7bce591244b4f9dce7dba5c8
"
GIT_TAG
"
211d1fc5d813d70c0c14072f9083cf25f40940ea
"
PREFIX
${
ANAKIN_SOURCE_DIR
}
UPDATE_COMMAND
""
CMAKE_ARGS -DUSE_GPU_PLACE=YES
...
...
@@ -48,6 +64,7 @@ ExternalProject_Add(
-DMKLML_ROOT=
${
THIRD_PARTY_PATH
}
/install/mklml
-DCUDNN_ROOT=
${
CUDNN_ROOT
}
-DCUDNN_INCLUDE_DIR=
${
CUDNN_INCLUDE_DIR
}
-DENABLE_OP_TIMER=
${
ANAKIN_ENABLE_OP_TIMER
}
${
EXTERNAL_OPTIONAL_ARGS
}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
ANAKIN_INSTALL_DIR
}
)
...
...
doc/fluid/dev/new_op_cn.md
浏览文件 @
c73c5ed5
...
...
@@ -119,10 +119,29 @@ $$Out = scale*X$$
这个例子有
`AddAttr<AttrType>("scale", "...").SetDefault(1.0);`
: 增加
`scale`
系数,作为参数属性,并且设置默认值为1.0。
### 定义GradProtoMaker类
每个Op的必须有一个对应的GraProtoMaker,若未定制对应前向Op的GradProtoMaker,fluid提供了DefaultGradProtoMaker,默认注册会使用全部输入输出,包括Input, Output, Output@Grad等,使用不需要的变量的会造成显存浪费。
下面示例定义了ScaleOp的GradProtoMaker。
```
cpp
class
ScaleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttr
(
"scale"
,
GetAttr
(
"scale"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
```
### 定义Operator类
下面
的点
实现了MulOp的定义:
下面实现了MulOp的定义:
```
cpp
class
MulOp
:
public
framework
::
OperatorWithKernel
{
...
...
@@ -383,6 +402,19 @@ PADDLE_ENFORCE(forward_pd != nullptr,
"Fail to find eltwise_fwd_pd in device context"); //eltwise_fwd_pd用户可能看不懂
```
3.
OP内部调用非法接口:Op内部如果出现Output = ShareDataWith(Input)
问题示例:
```
cpp
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
out
->
ShareDataWith
(
*
in
);
```
Op内部如果出现Output = ShareDataWith(Input),相当于operator图的中有一条隐藏边,连接了Input和Output,这条边无法在图分析中表达,引发基于图优化的错误。
4.
OP实现的性能实践
调用了eigen的broadcast, chop等操作,性能会比手写cuda kernel差几倍以上。此时cpu的实现可以复用eigen,gpu实现可以实现cuda kernel.
#### OP InferShape检查提示信息特别说明
-
检查输入输出变量,请统一遵循以下格式
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
c73c5ed5
...
...
@@ -104,7 +104,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for
(
auto
&
adj_n
:
var
->
inputs
)
{
PADDLE_ENFORCE
(
adj_n
->
NodeType
()
==
ir
::
Node
::
Type
::
kOperation
);
adj_list
[
n
].
insert
(
adj_n
);
VLOG
(
3
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
VLOG
(
4
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
<<
" -> "
<<
n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
n
)
<<
" via "
<<
var
->
Name
()
<<
reinterpret_cast
<
void
*>
(
var
);
}
...
...
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
c73c5ed5
...
...
@@ -22,7 +22,7 @@ function (inference_analysis_test TARGET)
if
(
WITH_TESTING
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS
)
set
(
multiValueArgs SRCS
EXTRA_DEPS
)
cmake_parse_arguments
(
analysis_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
set
(
mem_opt
""
)
...
...
@@ -31,22 +31,43 @@ function (inference_analysis_test TARGET)
endif
()
cc_test
(
${
TARGET
}
SRCS
"
${
analysis_test_SRCS
}
"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass
${
analysis_test_EXTRA_DEPS
}
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
${
mem_opt
}
)
set_tests_properties
(
${
TARGET
}
PROPERTIES DEPENDS test_word2vec
)
endif
(
WITH_TESTING
)
endfunction
(
inference_analysis_test
)
cc_test
(
test_analyzer SRCS analyzer_tester.cc DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
# ir
fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detecter
pass
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
)
#set_tests_properties(test_analyzer PROPERTIES DEPENDS test_word2vec)
#inference_api_test(test_analyzer SRC analyzer_tester.cc ARGS test_word2vec)
set
(
DITU_RNN_MODEL_URL
"http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fmodel.tar.gz"
)
set
(
DITU_RNN_DATA_URL
"http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fdata.txt.tar.gz"
)
set
(
DITU_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/install/ditu_rnn"
CACHE PATH
"Ditu RNN model and data root."
FORCE
)
set
(
DITU_RNN_MODEL
${
DITU_INSTALL_DIR
}
/model
)
set
(
DITU_RNN_DATA
${
DITU_INSTALL_DIR
}
/data.txt
)
function
(
inference_download_and_uncompress target url gz_filename
)
message
(
STATUS
"Download inference test stuff
${
gz_filename
}
from
${
url
}
"
)
execute_process
(
COMMAND bash -c
"mkdir -p
${
DITU_INSTALL_DIR
}
"
)
execute_process
(
COMMAND bash -c
"cd
${
DITU_INSTALL_DIR
}
&& wget -q
${
url
}
"
)
execute_process
(
COMMAND bash -c
"cd
${
DITU_INSTALL_DIR
}
&& tar xzf
${
gz_filename
}
"
)
message
(
STATUS
"finish downloading
${
gz_filename
}
"
)
endfunction
(
inference_download_and_uncompress
)
if
(
NOT EXISTS
${
DITU_INSTALL_DIR
}
)
inference_download_and_uncompress
(
ditu_rnn_model
${
DITU_RNN_MODEL_URL
}
"ditu_rnn_fluid%2Fmodel.tar.gz"
)
inference_download_and_uncompress
(
ditu_rnn_data
${
DITU_RNN_DATA_URL
}
"ditu_rnn_fluid%2Fdata.txt.tar.gz"
)
endif
()
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
# ir
fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detecter
infer_clean_graph_pass
pass
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
--infer_ditu_rnn_model=
${
DITU_INSTALL_DIR
}
/model
--infer_ditu_rnn_data=
${
DITU_INSTALL_DIR
}
/data.txt
)
inference_analysis_test
(
test_data_flow_graph SRCS data_flow_graph_tester.cc
)
inference_analysis_test
(
test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc
)
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
c73c5ed5
...
...
@@ -23,8 +23,6 @@
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
namespace
paddle
{
DEFINE_bool
(
IA_enable_tensorrt_subgraph_engine
,
false
,
"Enable subgraph to TensorRT engine for acceleration"
);
...
...
@@ -35,6 +33,7 @@ DEFINE_string(IA_graphviz_log_root, "./",
DEFINE_string
(
IA_output_storage_path
,
""
,
"optimized model output path"
);
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
c73c5ed5
...
...
@@ -39,8 +39,6 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
namespace
paddle
{
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available.
DECLARE_bool
(
IA_enable_tensorrt_subgraph_engine
);
...
...
@@ -48,6 +46,7 @@ DECLARE_string(IA_graphviz_log_root);
DECLARE_string
(
IA_output_storage_path
);
DECLARE_bool
(
IA_enable_ir
);
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
c73c5ed5
...
...
@@ -13,11 +13,17 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_data
,
""
,
"data path for ditu RNN"
);
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
...
...
@@ -38,7 +44,7 @@ TEST(Analyzer, analysis_with_tensorrt) {
analyser
.
Run
(
&
argument
);
}
void
TestWord2vecPrediction
(
const
std
::
string
&
model_path
)
{
void
TestWord2vecPrediction
(
const
std
::
string
&
model_path
)
{
NativeConfig
config
;
config
.
model_dir
=
model_path
;
config
.
use_gpu
=
false
;
...
...
@@ -69,12 +75,245 @@ void TestWord2vecPrediction(const std::string& model_path) {
// The outputs' buffers are in CPU memory.
for
(
size_t
i
=
0
;
i
<
std
::
min
(
5UL
,
num_elements
);
i
++
)
{
LOG
(
INFO
)
<<
"data: "
<<
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())[
i
];
PADDLE_ENFORCE
(
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())[
i
],
<<
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())[
i
];
PADDLE_ENFORCE
(
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())[
i
],
result
[
i
]);
}
}
namespace
{
struct
DataRecord
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
link_step_data_all
;
std
::
vector
<
std
::
vector
<
float
>>
week_data_all
,
minute_data_all
;
std
::
vector
<
size_t
>
lod1
,
lod2
,
lod3
;
std
::
vector
<
std
::
vector
<
float
>>
rnn_link_data
,
rnn_week_datas
,
rnn_minute_datas
;
size_t
batch_iter
{
0
};
size_t
batch_size
{
1
};
DataRecord
()
=
default
;
DataRecord
(
const
std
::
string
&
path
,
int
batch_size
=
1
)
:
batch_size
(
batch_size
)
{
Load
(
path
);
}
DataRecord
NextBatch
()
{
DataRecord
data
;
size_t
batch_end
=
batch_iter
+
batch_size
;
// NOTE skip the final batch, if no enough data is provided.
if
(
batch_end
<=
link_step_data_all
.
size
())
{
data
.
link_step_data_all
.
assign
(
link_step_data_all
.
begin
()
+
batch_iter
,
link_step_data_all
.
begin
()
+
batch_end
);
data
.
week_data_all
.
assign
(
week_data_all
.
begin
()
+
batch_iter
,
week_data_all
.
begin
()
+
batch_end
);
data
.
minute_data_all
.
assign
(
minute_data_all
.
begin
()
+
batch_iter
,
minute_data_all
.
begin
()
+
batch_end
);
// Prepare LoDs
data
.
lod1
.
push_back
(
0
);
data
.
lod2
.
push_back
(
0
);
data
.
lod3
.
push_back
(
0
);
CHECK
(
!
data
.
link_step_data_all
.
empty
())
<<
"empty"
;
CHECK
(
!
data
.
week_data_all
.
empty
());
CHECK
(
!
data
.
minute_data_all
.
empty
());
CHECK_EQ
(
data
.
link_step_data_all
.
size
(),
data
.
week_data_all
.
size
());
CHECK_EQ
(
data
.
minute_data_all
.
size
(),
data
.
link_step_data_all
.
size
());
for
(
size_t
j
=
0
;
j
<
data
.
link_step_data_all
.
size
();
j
++
)
{
for
(
const
auto
&
d
:
data
.
link_step_data_all
[
j
])
{
data
.
rnn_link_data
.
push_back
(
d
);
}
data
.
rnn_week_datas
.
push_back
(
data
.
week_data_all
[
j
]);
data
.
rnn_minute_datas
.
push_back
(
data
.
minute_data_all
[
j
]);
// calculate lod
data
.
lod1
.
push_back
(
data
.
lod1
.
back
()
+
data
.
link_step_data_all
[
j
].
size
());
data
.
lod3
.
push_back
(
data
.
lod3
.
back
()
+
1
);
for
(
size_t
i
=
1
;
i
<
data
.
link_step_data_all
[
j
].
size
()
+
1
;
i
++
)
{
data
.
lod2
.
push_back
(
data
.
lod2
.
back
()
+
data
.
link_step_data_all
[
j
].
size
());
}
}
}
batch_iter
+=
batch_size
;
return
data
;
}
void
Load
(
const
std
::
string
&
path
)
{
std
::
ifstream
file
(
path
);
std
::
string
line
;
int
num_lines
=
0
;
while
(
std
::
getline
(
file
,
line
))
{
num_lines
++
;
std
::
vector
<
std
::
string
>
data
;
split
(
line
,
':'
,
&
data
);
std
::
vector
<
std
::
vector
<
float
>>
link_step_data
;
std
::
vector
<
std
::
string
>
link_datas
;
split
(
data
[
0
],
'|'
,
&
link_datas
);
for
(
auto
&
step_data
:
link_datas
)
{
std
::
vector
<
float
>
tmp
;
split_to_float
(
step_data
,
','
,
&
tmp
);
link_step_data
.
push_back
(
tmp
);
}
// load week data
std
::
vector
<
float
>
week_data
;
split_to_float
(
data
[
2
],
','
,
&
week_data
);
// load minute data
std
::
vector
<
float
>
minute_data
;
split_to_float
(
data
[
1
],
','
,
&
minute_data
);
link_step_data_all
.
push_back
(
std
::
move
(
link_step_data
));
week_data_all
.
push_back
(
std
::
move
(
week_data
));
minute_data_all
.
push_back
(
std
::
move
(
minute_data
));
}
}
};
void
PrepareInputs
(
std
::
vector
<
PaddleTensor
>
*
input_slots
,
DataRecord
*
data
,
int
batch_size
)
{
// DataRecord data(FLAGS_datapath, batch_size);
PaddleTensor
lod_attention_tensor
,
init_zero_tensor
,
lod_tensor_tensor
,
week_tensor
,
minute_tensor
;
lod_attention_tensor
.
name
=
"data_lod_attention"
;
init_zero_tensor
.
name
=
"cell_init"
;
lod_tensor_tensor
.
name
=
"data"
;
week_tensor
.
name
=
"week"
;
minute_tensor
.
name
=
"minute"
;
auto
one_batch
=
data
->
NextBatch
();
// clang-format off
std
::
vector
<
int
>
rnn_link_data_shape
({
static_cast
<
int
>
(
one_batch
.
rnn_link_data
.
size
()),
static_cast
<
int
>
(
one_batch
.
rnn_link_data
.
front
().
size
())});
lod_attention_tensor
.
shape
.
assign
({
1
,
2
});
lod_attention_tensor
.
lod
.
assign
({
one_batch
.
lod1
,
one_batch
.
lod2
});
init_zero_tensor
.
shape
.
assign
({
batch_size
,
15
});
init_zero_tensor
.
lod
.
assign
({
one_batch
.
lod3
});
lod_tensor_tensor
.
shape
=
rnn_link_data_shape
;
lod_tensor_tensor
.
lod
.
assign
({
one_batch
.
lod1
});
week_tensor
.
shape
.
assign
({(
int
)
one_batch
.
rnn_week_datas
.
size
(),
(
int
)
one_batch
.
rnn_week_datas
.
front
().
size
()});
week_tensor
.
lod
.
assign
({
one_batch
.
lod3
});
minute_tensor
.
shape
.
assign
({(
int
)
one_batch
.
rnn_minute_datas
.
size
(),
(
int
)
one_batch
.
rnn_minute_datas
.
front
().
size
()});
minute_tensor
.
lod
.
assign
({
one_batch
.
lod3
});
// assign data
TensorAssignData
(
&
lod_attention_tensor
,
std
::
vector
<
std
::
vector
<
float
>>
({{
0
,
0
}}));
std
::
vector
<
float
>
tmp_zeros
(
batch_size
*
15
,
0.
);
TensorAssignData
(
&
init_zero_tensor
,
{
tmp_zeros
});
TensorAssignData
(
&
lod_tensor_tensor
,
one_batch
.
rnn_link_data
);
TensorAssignData
(
&
week_tensor
,
one_batch
.
rnn_week_datas
);
TensorAssignData
(
&
minute_tensor
,
one_batch
.
rnn_minute_datas
);
// clang-format on
// Set inputs.
auto
init_zero_tensor1
=
init_zero_tensor
;
init_zero_tensor1
.
name
=
"hidden_init"
;
input_slots
->
assign
({
week_tensor
,
init_zero_tensor
,
minute_tensor
,
init_zero_tensor1
,
lod_attention_tensor
,
lod_tensor_tensor
});
for
(
auto
&
tensor
:
*
input_slots
)
{
tensor
.
dtype
=
PaddleDType
::
FLOAT32
;
}
}
std
::
string
DescribeTensor
(
const
PaddleTensor
&
tensor
)
{
std
::
stringstream
os
;
os
<<
"Tensor ["
<<
tensor
.
name
<<
"]
\n
"
;
os
<<
" - type: "
;
switch
(
tensor
.
dtype
)
{
case
PaddleDType
::
FLOAT32
:
os
<<
"float32"
;
break
;
case
PaddleDType
::
INT64
:
os
<<
"int64"
;
break
;
default:
os
<<
"unset"
;
}
os
<<
'\n'
;
os
<<
" - shape: "
<<
to_string
(
tensor
.
shape
)
<<
'\n'
;
os
<<
" - lod: "
;
for
(
auto
&
l
:
tensor
.
lod
)
{
os
<<
to_string
(
l
)
<<
"; "
;
}
os
<<
"
\n
"
;
os
<<
" - data: "
;
// clang-format off
int
dim
=
std
::
accumulate
(
tensor
.
shape
.
begin
(),
tensor
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
// clang-format on
for
(
size_t
i
=
0
;
i
<
dim
;
i
++
)
{
os
<<
static_cast
<
float
*>
(
tensor
.
data
.
data
())[
i
]
<<
" "
;
}
os
<<
'\n'
;
return
os
.
str
();
}
}
// namespace
const
float
ditu_rnn_target_data
[]
=
{
104.711
,
11.2431
,
1.35422
,
0
,
0
,
0
,
0
,
0
,
27.7039
,
1.41486
,
7.09526
,
0
,
0
,
0
,
0
,
0
,
7.6481
,
6.5324
,
56.383
,
2.88018
,
8.92918
,
132.007
,
4.27429
,
2.02934
,
14.1727
,
10.7461
,
25.0616
,
16.0197
,
14.4163
,
16.9199
,
6.75517
,
0
,
80.0249
,
4.77739
,
0
,
0
,
0
,
0
,
0
,
0
,
47.5643
,
2.67029
,
8.76252
,
0
,
0
,
0
,
0
,
0
,
51.8822
,
4.4411
,
0
,
0
,
0
,
0
,
0
,
0
,
10.7286
,
12.0595
,
10.6672
,
0
,
0
,
0
,
0
,
0
,
93.5771
,
3.84641
,
0
,
0
,
0
,
0
,
0
,
0
,
169.426
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
// Test with a really complicate model.
void
TestDituRNNPrediction
(
const
std
::
string
&
model_path
,
const
std
::
string
&
data_path
,
int
batch_size
,
bool
use_analysis
,
bool
activate_ir
,
int
num_times
=
1
)
{
FLAGS_IA_enable_ir
=
activate_ir
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
"./analysis.out"
;
std
::
string
model_out
;
if
(
use_analysis
)
{
Argument
argument
(
model_path
);
argument
.
model_output_store_path
.
reset
(
new
std
::
string
(
"./analysis.out"
));
Analyzer
analyzer
;
analyzer
.
Run
(
&
argument
);
// Should get the transformed model stored to ./analysis.out
model_out
=
"./analysis.out"
;
ASSERT_TRUE
(
PathExists
(
model_out
));
}
else
{
model_out
=
FLAGS_infer_ditu_rnn_model
;
}
NativeConfig
config
;
config
.
prog_file
=
model_out
+
"/__model__"
;
config
.
param_file
=
model_out
+
"/param"
;
config
.
use_gpu
=
false
;
config
.
device
=
0
;
config
.
specify_input_name
=
true
;
auto
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
data_path
,
batch_size
);
// Prepare inputs.
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
std
::
vector
<
PaddleTensor
>
outputs
;
Timer
timer
;
timer
.
tic
();
for
(
int
i
=
0
;
i
<
num_times
;
i
++
)
{
predictor
->
Run
(
input_slots
,
&
outputs
);
}
LOG
(
INFO
)
<<
"time/batch: "
<<
timer
.
toc
()
/
num_times
;
for
(
auto
&
out
:
outputs
)
{
size_t
size
=
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
float
*
data
=
static_cast
<
float
*>
(
out
.
data
.
data
());
for
(
int
i
=
0
;
i
<
std
::
min
(
sizeof
(
ditu_rnn_target_data
)
/
sizeof
(
float
),
size
);
i
++
)
{
EXPECT_NEAR
(
data
[
i
],
ditu_rnn_target_data
[
i
],
1e-3
);
}
}
}
// Turn on the IR pass supportion, run a real inference and check the result.
TEST
(
Analyzer
,
SupportIRPass
)
{
FLAGS_IA_enable_ir
=
true
;
...
...
@@ -94,6 +333,27 @@ TEST(Analyzer, SupportIRPass) {
TestWord2vecPrediction
(
"./analysis.out"
);
}
// Directly infer with the original model.
TEST
(
Analyzer
,
DituRNN_without_analysis
)
{
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
10
,
false
,
false
);
}
// Inference with the original model with the analysis turned on, the analysis
// module will transform the program to a data flow graph.
TEST
(
Analyzer
,
DituRNN_with_analysis
)
{
LOG
(
INFO
)
<<
"ditu rnn with analysis"
;
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
10
,
true
,
false
,
1
);
}
// Inference with analysis and IR. The IR module will fuse some large kernels.
TEST
(
Analyzer
,
DituRNN_with_analysis_with_IR
)
{
LOG
(
INFO
)
<<
"ditu rnn with analysis and IR fuse"
;
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
10
,
true
,
true
,
1
);
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
...
...
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
c73c5ed5
...
...
@@ -18,7 +18,10 @@ if(APPLE)
endif
(
APPLE
)
set
(
inference_deps paddle_inference_api paddle_fluid_api
)
set
(
inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager
graph_viz_pass fc_fuse_pass
infer_clean_graph_pass
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
set
(
inference_deps
${
inference_deps
}
paddle_inference_tensorrt_subgraph_engine
)
...
...
@@ -62,7 +65,7 @@ endif()
if
(
WITH_ANAKIN AND WITH_GPU
)
# only needed in CI
# compile the libinference_anakin_api.a and anakin.so.
cc_library
(
inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber
)
cc_library
(
inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber
mklml
)
cc_library
(
inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber
)
function
(
anakin_target target_name
)
target_compile_options
(
${
target_name
}
BEFORE PUBLIC
${
ANAKIN_COMPILE_EXTRA_FLAGS
}
)
...
...
@@ -70,9 +73,12 @@ if (WITH_ANAKIN AND WITH_GPU) # only needed in CI
anakin_target
(
inference_anakin_api
)
anakin_target
(
inference_anakin_api_shared
)
if
(
WITH_TESTING
)
cc_test
(
inference_anakin_test SRCS api_anakin_engine_tester.cc
cc_test
(
api_anakin_engine_tester SRCS api_anakin_engine_tester.cc
ARGS --model=
${
ANAKIN_SOURCE_DIR
}
/mobilenet_v2.anakin.bin
DEPS inference_anakin_api dynload_cuda SERIAL
)
target_compile_options
(
inference_anakin_test BEFORE PUBLIC
${
ANAKIN_COMPILE_EXTRA_FLAGS
}
)
DEPS inference_anakin_api_shared dynload_cuda SERIAL
)
cc_test
(
api_anakin_engine_rnn_tester SRCS api_anakin_engine_rnn_tester.cc
ARGS --model=
${
ANAKIN_SOURCE_DIR
}
/anakin_test%2Fditu_rnn.anakin2.model.bin
--datapath=
${
ANAKIN_SOURCE_DIR
}
/anakin_test%2Fditu_rnn_data.txt
DEPS inference_anakin_api_shared dynload_cuda SERIAL
)
endif
(
WITH_TESTING
)
endif
()
paddle/fluid/inference/api/api.cc
浏览文件 @
c73c5ed5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
paddle/fluid/inference/api/api_anakin_engine.cc
浏览文件 @
c73c5ed5
...
...
@@ -13,9 +13,22 @@
// limitations under the License.
#include "paddle/fluid/inference/api/api_anakin_engine.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif
#include <mkl_service.h>
#include <omp.h>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "framework/core/net/net.h"
#include "framework/operators/ops.h"
#include "saber/funcs/timer.h"
namespace
paddle
{
template
<
typename
Target
>
...
...
@@ -23,16 +36,24 @@ PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
const
AnakinConfig
&
config
)
{
CHECK
(
Init
(
config
));
}
template
<
>
PaddleInferenceAnakinPredictor
<
anakin
::
X86
>::
PaddleInferenceAnakinPredictor
(
const
AnakinConfig
&
config
)
{
omp_set_dynamic
(
0
);
omp_set_num_threads
(
1
);
mkl_set_num_threads
(
1
);
CHECK
(
Init
(
config
));
}
template
<
typename
Target
>
bool
PaddleInferenceAnakinPredictor
<
Target
>::
Init
(
const
AnakinConfig
&
config
)
{
if
(
!
(
graph_
.
load
(
config
.
model_file
)))
{
LOG
(
FATAL
)
<<
"fail to load graph from "
<<
config
.
model_file
;
VLOG
(
3
)
<<
"fail to load graph from "
<<
config
.
model_file
;
return
false
;
}
auto
inputs
=
graph_
.
get_ins
();
for
(
auto
&
input_str
:
inputs
)
{
graph_
.
ResetBatchSize
(
input_str
,
config
.
max_batch_size
);
max_batch_size_
=
config
.
max_batch_size
;
}
// optimization for graph
if
(
!
(
graph_
.
Optimize
()))
{
...
...
@@ -52,15 +73,15 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
std
::
vector
<
PaddleTensor
>
*
output_data
,
int
batch_size
)
{
for
(
const
auto
&
input
:
inputs
)
{
if
(
input
.
dtype
!=
PaddleDType
::
FLOAT32
)
{
LOG
(
ERROR
)
<<
"Only support float type inputs. "
<<
input
.
name
<<
"'s type is not float"
;
VLOG
(
3
)
<<
"Only support float type inputs. "
<<
input
.
name
<<
"'s type is not float"
;
return
false
;
}
auto
d_tensor_in_p
=
executor_p_
->
get_in
(
input
.
name
);
auto
net_shape
=
d_tensor_in_p
->
valid_
shape
();
auto
net_shape
=
d_tensor_in_p
->
shape
();
if
(
net_shape
.
size
()
!=
input
.
shape
.
size
())
{
LOG
(
ERROR
)
<<
" input "
<<
input
.
name
<<
"'s shape size should be equal to that of net"
;
VLOG
(
3
)
<<
" input "
<<
input
.
name
<<
"'s shape size should be equal to that of net"
;
return
false
;
}
int
sum
=
1
;
...
...
@@ -79,21 +100,45 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
}
d_tensor_in_p
->
reshape
(
tmp_shape
);
if
(
input
.
lod
.
size
()
>
0
)
{
if
(
input
.
lod
.
size
()
>
1
)
{
VLOG
(
3
)
<<
" input lod first dim should <=1, but you set "
<<
input
.
lod
.
size
();
return
false
;
}
std
::
vector
<
int
>
offset
(
input
.
lod
[
0
].
begin
(),
input
.
lod
[
0
].
end
());
d_tensor_in_p
->
set_seq_offset
(
offset
);
VLOG
(
3
)
<<
"offset.size(): "
<<
offset
.
size
();
for
(
int
i
=
0
;
i
<
offset
.
size
();
i
++
)
{
VLOG
(
3
)
<<
offset
[
i
];
}
}
float
*
d_data_p
=
d_tensor_in_p
->
mutable_data
();
if
(
cudaMemcpy
(
d_data_p
,
static_cast
<
float
*>
(
input
.
data
.
data
()),
d_tensor_in_p
->
valid_size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
)
!=
0
)
{
LOG
(
ERROR
)
<<
"copy data from CPU to GPU error"
;
return
false
;
#ifdef PADDLE_WITH_CUDA
if
(
std
::
is_same
<
anakin
::
NV
,
Target
>::
value
)
{
if
(
cudaMemcpy
(
d_data_p
,
static_cast
<
float
*>
(
input
.
data
.
data
()),
d_tensor_in_p
->
valid_size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
)
!=
0
)
{
VLOG
(
3
)
<<
"copy data from CPU to GPU error"
;
return
false
;
}
}
#endif
if
(
std
::
is_same
<
anakin
::
X86
,
Target
>::
value
)
{
memcpy
(
d_data_p
,
static_cast
<
float
*>
(
input
.
data
.
data
()),
d_tensor_in_p
->
valid_size
()
*
sizeof
(
float
));
}
cudaStreamSynchronize
(
NULL
);
}
#ifdef PADDLE_WITH_CUDA
cudaDeviceSynchronize
();
executor_p_
->
prediction
();
cudaDeviceSynchronize
();
#endif
if
(
output_data
->
empty
())
{
LOG
(
ERROR
)
<<
"At least one output should be set with tensors' names."
;
VLOG
(
3
)
<<
"At least one output should be set with tensors' names."
;
return
false
;
}
for
(
auto
&
output
:
*
output_data
)
{
...
...
@@ -102,14 +147,22 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
if
(
output
.
data
.
length
()
<
tensor
->
valid_size
()
*
sizeof
(
float
))
{
output
.
data
.
Resize
(
tensor
->
valid_size
()
*
sizeof
(
float
));
}
// Copy data from GPU -> CPU
if
(
cudaMemcpy
(
output
.
data
.
data
(),
tensor
->
mutable_data
(),
tensor
->
valid_size
()
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
)
!=
0
)
{
LOG
(
ERROR
)
<<
"copy data from GPU to CPU error"
;
return
false
;
#if PADDLE_WITH_CUDA
if
(
std
::
is_same
<
anakin
::
NV
,
Target
>::
value
)
{
// Copy data from GPU -> CPU
if
(
cudaMemcpy
(
output
.
data
.
data
(),
tensor
->
mutable_data
(),
tensor
->
valid_size
()
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
)
!=
0
)
{
VLOG
(
3
)
<<
"copy data from GPU to CPU error"
;
return
false
;
}
}
#endif
if
(
std
::
is_same
<
anakin
::
X86
,
Target
>::
value
)
{
memcpy
(
output
.
data
.
data
(),
tensor
->
mutable_data
(),
tensor
->
valid_size
()
*
sizeof
(
float
));
}
cudaStreamSynchronize
(
NULL
);
}
return
true
;
}
...
...
@@ -132,7 +185,7 @@ PaddleInferenceAnakinPredictor<Target>::Clone() {
auto
anakin_predictor_p
=
dynamic_cast
<
PaddleInferenceAnakinPredictor
<
Target
>
*>
(
cls
.
get
());
if
(
!
anakin_predictor_p
)
{
LOG
(
ERROR
)
<<
"fail to call Init"
;
VLOG
(
3
)
<<
"fail to call Init"
;
return
nullptr
;
}
anakin_predictor_p
->
get_executer
().
init
(
graph_
);
...
...
@@ -162,6 +215,44 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
VLOG
(
3
)
<<
"Anakin Predictor create on unknown platform."
;
return
nullptr
;
}
};
}
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
template
<
typename
Target
>
using
executor_t
=
anakin
::
Net
<
Target
,
anakin
::
saber
::
AK_FLOAT
,
anakin
::
Precision
::
FP32
>
;
template
<
typename
Target
>
void
DisplayOpTimer
(
executor_t
<
Target
>
*
net_executor
,
int
epoch
)
{
std
::
vector
<
float
>
op_time
=
net_executor
->
get_op_time
();
auto
exec_funcs
=
net_executor
->
get_exec_funcs
();
auto
op_param
=
net_executor
->
get_op_param
();
for
(
int
i
=
0
;
i
<
op_time
.
size
();
i
++
)
{
LOG
(
INFO
)
<<
"name: "
<<
exec_funcs
[
i
].
name
<<
" op_type: "
<<
exec_funcs
[
i
].
op_name
<<
" op_param: "
<<
op_param
[
i
]
<<
" time "
<<
op_time
[
i
]
/
epoch
;
}
std
::
map
<
std
::
string
,
float
>
op_map
;
for
(
int
i
=
0
;
i
<
op_time
.
size
();
i
++
)
{
auto
it
=
op_map
.
find
(
op_param
[
i
]);
if
(
it
!=
op_map
.
end
())
op_map
[
op_param
[
i
]]
+=
op_time
[
i
];
else
op_map
.
insert
(
std
::
pair
<
std
::
string
,
float
>
(
op_param
[
i
],
op_time
[
i
]));
}
for
(
auto
it
=
op_map
.
begin
();
it
!=
op_map
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
it
->
first
<<
" "
<<
(
it
->
second
)
/
epoch
<<
" ms"
;
}
}
#endif
template
<
typename
Target
>
PaddleInferenceAnakinPredictor
<
Target
>::~
PaddleInferenceAnakinPredictor
()
{
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
DisplayOpTimer
<
Target
>
(
executor_p_
,
max_batch_size_
);
#endif
delete
executor_p_
;
executor_p_
=
nullptr
;
}
}
// namespace paddle
paddle/fluid/inference/api/api_anakin_engine.h
浏览文件 @
c73c5ed5
...
...
@@ -47,10 +47,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
anakin
::
Net
<
Target
,
anakin
::
saber
::
AK_FLOAT
,
anakin
::
Precision
::
FP32
>&
get_executer
();
~
PaddleInferenceAnakinPredictor
()
override
{
delete
executor_p_
;
executor_p_
=
nullptr
;
};
~
PaddleInferenceAnakinPredictor
()
override
;
private:
bool
Init
(
const
AnakinConfig
&
config
);
...
...
@@ -60,6 +57,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
anakin
::
Net
<
Target
,
anakin
::
saber
::
AK_FLOAT
,
anakin
::
Precision
::
FP32
>*
executor_p_
{
nullptr
};
AnakinConfig
config_
;
int
max_batch_size_
{
0
};
};
}
// namespace paddle
paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc
0 → 100644
浏览文件 @
c73c5ed5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <sys/time.h>
#include <time.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <thread> // NOLINT
#include <vector>
#include "framework/core/net/net.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
DEFINE_string
(
model
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
datapath
,
""
,
"Path of the dataset."
);
DEFINE_int32
(
batch_size
,
1
,
"batch size."
);
DEFINE_int32
(
repeat
,
1
,
"Running the inference program repeat times."
);
// Timer for timer
class
Timer
{
public:
double
start
;
double
startu
;
void
tic
()
{
struct
timeval
tp
;
gettimeofday
(
&
tp
,
NULL
);
start
=
tp
.
tv_sec
;
startu
=
tp
.
tv_usec
;
}
double
toc
()
{
struct
timeval
tp
;
gettimeofday
(
&
tp
,
NULL
);
double
used_time_ms
=
(
tp
.
tv_sec
-
start
)
*
1000.0
+
(
tp
.
tv_usec
-
startu
)
/
1000.0
;
return
used_time_ms
;
}
};
std
::
vector
<
std
::
string
>
string_split
(
std
::
string
in_str
,
std
::
string
delimiter
)
{
std
::
vector
<
std
::
string
>
seq
;
int
found
=
in_str
.
find
(
delimiter
);
int
pre_found
=
-
1
;
while
(
found
!=
std
::
string
::
npos
)
{
if
(
pre_found
==
-
1
)
{
seq
.
push_back
(
in_str
.
substr
(
0
,
found
));
}
else
{
seq
.
push_back
(
in_str
.
substr
(
pre_found
+
delimiter
.
length
(),
found
-
delimiter
.
length
()
-
pre_found
));
}
pre_found
=
found
;
found
=
in_str
.
find
(
delimiter
,
pre_found
+
delimiter
.
length
());
}
seq
.
push_back
(
in_str
.
substr
(
pre_found
+
1
,
in_str
.
length
()
-
(
pre_found
+
1
)));
return
seq
;
}
std
::
vector
<
std
::
string
>
string_split
(
std
::
string
in_str
,
std
::
vector
<
std
::
string
>&
delimiter
)
{
// NOLINT
std
::
vector
<
std
::
string
>
in
;
std
::
vector
<
std
::
string
>
out
;
out
.
push_back
(
in_str
);
for
(
auto
del
:
delimiter
)
{
in
=
out
;
out
.
clear
();
for
(
auto
s
:
in
)
{
auto
out_s
=
string_split
(
s
,
del
);
for
(
auto
o
:
out_s
)
{
out
.
push_back
(
o
);
}
}
}
return
out
;
}
class
Data
{
public:
Data
(
std
::
string
file_name
,
int
batch_size
)
:
_batch_size
(
batch_size
),
_total_length
(
0
)
{
_file
.
open
(
file_name
);
_file
.
seekg
(
_file
.
end
);
_total_length
=
_file
.
tellg
();
_file
.
seekg
(
_file
.
beg
);
}
void
get_batch_data
(
std
::
vector
<
std
::
vector
<
float
>>&
fea
,
// NOLINT
std
::
vector
<
std
::
vector
<
float
>>&
week_fea
,
// NOLINT
std
::
vector
<
std
::
vector
<
float
>>&
time_fea
,
// NOLINT
std
::
vector
<
long
unsigned
int
>&
seq_offset
);
// NOLINT
private:
std
::
fstream
_file
;
int
_total_length
;
int
_batch_size
;
};
void
Data
::
get_batch_data
(
std
::
vector
<
std
::
vector
<
float
>>&
fea
,
// NOLINT
std
::
vector
<
std
::
vector
<
float
>>&
week_fea
,
// NOLINT
std
::
vector
<
std
::
vector
<
float
>>&
time_fea
,
// NOLINT
std
::
vector
<
long
unsigned
int
>&
seq_offset
)
{
// NOLINT
int
seq_num
=
0
;
long
unsigned
int
cum
=
0
;
// NOLINT
char
buf
[
10000
];
seq_offset
.
clear
();
seq_offset
.
push_back
(
0
);
fea
.
clear
();
week_fea
.
clear
();
time_fea
.
clear
();
while
(
_file
.
getline
(
buf
,
10000
))
{
std
::
string
s
=
buf
;
std
::
vector
<
std
::
string
>
deli_vec
=
{
":"
};
std
::
vector
<
std
::
string
>
data_vec
=
string_split
(
s
,
deli_vec
);
std
::
vector
<
std
::
string
>
seq
;
seq
=
string_split
(
data_vec
[
0
],
{
"|"
});
for
(
auto
link
:
seq
)
{
std
::
vector
<
std
::
string
>
data
=
string_split
(
link
,
","
);
std
::
vector
<
float
>
vec
;
for
(
int
i
=
0
;
i
<
data
.
size
();
i
++
)
{
vec
.
push_back
(
atof
(
data
[
i
].
c_str
()));
}
fea
.
push_back
(
vec
);
}
std
::
vector
<
std
::
string
>
week_data
;
std
::
vector
<
std
::
string
>
time_data
;
week_data
=
string_split
(
data_vec
[
2
],
","
);
std
::
vector
<
float
>
vec_w
;
for
(
int
i
=
0
;
i
<
week_data
.
size
();
i
++
)
{
vec_w
.
push_back
(
atof
(
week_data
[
i
].
c_str
()));
}
week_fea
.
push_back
(
vec_w
);
time_data
=
string_split
(
data_vec
[
1
],
","
);
std
::
vector
<
float
>
vec_t
;
for
(
int
i
=
0
;
i
<
time_data
.
size
();
i
++
)
{
vec_t
.
push_back
(
atof
(
time_data
[
i
].
c_str
()));
}
time_fea
.
push_back
(
vec_t
);
cum
+=
seq
.
size
();
seq_offset
.
push_back
(
cum
);
seq_num
++
;
if
(
seq_num
>=
_batch_size
)
{
break
;
}
}
}
namespace
paddle
{
AnakinConfig
GetConfig
()
{
AnakinConfig
config
;
// using AnakinConfig::X86 if you need to use cpu to do inference
config
.
target_type
=
AnakinConfig
::
X86
;
config
.
model_file
=
FLAGS_model
;
config
.
device
=
0
;
config
.
max_batch_size
=
1000
;
// the max number of token
return
config
;
}
void
set_tensor
(
std
::
string
name
,
std
::
vector
<
int
>
shape
,
std
::
vector
<
PaddleTensor
>&
vec
)
{
// NOLINT
int
sum
=
1
;
std
::
for_each
(
shape
.
begin
(),
shape
.
end
(),
[
&
](
int
n
)
{
sum
*=
n
;
});
float
*
data
=
new
float
[
sum
];
PaddleTensor
tensor
;
tensor
.
name
=
name
;
tensor
.
shape
=
shape
;
tensor
.
data
=
PaddleBuf
(
data
,
sum
);
tensor
.
dtype
=
PaddleDType
::
FLOAT32
;
vec
.
push_back
(
tensor
);
}
void
single_test
()
{
AnakinConfig
config
=
GetConfig
();
auto
predictor
=
CreatePaddlePredictor
<
AnakinConfig
,
PaddleEngineKind
::
kAnakin
>
(
config
);
int
max_batch_size
=
1000
;
std
::
string
feature_file
=
FLAGS_datapath
;
Data
map_data
(
feature_file
,
FLAGS_batch_size
);
std
::
vector
<
std
::
vector
<
float
>>
fea
;
std
::
vector
<
std
::
vector
<
float
>>
week_fea
;
std
::
vector
<
std
::
vector
<
float
>>
time_fea
;
std
::
vector
<
long
unsigned
int
>
seq_offset
;
// NOLINT
paddle
::
PaddleTensor
tensor_0
,
tensor_1
,
tensor_2
;
tensor_0
.
name
=
"input_0"
;
tensor_1
.
name
=
"input_4"
;
tensor_2
.
name
=
"input_5"
;
PaddleTensor
tensor_out
;
tensor_out
.
name
=
"final_output.tmp_1_gout"
;
tensor_out
.
shape
=
std
::
vector
<
int
>
({});
tensor_out
.
data
=
PaddleBuf
();
tensor_out
.
dtype
=
PaddleDType
::
FLOAT32
;
std
::
vector
<
PaddleTensor
>
inputs
;
std
::
vector
<
PaddleTensor
>
outputs
(
1
,
tensor_out
);
int
data_0_dim
=
38
;
int
data_1_dim
=
10
;
int
data_2_dim
=
10
;
float
data_0
[
max_batch_size
*
data_0_dim
];
// NOLINT
float
data_1
[
max_batch_size
*
data_1_dim
];
// NOLINT
float
data_2
[
max_batch_size
*
data_2_dim
];
// NOLINT
int
count
=
0
;
while
(
true
)
{
if
(
count
++
>
0
)
break
;
// only run the first batch in ci.
seq_offset
.
clear
();
map_data
.
get_batch_data
(
fea
,
week_fea
,
time_fea
,
seq_offset
);
if
(
seq_offset
.
size
()
<=
1
)
{
LOG
(
FATAL
)
<<
"seq_offset.size() <= 1, exit."
;
break
;
}
std
::
vector
<
std
::
vector
<
long
unsigned
int
>>
seq_offset_vec
;
// NOLINT
seq_offset_vec
.
push_back
(
seq_offset
);
tensor_0
.
lod
=
seq_offset_vec
;
int
p_shape_0
[]
=
{(
int
)
fea
.
size
(),
1
,
1
,
data_0_dim
};
// NOLINT
int
p_shape_1
[]
=
{(
int
)
week_fea
.
size
(),
data_1_dim
,
1
,
1
};
// NOLINT
int
p_shape_2
[]
=
{(
int
)
time_fea
.
size
(),
data_2_dim
,
1
,
1
};
// NOLINT
std
::
vector
<
int
>
shape_0
(
p_shape_0
,
p_shape_0
+
4
);
std
::
vector
<
int
>
shape_1
(
p_shape_1
,
p_shape_1
+
4
);
std
::
vector
<
int
>
shape_2
(
p_shape_2
,
p_shape_2
+
4
);
tensor_0
.
shape
=
shape_0
;
tensor_1
.
shape
=
shape_1
;
tensor_2
.
shape
=
shape_2
;
for
(
int
i
=
0
;
i
<
fea
.
size
();
i
++
)
{
memcpy
(
data_0
+
i
*
data_0_dim
,
&
fea
[
i
][
0
],
sizeof
(
float
)
*
data_0_dim
);
}
for
(
int
i
=
0
;
i
<
week_fea
.
size
();
i
++
)
{
memcpy
(
data_1
+
i
*
data_1_dim
,
&
week_fea
[
i
][
0
],
sizeof
(
float
)
*
data_1_dim
);
}
for
(
int
i
=
0
;
i
<
time_fea
.
size
();
i
++
)
{
memcpy
(
data_2
+
i
*
data_2_dim
,
&
time_fea
[
i
][
0
],
sizeof
(
float
)
*
data_2_dim
);
}
tensor_0
.
data
=
paddle
::
PaddleBuf
(
data_0
,
fea
.
size
()
*
sizeof
(
float
)
*
data_0_dim
);
tensor_1
.
data
=
paddle
::
PaddleBuf
(
data_1
,
week_fea
.
size
()
*
sizeof
(
float
)
*
data_1_dim
);
tensor_2
.
data
=
paddle
::
PaddleBuf
(
data_2
,
time_fea
.
size
()
*
sizeof
(
float
)
*
data_2_dim
);
tensor_0
.
dtype
=
paddle
::
PaddleDType
::
FLOAT32
;
tensor_1
.
dtype
=
paddle
::
PaddleDType
::
FLOAT32
;
tensor_2
.
dtype
=
paddle
::
PaddleDType
::
FLOAT32
;
inputs
.
clear
();
inputs
.
push_back
(
tensor_1
);
inputs
.
push_back
(
tensor_2
);
inputs
.
push_back
(
tensor_0
);
Timer
timer
;
timer
.
tic
();
for
(
int
i
=
0
;
i
<
FLAGS_repeat
;
i
++
)
predictor
->
Run
(
inputs
,
&
outputs
);
LOG
(
INFO
)
<<
"batch_size = "
<<
FLAGS_batch_size
<<
", repeat = "
<<
FLAGS_repeat
<<
", sequence_length = "
<<
seq_offset
[
seq_offset
.
size
()
-
1
]
<<
", latency: "
<<
timer
.
toc
()
/
FLAGS_repeat
<<
"ms"
;
float
*
data_o
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
());
VLOG
(
3
)
<<
"outputs[0].data.length() = "
<<
outputs
[
0
].
data
.
length
();
for
(
size_t
j
=
0
;
j
<
outputs
[
0
].
data
.
length
();
++
j
)
{
VLOG
(
3
)
<<
"output["
<<
j
<<
"]: "
<<
data_o
[
j
];
}
}
}
}
// namespace paddle
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
logger
::
init
(
argv
[
0
]);
paddle
::
single_test
();
/* multi-threads
std::vector<std::thread> threads;
int num = 1;
for (int i = 0; i < num; i++) {
LOG(INFO) << " thread id : " << i;
threads.emplace_back(paddle::single_test);
}
for (int i = 0; i < num; i++) {
threads[i].join();
}
threads.clear();
*/
return
0
;
}
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
c73c5ed5
...
...
@@ -137,8 +137,11 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
return
false
;
}
for
(
size_t
i
=
0
;
i
<
feed_target_names_
.
size
();
++
i
)
{
VLOG
(
4
)
<<
"setting "
<<
i
<<
"-th target"
;
feed_targets
[
feed_target_names_
[
i
]]
=
&
feeds
[
i
];
if
(
config_
.
specify_input_name
)
{
feed_targets
[
inputs
[
i
].
name
]
=
&
feeds
[
i
];
}
else
{
feed_targets
[
feed_target_names_
[
i
]]
=
&
feeds
[
i
];
}
}
// get fetch variable
std
::
map
<
std
::
string
,
framework
::
LoDTensor
*>
fetch_targets
;
...
...
paddle/fluid/inference/api/helper.h
0 → 100644
浏览文件 @
c73c5ed5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sys/time.h>
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace
paddle
{
namespace
inference
{
// Timer for timer
class
Timer
{
public:
double
start
;
double
startu
;
void
tic
()
{
struct
timeval
tp
;
gettimeofday
(
&
tp
,
NULL
);
start
=
tp
.
tv_sec
;
startu
=
tp
.
tv_usec
;
}
double
toc
()
{
struct
timeval
tp
;
gettimeofday
(
&
tp
,
NULL
);
double
used_time_ms
=
(
tp
.
tv_sec
-
start
)
*
1000.0
+
(
tp
.
tv_usec
-
startu
)
/
1000.0
;
return
used_time_ms
;
}
};
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
return
;
}
size_t
pos
=
0
;
size_t
next
=
str
.
find
(
sep
,
pos
);
while
(
next
!=
std
::
string
::
npos
)
{
pieces
->
push_back
(
str
.
substr
(
pos
,
next
-
pos
));
pos
=
next
+
1
;
next
=
str
.
find
(
sep
,
pos
);
}
if
(
!
str
.
substr
(
pos
).
empty
())
{
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
void
split_to_float
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
float
>
*
fs
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
str
,
sep
,
&
pieces
);
std
::
transform
(
pieces
.
begin
(),
pieces
.
end
(),
std
::
back_inserter
(
*
fs
),
[](
const
std
::
string
&
v
)
{
return
std
::
stof
(
v
);
});
}
template
<
typename
T
>
std
::
string
to_string
(
const
std
::
vector
<
T
>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
c
:
vec
)
{
ss
<<
c
<<
" "
;
}
return
ss
.
str
();
}
template
<
>
std
::
string
to_string
<
std
::
vector
<
float
>>
(
const
std
::
vector
<
std
::
vector
<
float
>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
piece
:
vec
)
{
ss
<<
to_string
(
piece
)
<<
"
\n
"
;
}
return
ss
.
str
();
}
template
<
>
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
)
{
std
::
stringstream
ss
;
for
(
const
auto
&
line
:
vec
)
{
for
(
const
auto
&
rcd
:
line
)
{
ss
<<
to_string
(
rcd
)
<<
";
\t
"
;
}
ss
<<
'\n'
;
}
return
ss
.
str
();
}
// clang-format off
void
TensorAssignData
(
PaddleTensor
*
tensor
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
)
{
// Assign buffer
int
dim
=
std
::
accumulate
(
tensor
->
shape
.
begin
(),
tensor
->
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
tensor
->
data
.
Resize
(
sizeof
(
float
)
*
dim
);
int
c
=
0
;
for
(
const
auto
&
f
:
data
)
{
for
(
float
v
:
f
)
{
static_cast
<
float
*>
(
tensor
->
data
.
data
())[
c
++
]
=
v
;
}
}
}
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
c73c5ed5
...
...
@@ -45,7 +45,7 @@ class PaddleBuf {
PaddleBuf
(
void
*
data
,
size_t
length
)
:
data_
(
data
),
length_
(
length
),
memory_owned_
{
false
}
{}
// Own memory.
explicit
PaddleBuf
(
size_t
length
)
PaddleBuf
(
size_t
length
)
:
data_
(
new
char
[
length
]),
length_
(
length
),
memory_owned_
(
true
)
{}
// Resize to `length` bytes.
void
Resize
(
size_t
length
);
...
...
@@ -70,7 +70,7 @@ struct PaddleTensor {
std
::
vector
<
int
>
shape
;
PaddleBuf
data
;
// blob of data.
PaddleDType
dtype
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
lod
;
// lod data
std
::
vector
<
std
::
vector
<
size_t
>>
lod
;
// Tensor+LoD equals LoDTensor
};
enum
class
PaddleEngineKind
{
...
...
@@ -120,6 +120,8 @@ struct NativeConfig : public PaddlePredictor::Config {
bool
use_gpu
{
false
};
int
device
{
0
};
float
fraction_of_gpu_memory
{
-
1.
f
};
// Negative to notify initialization.
// Specify the variable's name of each input.
bool
specify_input_name
{
false
};
std
::
string
prog_file
;
std
::
string
param_file
;
...
...
paddle/fluid/operators/mul_op.cc
浏览文件 @
c73c5ed5
...
...
@@ -54,9 +54,9 @@ class MulOp : public framework::OperatorWithKernel {
auto
x_mat_dims
=
framework
::
flatten_to_2d
(
x_dims
,
x_num_col_dims
);
auto
y_mat_dims
=
framework
::
flatten_to_2d
(
y_dims
,
y_num_col_dims
);
PADDLE_ENFORCE_EQ
(
x_mat_dims
[
1
],
y_mat_dims
[
0
],
"First matrix's width must be equal with second matrix's height.
"
);
PADDLE_ENFORCE_EQ
(
x_mat_dims
[
1
],
y_mat_dims
[
0
],
"First matrix's width must be equal with second matrix's "
"height. %s, %s
"
);
std
::
vector
<
int64_t
>
output_dims
;
output_dims
.
reserve
(
static_cast
<
size_t
>
(
x_num_col_dims
+
y_dims
.
size
()
-
y_num_col_dims
));
...
...
paddle/fluid/operators/stack_op.cc
浏览文件 @
c73c5ed5
...
...
@@ -14,53 +14,15 @@
#include "paddle/fluid/operators/stack_op.h"
namespace
paddle
{
namespace
operators
{
struct
CPUStackFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
const
std
::
vector
<
const
T
*>&
x
,
T
*
y
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
for
(
int
idx
=
0
;
idx
<
total_num
;
++
idx
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
y
[
idx
]
=
x
[
which_x
][
x_index
];
}
}
};
struct
CPUStackGradFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
std
::
vector
<
T
*>&
dx
,
// NOLINT
const
T
*
dy
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
for
(
int
idx
=
0
;
idx
<
total_num
;
++
idx
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
dx
[
which_x
][
x_index
]
=
dy
[
idx
];
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
stack
,
ops
::
StackOp
,
ops
::
StackOpMaker
,
ops
::
StackGradOpDescMaker
);
REGISTER_OPERATOR
(
stack_grad
,
ops
::
StackOpGrad
);
REGISTER_OP_CPU_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
float
,
ops
::
CPUStackFunctor
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
double
,
ops
::
CPUStackFunctor
>
);
REGISTER_OP_CPU_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
stack_grad
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
float
,
ops
::
CPUStackGradFunctor
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
double
,
ops
::
CPUStackGradFunctor
>
);
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/stack_op.cu
浏览文件 @
c73c5ed5
...
...
@@ -12,98 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/operators/stack_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
VecXType
>
__global__
void
StackCUDAKernel
(
VecXType
x
,
T
*
y
,
int
total_num
,
int
n
,
int
post
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
total_num
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
y
[
idx
]
=
x
[
which_x
][
x_index
];
}
}
template
<
typename
T
,
typename
VecDxType
>
__global__
void
StackGradCUDAKernel
(
VecDxType
dx
,
const
T
*
dy
,
int
total_num
,
int
n
,
int
post
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
total_num
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
dx
[
which_x
][
x_index
]
=
dy
[
idx
];
}
}
struct
GPUStackFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
const
std
::
vector
<
const
T
*>&
x
,
T
*
y
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
int
threads
=
512
;
int
grid
=
(
total_num
+
threads
-
1
)
/
threads
;
constexpr
auto
kMaxThreshold
=
16
;
if
(
n
<=
kMaxThreshold
)
{
framework
::
Array
<
const
T
*
,
kMaxThreshold
>
arr
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
arr
[
i
]
=
x
[
i
];
StackCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
arr
,
y
,
total_num
,
n
,
post
);
}
else
{
VLOG
(
10
)
<<
"Stack more than "
<<
kMaxThreshold
<<
" tensors may be slow on GPU."
;
thrust
::
device_vector
<
const
T
*>
dev_x
(
x
);
StackCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
dev_x
.
data
().
get
(),
y
,
total_num
,
n
,
post
);
}
}
};
struct
GPUStackGradFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
std
::
vector
<
T
*>&
dx
,
// NOLINT
const
T
*
dy
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
int
threads
=
512
;
int
grid
=
(
total_num
+
threads
-
1
)
/
threads
;
constexpr
auto
kMaxThreshold
=
16
;
if
(
n
<=
kMaxThreshold
)
{
framework
::
Array
<
T
*
,
kMaxThreshold
>
arr
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
arr
[
i
]
=
dx
[
i
];
StackGradCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
arr
,
dy
,
total_num
,
n
,
post
);
}
else
{
VLOG
(
10
)
<<
"Stack more than "
<<
kMaxThreshold
<<
" tensors may be slow on GPU."
;
thrust
::
device_vector
<
T
*>
dev_dx
(
dx
);
StackGradCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
dev_dx
.
data
().
get
(),
dy
,
total_num
,
n
,
post
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CUDADeviceContext
,
float
,
ops
::
GPUStackFunctor
>
,
ops
::
StackKernel
<
plat
::
CUDADeviceContext
,
double
,
ops
::
GPUStackFunctor
>
);
REGISTER_OP_CUDA_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
StackKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
stack_grad
,
ops
::
StackGradKernel
<
plat
::
CUDADeviceContext
,
float
,
ops
::
GPUStackGradFunctor
>
,
ops
::
StackGradKernel
<
plat
::
CUDADeviceContext
,
double
,
ops
::
GPUStackGradFunctor
>
);
ops
::
StackGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
StackGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/stack_op.h
浏览文件 @
c73c5ed5
...
...
@@ -11,20 +11,20 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#ifdef __NVCC__
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
#endif
namespace
paddle
{
namespace
operators
{
inline
void
GetPrePostForStackOp
(
const
framework
::
DDim
&
dim
,
int
axis
,
int
*
pre
,
int
*
post
)
{
*
pre
=
1
;
for
(
auto
i
=
0
;
i
<
axis
;
++
i
)
(
*
pre
)
*=
dim
[
i
];
*
post
=
1
;
for
(
auto
i
=
axis
;
i
<
dim
.
size
();
++
i
)
(
*
post
)
*=
dim
[
i
];
}
class
StackOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -72,7 +72,61 @@ class StackOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
>
template
<
typename
VecXType
,
typename
T
>
struct
StackFunctor
{
HOSTDEVICE
StackFunctor
(
const
VecXType
&
x
,
T
*
y
,
int
n
,
int
post
)
:
x_
(
x
),
y_
(
y
),
n_
(
n
),
post_
(
post
)
{}
HOSTDEVICE
void
operator
()(
int
idx
)
{
int
i
=
idx
/
(
n_
*
post_
);
int
which_x
=
idx
/
post_
-
i
*
n_
;
int
x_index
=
i
*
post_
+
idx
%
post_
;
y_
[
idx
]
=
x_
[
which_x
][
x_index
];
}
private:
VecXType
x_
;
T
*
y_
;
int
n_
;
int
post_
;
};
template
<
typename
VecDxType
,
typename
T
>
struct
StackGradFunctor
{
HOSTDEVICE
StackGradFunctor
(
const
VecDxType
&
dx
,
const
T
*
dy
,
int
n
,
int
post
)
:
dx_
(
dx
),
dy_
(
dy
),
n_
(
n
),
post_
(
post
)
{}
HOSTDEVICE
void
operator
()(
int
idx
)
{
int
i
=
idx
/
(
n_
*
post_
);
int
which_x
=
idx
/
post_
-
i
*
n_
;
int
x_index
=
i
*
post_
+
idx
%
post_
;
dx_
[
which_x
][
x_index
]
=
dy_
[
idx
];
}
private:
VecDxType
dx_
;
const
T
*
dy_
;
int
n_
;
int
post_
;
};
template
<
typename
DeviceContext
,
typename
VecXType
,
typename
T
>
static
inline
void
StackFunctorForRange
(
const
DeviceContext
&
ctx
,
const
VecXType
&
x
,
T
*
y
,
int
total_num
,
int
n
,
int
post
)
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
total_num
);
for_range
(
StackFunctor
<
VecXType
,
T
>
(
x
,
y
,
n
,
post
));
}
template
<
typename
DeviceContext
,
typename
VecDxType
,
typename
T
>
static
inline
void
StackGradFunctorForRange
(
const
DeviceContext
&
ctx
,
const
VecDxType
&
dx
,
const
T
*
dy
,
int
total_num
,
int
n
,
int
post
)
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
total_num
);
for_range
(
StackGradFunctor
<
VecDxType
,
T
>
(
dx
,
dy
,
n
,
post
));
}
template
<
typename
DeviceContext
,
typename
T
>
class
StackKernel
:
public
framework
::
OpKernel
<
T
>
{
using
Tensor
=
framework
::
LoDTensor
;
...
...
@@ -93,10 +147,29 @@ class StackKernel : public framework::OpKernel<T> {
auto
&
dim
=
x
[
0
]
->
dims
();
for
(
auto
i
=
0
;
i
<
axis
;
++
i
)
pre
*=
dim
[
i
];
for
(
auto
i
=
axis
;
i
<
dim
.
size
();
++
i
)
post
*=
dim
[
i
];
int
total_num
=
pre
*
n
*
post
;
Functor
functor
;
functor
(
ctx
.
template
device_context
<
DeviceContext
>(),
x_datas
,
y_data
,
pre
,
n
,
post
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
constexpr
auto
kMaxThreshold
=
16
;
if
(
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
||
n
>
kMaxThreshold
)
{
#ifdef __NVCC__
thrust
::
device_vector
<
const
T
*>
device_x_vec
(
x_datas
);
auto
x_data_arr
=
device_x_vec
.
data
().
get
();
#else
auto
x_data_arr
=
x_datas
.
data
();
#endif
StackFunctorForRange
(
dev_ctx
,
x_data_arr
,
y_data
,
total_num
,
n
,
post
);
}
#ifdef __NVCC__
else
{
// NOLINT
VLOG
(
10
)
<<
"Stack more than "
<<
kMaxThreshold
<<
" tensors on GPU may be slow."
;
framework
::
Array
<
const
T
*
,
kMaxThreshold
>
x_data_arr
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
x_data_arr
[
i
]
=
x_datas
[
i
];
StackFunctorForRange
(
dev_ctx
,
x_data_arr
,
y_data
,
total_num
,
n
,
post
);
}
#endif
}
};
...
...
@@ -127,31 +200,11 @@ class StackOpGrad : public framework::OperatorWithKernel {
}
};
class
StackGradOpDescMaker
:
public
framework
::
SingleGradOpDescMaker
/*framework::GradOpDescMakerBase*/
{
class
StackGradOpDescMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
/*
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator ()() const override {
auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Y");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string& x_grad) {
auto* grad_op = new framework::OpDesc();
grad_op->SetInput("X", og);
grad_op->SetOutput("Y", {x_grad});
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
});
return grad_ops;
}
*/
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
op
(
new
framework
::
OpDesc
());
op
->
SetType
(
"stack_grad"
);
...
...
@@ -162,7 +215,7 @@ class StackGradOpDescMaker
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
GradFunctor
>
template
<
typename
DeviceContext
,
typename
T
>
class
StackGradKernel
:
public
framework
::
OpKernel
<
T
>
{
using
Tensor
=
framework
::
LoDTensor
;
...
...
@@ -175,16 +228,39 @@ class StackGradKernel : public framework::OpKernel<T> {
int
n
=
dy
->
dims
()[
axis
];
std
::
vector
<
T
*>
dx_datas
(
n
);
// NOLINT
for
(
int
i
=
0
;
i
<
n
;
i
++
)
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
dx_datas
[
i
]
=
dx
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
auto
dy_data
=
dy
->
data
<
T
>
();
int
pre
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
pre
*=
dy
->
dims
()[
i
];
int
post
=
dy
->
numel
()
/
(
n
*
pre
);
GradFunctor
functor
;
functor
(
ctx
.
template
device_context
<
DeviceContext
>(),
dx_datas
,
dy_data
,
pre
,
n
,
post
);
int
total_num
=
dy
->
numel
();
int
post
=
total_num
/
(
n
*
pre
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
constexpr
auto
kMaxThreshold
=
16
;
if
(
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
||
n
>
kMaxThreshold
)
{
#ifdef __NVCC__
thrust
::
device_vector
<
T
*>
device_dx_vec
(
dx_datas
);
auto
dx_data_arr
=
device_dx_vec
.
data
().
get
();
#else
auto
dx_data_arr
=
dx_datas
.
data
();
#endif
StackGradFunctorForRange
(
dev_ctx
,
dx_data_arr
,
dy_data
,
total_num
,
n
,
post
);
}
#ifdef __NVCC__
else
{
// NOLINT
VLOG
(
10
)
<<
"Stack more than "
<<
kMaxThreshold
<<
" tensors on GPU may be slow."
;
framework
::
Array
<
T
*
,
kMaxThreshold
>
dx_data_arr
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
dx_data_arr
[
i
]
=
dx_datas
[
i
];
StackGradFunctorForRange
(
dev_ctx
,
dx_data_arr
,
dy_data
,
total_num
,
n
,
post
);
}
#endif
}
};
...
...
python/paddle/fluid/io.py
浏览文件 @
c73c5ed5
...
...
@@ -406,6 +406,9 @@ def load_vars(executor,
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
filename
)})
executor
.
run
(
load_prog
)
if
main_program
is
None
:
main_program
=
default_main_program
()
# load slice vars on pserver, if have it.
_load_slice_up_vars
(
executor
,
dirname
,
main_program
.
_slice_vars_and_attrs
)
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
c73c5ed5
...
...
@@ -29,21 +29,81 @@ from .. import unique_name
from
functools
import
reduce
__all__
=
[
'fc'
,
'embedding'
,
'dynamic_lstm'
,
'dynamic_lstmp'
,
'dynamic_gru'
,
'gru_unit'
,
'linear_chain_crf'
,
'crf_decoding'
,
'cos_sim'
,
'cross_entropy'
,
'square_error_cost'
,
'chunk_eval'
,
'sequence_conv'
,
'conv2d'
,
'conv3d'
,
'sequence_pool'
,
'sequence_softmax'
,
'softmax'
,
'pool2d'
,
'pool3d'
,
'batch_norm'
,
'beam_search_decode'
,
'conv2d_transpose'
,
'conv3d_transpose'
,
'sequence_expand'
,
'lstm_unit'
,
'reduce_sum'
,
'reduce_mean'
,
'reduce_max'
,
'reduce_min'
,
'reduce_prod'
,
'sequence_first_step'
,
'sequence_last_step'
,
'dropout'
,
'split'
,
'ctc_greedy_decoder'
,
'edit_distance'
,
'l2_normalize'
,
'matmul'
,
'topk'
,
'warpctc'
,
'sequence_reshape'
,
'transpose'
,
'im2sequence'
,
'nce'
,
'hsigmoid'
,
'beam_search'
,
'row_conv'
,
'multiplex'
,
'layer_norm'
,
'softmax_with_cross_entropy'
,
'smooth_l1'
,
'one_hot'
,
'autoincreased_step_counter'
,
'reshape'
,
'lod_reset'
,
'lrn'
,
'pad'
,
'label_smooth'
,
'roi_pool'
,
'dice_loss'
,
'image_resize'
,
'image_resize_short'
,
'resize_bilinear'
,
'gather'
,
'scatter'
,
'random_crop'
,
'mean_iou'
,
'relu'
,
'log'
,
'crop'
,
'rank_loss'
,
'prelu'
,
'flatten'
,
'stack'
'fc'
,
'embedding'
,
'dynamic_lstm'
,
'dynamic_lstmp'
,
'dynamic_gru'
,
'gru_unit'
,
'linear_chain_crf'
,
'crf_decoding'
,
'cos_sim'
,
'cross_entropy'
,
'square_error_cost'
,
'chunk_eval'
,
'sequence_conv'
,
'conv2d'
,
'conv3d'
,
'sequence_pool'
,
'sequence_softmax'
,
'softmax'
,
'pool2d'
,
'pool3d'
,
'batch_norm'
,
'beam_search_decode'
,
'conv2d_transpose'
,
'conv3d_transpose'
,
'sequence_expand'
,
'lstm_unit'
,
'reduce_sum'
,
'reduce_mean'
,
'reduce_max'
,
'reduce_min'
,
'reduce_prod'
,
'sequence_first_step'
,
'sequence_last_step'
,
'dropout'
,
'split'
,
'ctc_greedy_decoder'
,
'edit_distance'
,
'l2_normalize'
,
'matmul'
,
'topk'
,
'warpctc'
,
'sequence_reshape'
,
'transpose'
,
'im2sequence'
,
'nce'
,
'hsigmoid'
,
'beam_search'
,
'row_conv'
,
'multiplex'
,
'layer_norm'
,
'softmax_with_cross_entropy'
,
'smooth_l1'
,
'one_hot'
,
'autoincreased_step_counter'
,
'reshape'
,
'lod_reset'
,
'lrn'
,
'pad'
,
'label_smooth'
,
'roi_pool'
,
'dice_loss'
,
'image_resize'
,
'image_resize_short'
,
'resize_bilinear'
,
'gather'
,
'scatter'
,
'random_crop'
,
'mean_iou'
,
'relu'
,
'log'
,
'crop'
,
'rank_loss'
,
'prelu'
,
'flatten'
,
'stack'
,
]
...
...
@@ -5469,5 +5529,6 @@ def stack(x, axis=0):
out
=
helper
.
create_tmp_variable
(
x
[
0
].
dtype
)
helper
.
append_op
(
type
=
'stack'
,
inputs
=
{
'X'
:
x
},
outpus
=
{
'Y'
:
out
},
attrs
=
{
'axis'
:
axis
})
type
=
'stack'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Y'
:
out
},
attrs
=
{
'axis'
:
axis
})
return
out
python/paddle/fluid/tests/unittests/test_profiler.py
浏览文件 @
c73c5ed5
...
...
@@ -25,9 +25,6 @@ import paddle.fluid.core as core
class
TestProfiler
(
unittest
.
TestCase
):
def
net_profiler
(
self
,
state
,
profile_path
=
'/tmp/profile'
):
enable_if_gpu
=
state
==
'GPU'
or
state
==
"All"
if
enable_if_gpu
and
not
core
.
is_compiled_with_cuda
():
return
startup_program
=
fluid
.
Program
()
main_program
=
fluid
.
Program
()
...
...
@@ -81,8 +78,6 @@ class TestProfiler(unittest.TestCase):
pass_acc_calculator
.
add
(
value
=
acc
,
weight
=
b_size
)
pass_acc
=
pass_acc_calculator
.
eval
()
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"profiler is enabled only with GPU"
)
def
test_cpu_profiler
(
self
):
self
.
net_profiler
(
'CPU'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录