Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3e3599f3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e3599f3
编写于
11月 21, 2018
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine split tensorrt plugin
上级
33c65517
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
134 addition
and
35 deletion
+134
-35
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+1
-2
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+126
-31
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+7
-2
未找到文件。
paddle/fluid/inference/tensorrt/convert/split_op.cc
浏览文件 @
3e3599f3
...
...
@@ -40,7 +40,7 @@ class SplitOpConverter : public OpConverter {
int
axis
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"axis"
));
std
::
vector
<
int
>
output_lengths
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"sections"
));
PADDLE_ENFORCE
(
axis
!=
0
);
//
PADDLE_ENFORCE(axis != 0);
if
(
axis
<
0
)
{
axis
+=
input_dims
.
nbDims
;
}
else
{
...
...
@@ -48,7 +48,6 @@ class SplitOpConverter : public OpConverter {
}
PADDLE_ENFORCE
(
output_lengths
.
size
()
==
output_num
);
//
plugin
::
SplitPlugin
*
plugin
=
new
plugin
::
SplitPlugin
(
axis
,
output_lengths
);
nvinfer1
::
IPluginLayer
*
layer
=
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
浏览文件 @
3e3599f3
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_fp16.h>
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace
paddle
{
...
...
@@ -19,6 +21,52 @@ namespace inference {
namespace
tensorrt
{
namespace
plugin
{
// copied from operators::math::SplitFunctor
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
const
int
in_col
,
const
int
*
out_cols
,
int
out_cols_size
,
T
**
outputs_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
curr_segment
=
0
;
int
curr_offset
=
out_cols
[
0
];
for
(;
tid_x
<
in_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
curr_col_offset
=
out_cols
[
curr_segment
+
1
];
while
(
curr_col_offset
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
curr_col_offset
=
out_cols
[
curr_segment
+
1
];
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
output_ptr
=
outputs_data
[
curr_segment
];
if
(
output_ptr
!=
nullptr
)
{
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
in_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
segment_width
+
local_col
]
=
input_data
[
tid_y
*
in_col
+
tid_x
];
}
}
}
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int
in_row
,
const
int
in_col
,
const
int
fixed_out_col
,
T
**
outputs_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
tid_x
<
in_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
/
fixed_out_col
;
int
in_offset
=
tid_x
-
split
*
fixed_out_col
;
T
*
output_ptr
=
outputs_data
[
split
];
if
(
output_ptr
!=
nullptr
)
{
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
in_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
fixed_out_col
+
in_offset
]
=
input_data
[
tid_y
*
in_col
+
tid_x
];
}
}
}
nvinfer1
::
Dims
SplitPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
input_dims
,
int
num_inputs
)
{
PADDLE_ENFORCE_EQ
(
num_inputs
,
1
);
...
...
@@ -31,48 +79,95 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(
int
SplitPlugin
::
initialize
()
{
PADDLE_ENFORCE_LE
(
axis_
,
nvinfer1
::
Dims
::
MAX_DIMS
);
// notice input dims is [C, H, W]
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
outer_rows_
=
1
;
inner_cols_
=
1
;
for
(
int
i
=
0
;
i
<
axis_
;
++
i
)
{
outer_rows_
*=
dims
.
d
[
i
];
}
for
(
int
i
=
axis_
+
1
;
i
<
dims
.
nbDims
;
++
i
)
{
inner_cols_
*=
dims
.
d
[
i
];
}
same_shape_
=
true
;
std
::
vector
<
int
>
segment_offsets
(
1
,
0
);
for
(
int
i
=
0
;
i
<
this
->
getNbOutputs
();
++
i
)
{
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_length_
[
i
]);
if
(
output_length_
[
i
]
!=
output_length_
[
0
])
{
same_shape_
=
false
;
}
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_length_
[
i
]
*
inner_cols_
);
}
segment_offsets_
=
segment_offsets
;
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nx_
=
1
;
for
(
int
i
=
dims
.
nbDims
-
1
;
i
>
axis_
;
--
i
)
{
nx_
*=
dims
.
d
[
i
];
inner_cols_
*=
dims
.
d
[
axis_
];
d_segment_offsets_
=
segment_offsets
;
segment_offsets_
=
std
::
move
(
segment_offsets
);
d_output_ptrs_
.
resize
(
this
->
getNbOutputs
(),
nullptr
);
return
0
;
}
template
<
typename
T
>
inline
void
Split
(
cudaStream_t
stream
,
const
bool
same_shape
,
const
int
outer_rows
,
const
int
inner_cols
,
const
std
::
vector
<
int
>&
segment_offsets
,
const
int
*
d_segment_offsets
,
const
T
*
input
,
T
**
outputs
)
{
const
int
kThreadsPerBlock
=
1024
;
const
int
kMaxBlocks
=
65535
;
int
block_cols
=
kThreadsPerBlock
;
if
(
inner_cols
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
inner_cols
+
31
)
>>
5
)
<<
5
;
}
ny_
=
dims
.
d
[
axis_
];
nz_
=
1
;
for
(
int
i
=
axis_
-
1
;
i
>=
0
;
--
i
)
{
nz_
*=
dims
.
d
[
i
];
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
grid_cols
=
std
::
min
((
inner_cols
+
block_cols
-
1
)
/
block_cols
,
kMaxBlocks
);
int
grid_rows
=
std
::
min
(
kMaxBlocks
/
grid_cols
,
std
::
max
(
outer_rows
/
block_rows
,
1
));
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
if
(
same_shape
)
{
SplitKernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
input
,
outer_rows
,
inner_cols
,
segment_offsets
[
1
],
outputs
);
}
else
{
SplitKernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
input
,
outer_rows
,
inner_cols
,
d_segment_offsets
,
static_cast
<
int
>
(
segment_offsets
.
size
()),
outputs
);
}
return
0
;
}
int
SplitPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
input_size
=
0
;
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
// kernel impl here.
int
inputBatchOffset
=
nx_
*
ny_
*
nz_
;
for
(
size_t
i
=
0
;
i
<
this
->
getNbOutputs
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
batchSize
;
j
++
)
{
cudaMemcpyAsync
(
odatas
[
i
]
+
j
*
(
segment_offsets_
[
i
+
1
]
-
segment_offsets_
[
i
])
*
nx_
*
sizeof
(
float
),
inputs
[
0
]
+
(
inputBatchOffset
*
j
+
segment_offsets_
[
i
]
*
nx_
)
*
sizeof
(
float
),
(
segment_offsets_
[
i
+
1
]
-
segment_offsets_
[
i
])
*
nx_
*
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
);
float
const
*
input_ptr
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
if
(
axis_
==
-
1
&&
this
->
getNbOutputs
()
<
10
)
{
float
**
output_ptrs
=
reinterpret_cast
<
float
**>
(
outputs
);
int
data_type_size
=
(
this
->
getDataType
()
==
nvinfer1
::
DataType
::
kFLOAT
)
?
sizeof
(
__half
)
:
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
this
->
getNbOutputs
();
++
i
)
{
PADDLE_ENFORCE
(
cudaMemcpyAsync
(
output_ptrs
[
i
],
input_ptr
+
segment_offsets_
[
i
],
(
segment_offsets_
[
i
+
1
]
-
segment_offsets_
[
i
])
*
data_type_size
,
cudaMemcpyDeviceToDevice
,
stream
)
==
cudaSuccess
);
}
}
else
{
outer_rows_
*=
batchSize
;
const
int
*
d_segment_offsets_ptr
=
thrust
::
raw_pointer_cast
(
&
d_segment_offsets_
[
0
]);
float
**
output_ptrs
=
thrust
::
raw_pointer_cast
(
&
d_output_ptrs_
[
0
]);
PADDLE_ENFORCE
(
cudaMemcpyAsync
(
output_ptrs
,
outputs
,
this
->
getNbOutputs
()
*
sizeof
(
float
*
),
cudaMemcpyHostToDevice
,
stream
)
==
cudaSuccess
);
if
(
this
->
getDataType
()
==
nvinfer1
::
DataType
::
kFLOAT
)
{
Split
(
stream
,
same_shape_
,
outer_rows_
,
inner_cols_
,
segment_offsets_
,
d_segment_offsets_ptr
,
input_ptr
,
output_ptrs
);
}
else
{
Split
(
stream
,
same_shape_
,
outer_rows_
,
inner_cols_
,
segment_offsets_
,
d_segment_offsets_ptr
,
(
__half
*
)
input_ptr
,
// NOLINT
(
__half
**
)
output_ptrs
);
// NOLINT
}
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
浏览文件 @
3e3599f3
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <thrust/device_vector.h>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...
...
@@ -25,7 +26,7 @@ namespace plugin {
class
SplitPlugin
:
public
PluginTensorRT
{
public:
SplitPlugin
(
int
axis
,
std
::
vector
<
int
>
const
&
output_lengths
)
:
axis_
(
axis
),
output_length_
(
output_lengths
)
{}
:
axis_
(
axis
),
same_shape_
(
true
),
output_length_
(
output_lengths
)
{}
SplitPlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{
deserializeBase
(
serial_data
,
serial_length
);
...
...
@@ -60,9 +61,13 @@ class SplitPlugin : public PluginTensorRT {
}
int
axis_
;
int
outer_rows_
;
int
inner_cols_
;
bool
same_shape_
;
std
::
vector
<
int
>
output_length_
;
int
nx_
,
ny_
,
nz_
;
std
::
vector
<
int
>
segment_offsets_
;
thrust
::
device_vector
<
int
>
d_segment_offsets_
;
thrust
::
device_vector
<
float
*>
d_output_ptrs_
;
};
}
// namespace plugin
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录