Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6ccf8685
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6ccf8685
编写于
1月 07, 2019
作者:
Y
Yan Chunwei
提交者:
GitHub
1月 07, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor tensorrt node teller (#15181)
上级
c8f101e5
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
134 addition
and
47 deletion
+134
-47
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+0
-2
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+0
-10
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
+11
-7
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+5
-3
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc
...uid/inference/analysis/passes/ir_analysis_compose_pass.cc
+0
-23
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h
...luid/inference/analysis/passes/ir_analysis_compose_pass.h
+0
-2
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+49
-0
paddle/fluid/inference/tensorrt/op_teller.h
paddle/fluid/inference/tensorrt/op_teller.h
+68
-0
未找到文件。
paddle/fluid/inference/analysis/argument.h
浏览文件 @
6ccf8685
...
...
@@ -123,8 +123,6 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
use_gpu
,
UseGPU
,
bool
);
DECL_ARGUMENT_FIELD
(
gpu_device_id
,
GPUDeviceId
,
int
);
DECL_ARGUMENT_FIELD
(
use_tensorrt
,
UseTensorRT
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_node_teller
,
TensorRtNodeTeller
,
std
::
function
<
bool
(
const
framework
::
ir
::
Node
*
)
>
);
DECL_ARGUMENT_FIELD
(
tensorrt_max_batch_size
,
TensorRtMaxBatchSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_workspace_size
,
TensorRtWorkspaceSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
6ccf8685
...
...
@@ -49,13 +49,6 @@ void IRPassManager::CreatePasses(Argument *argument,
for
(
const
std
::
string
&
pass_name
:
passes
)
{
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
// Set some pass attributes.
if
(
pass_name
==
"ir_analysis_pass"
)
{
pass
->
Set
(
"tensorrt_node_teller"
,
new
SubgraphDetector
::
NodeInsideSubgraphTeller
(
argument
->
tensorrt_node_teller
()));
}
if
(
pass_name
==
"graph_viz_pass"
)
{
std
::
string
dot_file_path
=
std
::
to_string
(
pass_num
)
+
"_ir_"
+
(
pre_pass
.
empty
()
?
"origin"
:
pre_pass
)
+
...
...
@@ -70,9 +63,6 @@ void IRPassManager::CreatePasses(Argument *argument,
}
if
(
pass_name
==
"tensorrt_subgraph_pass"
)
{
PADDLE_ENFORCE
(
argument
->
tensorrt_node_teller_valid
());
pass
->
SetNotOwned
(
"tensorrt_node_teller"
,
argument
->
tensorrt_node_teller_ptr
());
pass
->
Set
(
"workspace_size"
,
new
int
(
argument
->
tensorrt_workspace_size
()));
pass
->
Set
(
"max_batch_size"
,
new
int
(
argument
->
tensorrt_max_batch_size
()));
pass
->
Set
(
"min_subgraph_size"
,
...
...
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
浏览文件 @
6ccf8685
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS proto_desc
)
cc_library
(
tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector
)
set
(
analysis_deps
${
analysis_deps
}
subgraph_detector tensorrt_subgraph_pass
CACHE INTERNAL
""
)
set
(
pass_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/paddle_inference_pass.h
)
file
(
APPEND
${
pass_file
}
"USE_PASS(tensorrt_subgraph_pass);
\n
"
)
set
(
INFER_IR_PASSES
${
INFER_IR_PASSES
}
tensorrt_subgraph_pass CACHE INTERNAL
""
)
if
(
TENSORRT_FOUND
)
cc_library
(
tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector tensorrt_op_teller
)
set
(
analysis_deps
${
analysis_deps
}
subgraph_detector tensorrt_subgraph_pass
CACHE INTERNAL
""
)
set
(
pass_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/paddle_inference_pass.h
)
file
(
APPEND
${
pass_file
}
"USE_PASS(tensorrt_subgraph_pass);
\n
"
)
set
(
INFER_IR_PASSES
${
INFER_IR_PASSES
}
tensorrt_subgraph_pass CACHE INTERNAL
""
)
endif
()
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
6ccf8685
...
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -35,8 +36,10 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph
)
const
{
framework
::
ir
::
FusePassBase
::
Init
(
"tensorrt_subgraph_pass"
,
graph
.
get
());
auto
teller
=
Get
<
SubgraphDetector
::
NodeInsideSubgraphTeller
>
(
"tensorrt_node_teller"
);
auto
teller
=
[](
const
framework
::
ir
::
Node
*
node
)
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
())
return
false
;
return
tensorrt
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
());
};
SubGraphFuser
fuser
(
graph
.
get
(),
teller
,
Get
<
int
>
(
"min_subgraph_size"
)
/*min subgraph size*/
);
...
...
@@ -232,7 +235,6 @@ std::vector<std::string> ExtractParameters(
REGISTER_PASS
(
tensorrt_subgraph_pass
,
paddle
::
inference
::
analysis
::
TensorRtSubgraphPass
)
.
RequirePassAttr
(
"tensorrt_node_teller"
)
.
RequirePassAttr
(
"max_batch_size"
)
.
RequirePassAttr
(
"workspace_size"
)
.
RequirePassAttr
(
"min_subgraph_size"
);
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc
浏览文件 @
6ccf8685
...
...
@@ -27,9 +27,6 @@ namespace analysis {
void
IrAnalysisComposePass
::
RunImpl
(
Argument
*
argument
)
{
ARGUMENT_CHECK_FIELD
(
argument
,
ir_analysis_passes
);
if
(
argument
->
use_tensorrt_valid
()
&&
argument
->
use_tensorrt
())
{
InitTensorRTAttrs
(
argument
);
}
ApplyIrPasses
(
argument
);
CollectFusionStatis
(
argument
);
}
...
...
@@ -38,26 +35,6 @@ std::string IrAnalysisComposePass::repr() const {
return
"ir-analysis-compose-pass"
;
}
void
IrAnalysisComposePass
::
InitTensorRTAttrs
(
Argument
*
argument
)
{
if
(
argument
->
use_tensorrt_valid
()
&&
argument
->
use_tensorrt
())
{
LOG
(
INFO
)
<<
"Initing TensorRT pass"
;
argument
->
SetTensorRtNodeTeller
([](
const
framework
::
ir
::
Node
*
node
)
{
std
::
unordered_set
<
std
::
string
>
teller_set
(
{
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"elementwise_mul"
,
"dropout"
,
"split"
,
"prelu"
,
"conv2d_transpose"
,
"leaky_relu"
});
if
(
!
node
->
IsOp
())
return
false
;
if
(
teller_set
.
count
(
node
->
Op
()
->
Type
()))
{
return
true
;
}
else
{
return
false
;
}
});
}
}
void
IrAnalysisComposePass
::
ApplyIrPasses
(
Argument
*
argument
)
{
std
::
vector
<
std
::
string
>
passes
({
"ir_graph_build_pass"
,
"ir_analysis_pass"
,
...
...
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h
浏览文件 @
6ccf8685
...
...
@@ -33,8 +33,6 @@ class IrAnalysisComposePass : public AnalysisPass {
std
::
string
repr
()
const
override
;
private:
void
InitTensorRTAttrs
(
Argument
*
argument
);
void
ApplyIrPasses
(
Argument
*
argument
);
void
CollectFusionStatis
(
Argument
*
argument
);
...
...
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
6ccf8685
nv_library
(
tensorrt_engine SRCS engine.cc DEPS
${
GLOB_OPERATOR_DEPS
}
framework_proto device_context
)
nv_library
(
tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine
)
add_subdirectory
(
plugin
)
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
0 → 100644
浏览文件 @
6ccf8685
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/op_teller.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
// Just tell by the op_types.
struct
SimpleOpTypeSetTeller
:
public
Teller
{
SimpleOpTypeSetTeller
()
{}
bool
operator
()(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
override
{
return
teller_set
.
count
(
op_type
);
}
private:
std
::
unordered_set
<
std
::
string
>
teller_set
{
{
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"elementwise_mul"
,
"dropout"
,
"split"
,
"prelu"
,
"conv2d_transpose"
,
"leaky_relu"
}};
};
bool
OpTeller
::
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
{
for
(
auto
&
teller
:
tellers_
)
{
if
((
*
teller
)(
op_type
,
desc
))
return
true
;
}
return
false
;
}
OpTeller
::
OpTeller
()
{
tellers_
.
emplace_back
(
new
SimpleOpTypeSetTeller
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/op_teller.h
0 → 100644
浏览文件 @
6ccf8685
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Single Op teller definition.
* One can override this and define a more complex tell logic, considerring more
* issues such as op_desc.
*/
struct
Teller
{
virtual
bool
operator
()(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
=
0
;
virtual
~
Teller
()
=
default
;
};
/*
* A real example:
*
* struct SomeTeller : public Teller {
* bool operator()(const std::string& op_type,
* const framework::OpDesc& desc) override {
* return op_type == "fc" && desc.Inputs().size() == 2;
* }
*};
*/
/*
* class OpTeller helps to tell whether a fluid
* operator can be transformed to a TensorRT layer.
*/
class
OpTeller
{
public:
static
OpTeller
&
Global
()
{
static
std
::
unique_ptr
<
OpTeller
>
x
(
new
OpTeller
);
return
*
x
;
}
bool
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
);
private:
OpTeller
();
private:
std
::
vector
<
std
::
unique_ptr
<
Teller
>>
tellers_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录