Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
6a7b9957
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a7b9957
编写于
11月 16, 2018
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine commit message to enable ci, test=develop
上级
413f5948
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
19 addition
and
36 deletion
+19
-36
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
+16
-18
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+1
-1
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+1
-1
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
+1
-14
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
+0
-2
未找到文件。
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
浏览文件 @
6a7b9957
...
...
@@ -26,7 +26,7 @@ class PReluOpConverter : public OpConverter {
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
0
)
<<
"convert fluid prelu op to tensorrt prelu layer"
;
VLOG
(
4
)
<<
"convert fluid prelu op to tensorrt prelu layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
...
...
@@ -43,33 +43,31 @@ class PReluOpConverter : public OpConverter {
PADDLE_ENFORCE_NOT_NULL
(
alpha_var
);
auto
*
alpha_tensor
=
alpha_var
->
GetMutable
<
framework
::
LoDTensor
>
();
platform
::
C
PU
Place
place
;
std
::
unique_ptr
<
framework
::
LoDTensor
>
alpha_tensor_
host
(
platform
::
C
UDA
Place
place
;
std
::
unique_ptr
<
framework
::
LoDTensor
>
alpha_tensor_
device
(
new
framework
::
LoDTensor
());
alpha_tensor_
host
->
Resize
(
alpha_tensor
->
dims
());
TensorCopySync
(
*
alpha_tensor
,
place
,
alpha_tensor_
host
.
get
());
float
*
alpha_data
=
alpha_tensor_
host
->
mutable_data
<
float
>
(
place
);
alpha_tensor_
device
->
Resize
(
alpha_tensor
->
dims
());
TensorCopySync
(
*
alpha_tensor
,
place
,
alpha_tensor_
device
.
get
());
float
*
alpha_data
=
alpha_tensor_
device
->
mutable_data
<
float
>
(
place
);
// Transform alpha to TensorRTEngine::Weight
TensorRTEngine
::
Weight
alpha_rt
(
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
alpha_data
),
alpha_tensor_host
->
numel
());
engine_
->
weight_map
[
op_desc
.
Input
(
"Alpha"
)[
0
]]
=
std
::
move
(
alpha_tensor_host
);
//
alpha_tensor_device
->
numel
());
PReluPlugin
*
plugin
=
new
PReluPlugin
(
alpha_rt
,
mode
);
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
// keep alpha tensor to avoid release it's memory
engine_
->
weight_map
[
op_desc
.
Input
(
"Alpha"
)[
0
]]
=
std
::
move
(
alpha_tensor_device
);
std
::
string
layer_name
=
"prelu (Output: "
;
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
i
];
layer
->
getOutput
(
i
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
i
));
layer_name
+=
output_name
;
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
layer_name
+=
output_name
;
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
layer
->
setName
((
layer_name
+
")"
).
c_str
());
}
...
...
paddle/fluid/inference/tensorrt/convert/split_op.cc
浏览文件 @
6a7b9957
...
...
@@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter {
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
0
)
<<
"convert a fluid split op to tensorrt split layer"
;
VLOG
(
4
)
<<
"convert a fluid split op to tensorrt split layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
6a7b9957
...
...
@@ -46,7 +46,7 @@ class TensorRTEngine : public EngineBase {
w_
.
values
=
value
;
w_
.
count
=
num_elem
;
}
nvinfer1
::
Weights
&
get
()
{
return
w_
;
}
const
nvinfer1
::
Weights
&
get
()
{
return
w_
;
}
std
::
vector
<
int64_t
>
dims
;
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
浏览文件 @
6a7b9957
...
...
@@ -109,25 +109,12 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
return
output_dims
;
}
int
PReluPlugin
::
initialize
()
{
nvinfer1
::
Weights
&
alpha
=
cuda_alpha_
.
get
();
alpha
.
type
=
alpha_
.
get
().
type
;
alpha
.
count
=
alpha_
.
get
().
count
;
CHECK_EQ
(
cudaMalloc
(
&
alpha
.
values
,
alpha
.
count
*
sizeof
(
float
)),
cudaSuccess
);
CHECK_EQ
(
cudaMemcpy
(
const_cast
<
void
*>
(
alpha
.
values
),
alpha_
.
get
().
values
,
alpha
.
count
*
sizeof
(
float
),
cudaMemcpyHostToDevice
),
cudaSuccess
);
return
0
;
}
int
PReluPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
// input dims is CHW.
const
auto
&
input_dims
=
this
->
getInputDims
(
0
);
const
float
*
input
=
reinterpret_cast
<
const
float
*>
(
inputs
[
0
]);
const
float
*
alpha
=
reinterpret_cast
<
const
float
*>
(
cuda_alpha_
.
get
().
values
);
const
float
*
alpha
=
reinterpret_cast
<
const
float
*>
(
alpha_
.
get
().
values
);
float
*
output
=
reinterpret_cast
<
float
**>
(
outputs
)[
0
];
if
(
mode_
==
"channel"
)
{
PReluChannelWise
(
stream
,
input
,
alpha
,
output
,
batchSize
,
input_dims
);
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
浏览文件 @
6a7b9957
...
...
@@ -24,7 +24,6 @@ namespace tensorrt {
class
PReluPlugin
:
public
PluginTensorRT
{
TensorRTEngine
::
Weight
alpha_
;
TensorRTEngine
::
Weight
cuda_alpha_
;
std
::
string
mode_
;
protected:
...
...
@@ -60,7 +59,6 @@ class PReluPlugin : public PluginTensorRT {
int
getNbOutputs
()
const
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
int
initialize
()
override
;
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录