Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4646c0f3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4646c0f3
编写于
5月 04, 2018
作者:
T
Tao Luo
提交者:
GitHub
5月 04, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10144 from luotao1/tr_convert_init
tensorrt convert init
上级
3bb99c4f
beb12455
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
372 addition
and
7 deletion
+372
-7
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+4
-4
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+3
-0
paddle/fluid/inference/tensorrt/convert/activation_op.cc
paddle/fluid/inference/tensorrt/convert/activation_op.cc
+40
-0
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+34
-0
paddle/fluid/inference/tensorrt/convert/mul_op.cc
paddle/fluid/inference/tensorrt/convert/mul_op.cc
+33
-0
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+91
-0
paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
...le/fluid/inference/tensorrt/convert/test_activation_op.cc
+94
-0
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+37
-0
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+28
-2
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+8
-0
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+0
-1
未找到文件。
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
4646c0f3
if
(
WITH_TESTING
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda
)
endif
(
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda
)
set
(
ENGINE_FILE
${
CMAKE_CURRENT_SOURCE_DIR
}
/engine.cc
)
add_subdirectory
(
convert
)
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
0 → 100644
浏览文件 @
4646c0f3
nv_test
(
test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS
${
FLUID_CORE_MODULES
}
)
nv_test
(
test_tensorrt_activation_op SRCS test_activation_op.cc
${
ENGINE_FILE
}
activation_op.cc
DEPS
${
FLUID_CORE_MODULES
}
activation_op
)
paddle/fluid/inference/tensorrt/convert/activation_op.cc
0 → 100644
浏览文件 @
4646c0f3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
ReluOpConverter
:
public
OpConverter
{
public:
ReluOpConverter
()
{}
void
operator
()(
const
framework
::
OpDesc
&
op
)
override
{
LOG
(
INFO
)
<<
"convert a fluid relu op to tensorrt activation layer whose "
"type is Relu"
;
const
nvinfer1
::
ITensor
*
input_tensor
=
engine_
->
GetITensor
(
op
.
Input
(
"X"
)[
0
]);
nvinfer1
::
IActivationLayer
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input_tensor
),
nvinfer1
::
ActivationType
::
kRELU
);
engine_
->
SetITensor
(
op
.
Output
(
"Out"
)[
0
],
layer
->
getOutput
(
0
));
}
};
REGISTER_TRT_OP_CONVERTER
(
relu
,
ReluOpConverter
);
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
0 → 100644
浏览文件 @
4646c0f3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
Conv2dOpConverter
:
public
OpConverter
{
public:
Conv2dOpConverter
()
{}
void
operator
()(
const
framework
::
OpDesc
&
op
)
override
{
LOG
(
INFO
)
<<
"convert a fluid conv2d op to tensorrt conv layer without bias"
;
}
};
REGISTER_TRT_OP_CONVERTER
(
conv2d
,
Conv2dOpConverter
);
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/convert/mul_op.cc
0 → 100644
浏览文件 @
4646c0f3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
MulOpConverter
:
public
OpConverter
{
public:
MulOpConverter
()
{}
void
operator
()(
const
framework
::
OpDesc
&
op
)
override
{
LOG
(
INFO
)
<<
"convert a fluid mul op to tensorrt fc layer without bias"
;
}
};
REGISTER_TRT_OP_CONVERTER
(
mul
,
MulOpConverter
);
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/convert/op_converter.h
0 → 100644
浏览文件 @
4646c0f3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Convert Op from Fluid to TensorRT Engine.
*/
class
OpConverter
{
public:
OpConverter
()
{}
virtual
void
operator
()(
const
framework
::
OpDesc
&
op
)
{}
void
Execute
(
const
framework
::
OpDesc
&
op
,
TensorRTEngine
*
engine
)
{
std
::
string
type
=
op
.
Type
();
auto
it
=
converters_
.
find
(
type
);
PADDLE_ENFORCE
(
it
!=
converters_
.
end
(),
"no OpConverter for optype [%s]"
,
type
);
it
->
second
->
SetEngine
(
engine
);
(
*
it
->
second
)(
op
);
}
static
OpConverter
&
Global
()
{
static
auto
*
x
=
new
OpConverter
;
return
*
x
;
}
template
<
typename
T
>
void
Register
(
const
std
::
string
&
key
)
{
converters_
[
key
]
=
new
T
;
}
// convert fluid op to tensorrt layer
void
ConvertOp
(
const
framework
::
OpDesc
&
op
,
TensorRTEngine
*
engine
)
{
OpConverter
::
Global
().
Execute
(
op
,
engine
);
}
// convert fluid block to tensorrt network
void
ConvertBlock
(
const
framework
::
BlockDesc
&
block
,
TensorRTEngine
*
engine
)
{
for
(
auto
op
:
block
.
AllOps
())
{
OpConverter
::
Global
().
Execute
(
*
op
,
engine
);
}
}
void
SetEngine
(
TensorRTEngine
*
engine
)
{
engine_
=
engine
;
}
virtual
~
OpConverter
()
{}
// TensorRT engine
TensorRTEngine
*
engine_
{
nullptr
};
private:
// registered op converter map, whose key is the fluid op type, and value is
// the pointer position of corresponding OpConverter class.
std
::
unordered_map
<
std
::
string
,
OpConverter
*>
converters_
;
// fluid inference scope
framework
::
Scope
*
scope_
{
nullptr
};
};
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter { \
trt_##op_type__##_converter() { \
OpConverter::Global().Register<Converter__>(#op_type__); \
} \
}; \
trt_##op_type__##_converter trt_##op_type__##_converter__;
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
0 → 100644
浏览文件 @
4646c0f3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
USE_OP
(
relu
);
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
void
compare
(
float
input
,
float
expect
)
{
framework
::
Scope
scope
;
platform
::
CUDAPlace
place
;
platform
::
CUDADeviceContext
ctx
(
place
);
// init fluid op and variable
auto
x_var
=
scope
.
Var
(
"X"
);
auto
x_tensor
=
x_var
->
GetMutable
<
framework
::
LoDTensor
>
();
x_tensor
->
Resize
({
1
,
1
});
std
::
vector
<
float
>
init
;
init
.
push_back
(
input
);
framework
::
TensorFromVector
(
init
,
ctx
,
x_tensor
);
auto
out_var
=
scope
.
Var
(
"Out"
);
auto
out_tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_tensor
->
Resize
({
1
,
1
});
out_tensor
->
mutable_data
<
float
>
(
place
);
framework
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"relu"
);
op_desc
.
SetInput
(
"X"
,
{
"X"
});
op_desc
.
SetOutput
(
"Out"
,
{
"Out"
});
auto
relu_op
=
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
// run fluid op
relu_op
->
Run
(
scope
,
place
);
std
::
vector
<
float
>
out1
;
framework
::
TensorToVector
(
*
out_tensor
,
ctx
,
&
out1
);
// init tensorrt op
cudaStream_t
stream
;
ASSERT_EQ
(
0
,
cudaStreamCreate
(
&
stream
));
TensorRTEngine
*
engine
=
new
TensorRTEngine
(
1
,
1
<<
10
,
&
stream
);
engine
->
InitNetwork
();
engine
->
DeclareInput
(
"X"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
DimsCHW
{
1
,
1
,
1
});
OpConverter
op_converter
;
op_converter
.
ConvertOp
(
op_desc
,
engine
);
engine
->
DeclareOutput
(
"Out"
);
engine
->
FreezeNetwork
();
engine
->
SetInputFromCPU
(
"X"
,
&
input
,
1
*
sizeof
(
float
));
// run tensorrt op
engine
->
Execute
(
1
);
float
out2
;
engine
->
GetOutputInCPU
(
"Out"
,
&
out2
,
1
*
sizeof
(
float
));
ASSERT_EQ
(
out1
[
0
],
out2
);
ASSERT_EQ
(
out1
[
0
],
expect
);
delete
engine
;
cudaStreamDestroy
(
stream
);
}
TEST
(
OpConverter
,
ConvertRelu
)
{
compare
(
1
,
1
);
// relu(1) = 1
compare
(
-
5
,
0
);
// relu(-5) = 0
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
0 → 100644
浏览文件 @
4646c0f3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
TEST
(
BlockConverter
,
ConvertBlock
)
{
framework
::
ProgramDesc
prog
;
auto
*
block
=
prog
.
MutableBlock
(
0
);
auto
*
mul_op
=
block
->
AppendOp
();
mul_op
->
SetType
(
"mul"
);
auto
*
conv2d_op
=
block
->
AppendOp
();
conv2d_op
->
SetType
(
"conv2d"
);
OpConverter
converter
;
converter
.
ConvertBlock
(
*
block
,
nullptr
/*TensorRTEngine*/
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
4646c0f3
...
...
@@ -80,8 +80,8 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
"should initnetwork first"
);
auto
*
input
=
infer_network_
->
addInput
(
name
.
c_str
(),
dtype
,
dim
);
PADDLE_ENFORCE
(
input
,
"infer network add input %s failed"
,
name
);
buffer_sizes_
[
name
]
=
kDataTypeSize
[
static_cast
<
int
>
(
dtype
)]
*
AccumDims
(
dim
);
TensorRTEngine
::
SetITensor
(
name
,
input
);
return
input
;
}
...
...
@@ -99,6 +99,19 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
buffer_sizes_
[
name
]
=
0
;
}
void
TensorRTEngine
::
DeclareOutput
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
0
,
buffer_sizes_
.
count
(
name
),
"duplicate output name %s"
,
name
);
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
output
->
setName
(
name
.
c_str
());
infer_network_
->
markOutput
(
*
output
);
// output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter.
buffer_sizes_
[
name
]
=
0
;
}
void
*
TensorRTEngine
::
GetOutputInGPU
(
const
std
::
string
&
name
)
{
return
buffer
(
name
);
}
...
...
@@ -110,7 +123,6 @@ void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
());
PADDLE_ENFORCE_GT
(
it
->
second
,
0
);
PADDLE_ENFORCE_GE
(
max_size
,
it
->
second
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
dst
,
buffer
(
name
),
it
->
second
,
cudaMemcpyDeviceToHost
,
*
stream_
));
}
...
...
@@ -126,10 +138,24 @@ void*& TensorRTEngine::buffer(const std::string& name) {
void
TensorRTEngine
::
SetInputFromCPU
(
const
std
::
string
&
name
,
void
*
data
,
size_t
size
)
{
void
*
buf
=
buffer
(
name
);
cudaMemcpyAsync
(
buf
,
data
,
size
,
cudaMemcpyHostToDevice
,
*
stream_
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
,
data
,
size
,
cudaMemcpyHostToDevice
,
*
stream_
));
}
void
TensorRTEngine
::
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
)
{
PADDLE_ENFORCE
(
tensor
!=
nullptr
);
PADDLE_ENFORCE_EQ
(
0
,
itensor_map_
.
count
(
name
),
"duplicate itensor name %s"
,
name
);
itensor_map_
[
name
]
=
tensor
;
}
nvinfer1
::
ITensor
*
TensorRTEngine
::
GetITensor
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE
(
itensor_map_
.
count
(
name
),
"no itensor %s"
,
name
);
return
itensor_map_
[
name
];
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
4646c0f3
...
...
@@ -80,6 +80,8 @@ class TensorRTEngine : public EngineBase {
// name.
void
DeclareOutput
(
const
nvinfer1
::
ILayer
*
layer
,
int
offset
,
const
std
::
string
&
name
);
// Set the itensor_map_[name] as the network's output, and set its name.
void
DeclareOutput
(
const
std
::
string
&
name
);
// GPU memory address for an ITensor with specific name. One can operate on
// these memory directly for acceleration, for example, output the converted
...
...
@@ -98,6 +100,10 @@ class TensorRTEngine : public EngineBase {
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU.
void
GetOutputInCPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
);
// Fill an ITensor into map itensor_map_.
void
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
);
// Get an ITensor called name.
nvinfer1
::
ITensor
*
GetITensor
(
const
std
::
string
&
name
);
nvinfer1
::
ICudaEngine
*
engine
()
{
return
infer_engine_
.
get
();
}
nvinfer1
::
INetworkDefinition
*
network
()
{
return
infer_network_
.
get
();
}
...
...
@@ -113,6 +119,8 @@ class TensorRTEngine : public EngineBase {
std
::
vector
<
void
*>
buffers_
;
// max data size for the buffers.
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
itensor_map_
;
// TensorRT related internal members
template
<
typename
T
>
...
...
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
4646c0f3
...
...
@@ -70,7 +70,6 @@ TEST_F(TensorRTEngineTest, add_layer) {
engine_
->
Execute
(
1
);
LOG
(
INFO
)
<<
"to get output"
;
// void* y_v =
float
y_cpu
;
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
,
sizeof
(
float
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录