Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
1b74fded
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1b74fded
编写于
12月 23, 2019
作者:
W
Wilber
提交者:
GitHub
12月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sequence_pool_concat fuse and kernel test=develop (#2645)
add sequence_pool_concat fuse pass add fuse kernel
上级
3723451b
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
692 addition
and
1 deletion
+692
-1
lite/api/paddle_use_passes.h
lite/api/paddle_use_passes.h
+1
-0
lite/core/mir/CMakeLists.txt
lite/core/mir/CMakeLists.txt
+1
-0
lite/core/mir/fusion/CMakeLists.txt
lite/core/mir/fusion/CMakeLists.txt
+4
-0
lite/core/mir/fusion/sequence_pool_concat_fuse_pass.cc
lite/core/mir/fusion/sequence_pool_concat_fuse_pass.cc
+36
-0
lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h
lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h
+32
-0
lite/core/mir/fusion/sequence_pool_concat_fuser.cc
lite/core/mir/fusion/sequence_pool_concat_fuser.cc
+153
-0
lite/core/mir/fusion/sequence_pool_concat_fuser.h
lite/core/mir/fusion/sequence_pool_concat_fuser.h
+38
-0
lite/core/optimizer.h
lite/core/optimizer.h
+1
-0
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+1
-0
lite/kernels/cuda/sequence_pool_concat_compute.cu
lite/kernels/cuda/sequence_pool_concat_compute.cu
+266
-0
lite/kernels/cuda/sequence_pool_concat_compute.h
lite/kernels/cuda/sequence_pool_concat_compute.h
+44
-0
lite/operators/CMakeLists.txt
lite/operators/CMakeLists.txt
+2
-1
lite/operators/op_params.h
lite/operators/op_params.h
+6
-0
lite/operators/sequence_pool_concat_op.cc
lite/operators/sequence_pool_concat_op.cc
+64
-0
lite/operators/sequence_pool_concat_op.h
lite/operators/sequence_pool_concat_op.h
+43
-0
未找到文件。
lite/api/paddle_use_passes.h
浏览文件 @
1b74fded
...
...
@@ -31,6 +31,7 @@ USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS
(
lite_shuffle_channel_fuse_pass
);
USE_MIR_PASS
(
lite_transpose_softmax_transpose_fuse_pass
);
USE_MIR_PASS
(
lite_interpolate_fuse_pass
);
USE_MIR_PASS
(
lite_sequence_pool_concat_fuse_pass
);
USE_MIR_PASS
(
identity_scale_eliminate_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_fuse_pass
);
USE_MIR_PASS
(
lite_conv_activation_fuse_pass
);
...
...
lite/core/mir/CMakeLists.txt
浏览文件 @
1b74fded
...
...
@@ -20,6 +20,7 @@ lite_cc_library(mir_passes
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
static_kernel_pick_pass.cc
...
...
lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
1b74fded
...
...
@@ -28,6 +28,9 @@ lite_cc_library(fuse_transpose_softmax_transpose
lite_cc_library
(
fuse_interpolate
SRCS interpolate_fuser.cc
DEPS pattern_matcher_high_api
)
lite_cc_library
(
fuse_sequence_pool_concat
SRCS sequence_pool_concat_fuser.cc
DEPS pattern_matcher_high_api
)
set
(
mir_fusers
fuse_fc
...
...
@@ -40,6 +43,7 @@ set(mir_fusers
fuse_elementwise_add_activation
fuse_transpose_softmax_transpose
fuse_interpolate
fuse_sequence_pool_concat
CACHE INTERNAL
"fusers"
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
...
...
lite/core/mir/fusion/sequence_pool_concat_fuse_pass.cc
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
SequencePoolConcatFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
SequencePoolConcatFuser
fuser
;
fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_sequence_pool_concat_fuse_pass
,
paddle
::
lite
::
mir
::
SequencePoolConcatFusePass
)
.
BindTargets
({
TARGET
(
kCUDA
)});
lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
SequencePoolConcatFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
lite/core/mir/fusion/sequence_pool_concat_fuser.cc
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void
SequencePoolConcatFuser
::
BuildPattern
()
{
// create nodes.
auto
*
concat
=
OpNode
(
"concat"
,
"concat"
)
->
AsIntermediate
();
#define STR1(R) #R
#define STR2(R) STR1(R)
#define POOL_CONCAT_PATTERN(num) \
auto* x_##num = VarNode(STR2(sequence_pool_x_##num)) \
->assert_is_op_input("sequence_pool", "X") \
->AsInput(); \
auto* sequence_pool_##num = \
OpNode(STR2(sequence_pool_##num), "sequence_pool")->AsIntermediate(); \
auto* sequence_pool_##num##_out = \
VarNode(STR2(sequence_pool_##num##_out)) \
->assert_is_op_output("sequence_pool", "Out") \
->assert_is_op_nth_input("concat", "X", num - 1) \
->AsIntermediate(); \
auto* sequence_pool_##num##_idx = \
VarNode(STR2(sequence_pool_##num##_idx)) \
->assert_is_op_output("sequence_pool", "MaxIndex") \
->AsIntermediate(); \
*sequence_pool_##num >> *sequence_pool_##num##_idx; \
*x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat;
auto
*
concat_out
=
VarNode
(
"concat_out"
)
->
assert_is_op_output
(
"concat"
,
"Out"
);
*
concat
>>
*
concat_out
;
POOL_CONCAT_PATTERN
(
1
);
POOL_CONCAT_PATTERN
(
2
);
POOL_CONCAT_PATTERN
(
3
);
POOL_CONCAT_PATTERN
(
4
);
POOL_CONCAT_PATTERN
(
5
);
POOL_CONCAT_PATTERN
(
6
);
POOL_CONCAT_PATTERN
(
7
);
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
}
void
SequencePoolConcatFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
sequence_pool_concat_op
=
LiteOpRegistry
::
Global
().
Create
(
"sequence_pool_concat"
);
auto
concat
=
matched
.
at
(
"concat"
)
->
stmt
()
->
op
();
auto
*
scope
=
concat
->
scope
();
auto
&
valid_places
=
concat
->
valid_places
();
sequence_pool_concat_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
sequence_pool_concat_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"sequence_pool_x_1"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"sequence_pool_x_2"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"sequence_pool_x_3"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"sequence_pool_x_4"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"sequence_pool_x_5"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"sequence_pool_x_6"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"sequence_pool_x_7"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"concat_out"
));
}
cpp
::
OpDesc
SequencePoolConcatFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
=
*
matched
.
at
(
"concat"
)
->
stmt
()
->
op_info
();
op_desc
.
SetType
(
"sequence_pool_concat"
);
op_desc
.
SetInput
(
"X"
,
{
matched
.
at
(
"sequence_pool_x_1"
)
->
arg
()
->
name
,
matched
.
at
(
"sequence_pool_x_2"
)
->
arg
()
->
name
,
matched
.
at
(
"sequence_pool_x_3"
)
->
arg
()
->
name
,
matched
.
at
(
"sequence_pool_x_4"
)
->
arg
()
->
name
,
matched
.
at
(
"sequence_pool_x_5"
)
->
arg
()
->
name
,
matched
.
at
(
"sequence_pool_x_6"
)
->
arg
()
->
name
,
matched
.
at
(
"sequence_pool_x_7"
)
->
arg
()
->
name
});
std
::
vector
<
std
::
string
>
pooltypes
;
pooltypes
.
push_back
(
matched
.
at
(
"sequence_pool_1"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
std
::
string
>
(
"pooltype"
));
pooltypes
.
push_back
(
matched
.
at
(
"sequence_pool_2"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
std
::
string
>
(
"pooltype"
));
pooltypes
.
push_back
(
matched
.
at
(
"sequence_pool_3"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
std
::
string
>
(
"pooltype"
));
pooltypes
.
push_back
(
matched
.
at
(
"sequence_pool_4"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
std
::
string
>
(
"pooltype"
));
pooltypes
.
push_back
(
matched
.
at
(
"sequence_pool_5"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
std
::
string
>
(
"pooltype"
));
pooltypes
.
push_back
(
matched
.
at
(
"sequence_pool_6"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
std
::
string
>
(
"pooltype"
));
pooltypes
.
push_back
(
matched
.
at
(
"sequence_pool_7"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
std
::
string
>
(
"pooltype"
));
op_desc
.
SetAttr
(
"pooltype"
,
pooltypes
);
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"concat_out"
)
->
arg
()
->
name
});
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
lite/core/mir/fusion/sequence_pool_concat_fuser.h
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
SequencePoolConcatFuser
:
public
FuseBase
{
public:
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
lite/core/optimizer.h
浏览文件 @
1b74fded
...
...
@@ -69,6 +69,7 @@ class Optimizer {
"lite_interpolate_fuse_pass"
,
//
"identity_scale_eliminate_pass"
,
//
"elementwise_mul_constant_eliminate_pass"
,
//
"lite_sequence_pool_concat_fuse_pass"
,
//
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass"
,
//
...
...
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
1b74fded
...
...
@@ -11,6 +11,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${
add_kernel
(
relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
cuda_transpose
)
add_kernel
(
nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
...
...
lite/kernels/cuda/sequence_pool_concat_compute.cu
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_pool_concat_compute.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
Dtype
>
__global__
void
sequence_pool_concat
(
const
uint64_t
*
input_locate_data
,
const
int
*
pool_type_list
,
Dtype
*
output_data
,
const
int
*
offset
,
int
batch
,
int
in_num
,
int
in_dim
)
{
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
em_id
=
tid
%
in_dim
;
int
in_id
=
(
tid
/
in_dim
)
%
in_num
;
int
seq_id
=
tid
/
(
in_dim
*
in_num
);
if
(
seq_id
>=
batch
)
{
return
;
}
Dtype
*
out_data
=
output_data
+
tid
;
int
offset_id
=
in_id
*
(
batch
+
1
)
+
seq_id
;
if
(
pool_type_list
[
in_id
]
==
4
)
{
// last
const
Dtype
*
in_data
=
reinterpret_cast
<
const
Dtype
*>
(
reinterpret_cast
<
uintptr_t
>
(
input_locate_data
[
in_id
]))
+
em_id
;
output_data
[
tid
]
=
in_data
[(
offset
[
offset_id
+
1
]
-
1
)
*
in_dim
];
}
else
if
(
pool_type_list
[
in_id
]
==
6
)
{
// max
const
Dtype
*
in_data
=
reinterpret_cast
<
const
Dtype
*>
(
reinterpret_cast
<
uintptr_t
>
(
input_locate_data
[
in_id
]))
+
em_id
+
offset
[
offset_id
]
*
in_dim
;
Dtype
max
=
in_data
[
0
];
for
(
int
i
=
1
;
i
<
offset
[
offset_id
+
1
]
-
offset
[
offset_id
];
i
++
)
{
Dtype
cur_data
=
in_data
[
i
*
in_dim
];
max
=
cur_data
>
max
?
cur_data
:
max
;
}
output_data
[
tid
]
=
max
;
}
else
{
return
;
}
}
template
<
typename
Dtype
>
__global__
void
sequence_pool_concat
(
const
uint64_t
*
input_locate_data
,
const
int
*
pool_type_list
,
Dtype
*
output_data
,
const
int
*
offset
,
int
batch
,
int
in_num
,
const
int
*
out_offset
,
const
int
*
out_id_seq_map_data
,
int
out_dim
)
{
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
em_id
=
tid
%
out_dim
;
int
seq_id
=
tid
/
out_dim
;
int
in_id
=
out_id_seq_map_data
[
em_id
];
em_id
=
em_id
-
out_offset
[
in_id
];
int
in_dim
=
out_offset
[
in_id
+
1
]
-
out_offset
[
in_id
];
if
(
seq_id
>=
batch
)
{
return
;
}
Dtype
*
out_data
=
output_data
+
tid
;
int
offset_id
=
in_id
*
(
batch
+
1
)
+
seq_id
;
if
(
pool_type_list
[
in_id
]
==
4
)
{
// last
const
Dtype
*
in_data
=
reinterpret_cast
<
const
Dtype
*>
(
reinterpret_cast
<
uintptr_t
>
(
input_locate_data
[
in_id
]))
+
em_id
;
output_data
[
tid
]
=
in_data
[(
offset
[
offset_id
+
1
]
-
1
)
*
in_dim
];
}
else
if
(
pool_type_list
[
in_id
]
==
6
)
{
// max
const
Dtype
*
in_data
=
reinterpret_cast
<
const
Dtype
*>
(
reinterpret_cast
<
uintptr_t
>
(
input_locate_data
[
in_id
]))
+
em_id
+
offset
[
offset_id
]
*
in_dim
;
Dtype
max
=
in_data
[
0
];
for
(
int
i
=
1
;
i
<
offset
[
offset_id
+
1
]
-
offset
[
offset_id
];
i
++
)
{
Dtype
cur_data
=
in_data
[
i
*
in_dim
];
max
=
cur_data
>
max
?
cur_data
:
max
;
}
output_data
[
tid
]
=
max
;
}
else
{
return
;
}
}
void
SequencePoolConcatCompute
::
PrepareForRun
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
int
in_num
=
param
.
X
.
size
();
std
::
vector
<
int64_t
>
shape
({
in_num
,
1
,
1
,
1
});
_in_offset_tensor
.
Resize
(
shape
);
_in_ptr_tensor
.
Resize
(
shape
);
_in_pool_type_tensor
.
Resize
(
shape
);
int
*
in_pool_type_data
=
_in_pool_type_tensor
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
std
::
vector
<
int
>
pool_type_list
;
for
(
auto
type
:
param
.
pool_type
)
{
if
(
type
==
"AVERAGE"
)
{
pool_type_list
.
push_back
(
1
);
}
else
if
(
type
==
"SUM"
)
{
pool_type_list
.
push_back
(
2
);
}
else
if
(
type
==
"SQRT"
)
{
pool_type_list
.
push_back
(
3
);
}
else
if
(
type
==
"LAST"
)
{
pool_type_list
.
push_back
(
4
);
}
else
if
(
type
==
"FIRST"
)
{
pool_type_list
.
push_back
(
5
);
}
else
if
(
type
==
"MAX"
)
{
pool_type_list
.
push_back
(
6
);
}
else
{
LOG
(
ERROR
)
<<
"pool type "
<<
type
<<
" is not supoorted."
;
}
}
_is_in_same_len
=
true
;
int
in_len
=
param
.
X
[
0
]
->
dims
().
count
(
1
,
param
.
X
[
0
]
->
dims
().
size
());
std
::
vector
<
int
>
out_id_seq_map_list
;
std
::
vector
<
int
>
out_offset_list
;
int
total_len
=
0
;
out_offset_list
.
push_back
(
total_len
);
for
(
int
i
=
0
;
i
<
in_num
;
++
i
)
{
int
cur_len
=
param
.
X
[
i
]
->
dims
().
count
(
1
,
param
.
X
[
i
]
->
dims
().
size
());
_is_in_same_len
=
_is_in_same_len
&&
in_len
==
cur_len
;
for
(
int
k
=
0
;
k
<
cur_len
;
++
k
)
{
out_id_seq_map_list
.
push_back
(
i
);
}
total_len
+=
cur_len
;
out_offset_list
.
push_back
(
total_len
);
}
std
::
vector
<
int64_t
>
out_id_seq_map_shape
({
total_len
,
1
,
1
,
1
});
std
::
vector
<
int64_t
>
out_offset_shape
({
in_num
+
1
,
1
,
1
,
1
});
_out_offset_tensor
.
Resize
(
out_offset_shape
);
_out_id_seq_map_tensor
.
Resize
(
out_id_seq_map_shape
);
int
*
out_offset_data
=
_out_offset_tensor
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
int
*
out_id_seq_map_data
=
_out_id_seq_map_tensor
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
in_pool_type_data
,
&
pool_type_list
[
0
],
sizeof
(
int
)
*
param
.
X
.
size
(),
IoDirection
::
HtoD
,
stream
);
TargetWrapperCuda
::
MemcpyAsync
(
out_offset_data
,
&
out_offset_list
[
0
],
sizeof
(
int
)
*
out_offset_list
.
size
(),
IoDirection
::
HtoD
,
stream
);
TargetWrapperCuda
::
MemcpyAsync
(
out_id_seq_map_data
,
&
out_id_seq_map_list
[
0
],
sizeof
(
int
)
*
out_id_seq_map_list
.
size
(),
IoDirection
::
HtoD
,
stream
);
cudaStreamSynchronize
(
stream
);
}
void
SequencePoolConcatCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
auto
&
inputs
=
param
.
X
;
auto
offset
=
inputs
[
0
]
->
lod
()[
0
];
int
batch
=
offset
.
size
()
-
1
;
CHECK_GE
(
offset
.
size
(),
1
);
std
::
vector
<
int
>
all_offset
;
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
it
=
all_offset
.
end
();
auto
cur_offset
=
inputs
[
i
]
->
lod
()[
0
];
all_offset
.
insert
(
it
,
cur_offset
.
begin
(),
cur_offset
.
end
());
}
int
total_size
=
all_offset
.
size
();
std
::
vector
<
int64_t
>
offset_shape
({
total_size
,
1
,
1
,
1
});
_in_offset_tensor
.
Resize
(
offset_shape
);
int
*
offset_data
=
_in_offset_tensor
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
offset_data
,
&
all_offset
[
0
],
sizeof
(
int
)
*
all_offset
.
size
(),
IoDirection
::
HtoD
,
stream
);
std
::
vector
<
uint64_t
>
in_locate_vec
;
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
in_locate_vec
.
push_back
(
reinterpret_cast
<
uintptr_t
>
(
inputs
[
i
]
->
data
<
float
>
()));
}
uint64_t
*
in_locate_data
=
_in_ptr_tensor
.
mutable_data
<
uint64_t
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
in_locate_data
,
&
in_locate_vec
[
0
],
sizeof
(
uint64_t
)
*
inputs
.
size
(),
IoDirection
::
HtoD
,
stream
);
const
int
*
in_pool_type_data
=
_in_pool_type_tensor
.
data
<
int
>
();
const
int
*
out_id_seq_map_data
=
_out_id_seq_map_tensor
.
data
<
int
>
();
const
int
*
out_offset_data
=
_out_offset_tensor
.
data
<
int
>
();
int
count
=
param
.
Out
->
numel
();
int
in_dim
=
inputs
[
0
]
->
numel
()
/
inputs
[
0
]
->
dims
()[
0
];
float
*
out_data
=
param
.
Out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
int
in_num
=
inputs
.
size
();
if
(
_is_in_same_len
)
{
sequence_pool_concat
<
float
><<<
CUDA_GET_BLOCKS
(
count
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_locate_data
,
in_pool_type_data
,
out_data
,
offset_data
,
batch
,
in_num
,
in_dim
);
}
else
{
int
out_dim
=
param
.
Out
->
numel
()
/
param
.
Out
->
dims
()[
0
];
sequence_pool_concat
<
float
><<<
CUDA_GET_BLOCKS
(
count
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_locate_data
,
in_pool_type_data
,
out_data
,
offset_data
,
batch
,
in_num
,
out_offset_data
,
out_id_seq_map_data
,
out_dim
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
sequence_pool_concat
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
SequencePoolConcatCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
lite/kernels/cuda/sequence_pool_concat_compute.h
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
SequencePoolConcatCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
SequencePoolConcatParam
;
void
Run
()
override
;
void
PrepareForRun
()
override
;
virtual
~
SequencePoolConcatCompute
()
=
default
;
private:
lite
::
Tensor
_in_offset_tensor
;
lite
::
Tensor
_in_ptr_tensor
;
lite
::
Tensor
_in_pool_type_tensor
;
lite
::
Tensor
_out_offset_tensor
;
lite
::
Tensor
_out_id_seq_map_tensor
;
bool
_is_in_same_len
;
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/operators/CMakeLists.txt
浏览文件 @
1b74fded
...
...
@@ -90,6 +90,8 @@ add_operator(merge_lod_tensor_op_lite extra SRCS merge_lod_tensor_op.cc DEPS ${o
add_operator
(
reduce_prod_op_lite extra SRCS reduce_prod_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_pool extra SRCS sequence_pool_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS
${
op_DEPS
}
)
add_operator
(
match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS
${
op_DEPS
}
)
...
...
@@ -120,7 +122,6 @@ add_operator(greater_than extra SRCS compare_op.cc DEPS ${op_DEPS})
add_operator
(
greater_equal extra SRCS compare_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
read_from_array_op extra SRCS read_from_array_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
beam_search_op extra SRCS beam_search_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_pool extra SRCS sequence_pool_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
lod_reset_op extra SRCS lod_reset_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
is_empty extra SRCS is_empty_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
slice_op_lite basic SRCS slice_op.cc DEPS
${
op_DEPS
}
)
...
...
lite/operators/op_params.h
浏览文件 @
1b74fded
...
...
@@ -769,6 +769,12 @@ struct SequencePoolParam {
#endif
};
struct
SequencePoolConcatParam
{
std
::
vector
<
lite
::
Tensor
*>
X
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
std
::
string
>
pool_type
{};
};
struct
SearchGroupPaddingParam
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
out_emb_padding
{};
...
...
lite/operators/sequence_pool_concat_op.cc
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_pool_concat_op.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
SequencePoolConcatOp
::
CheckShape
()
const
{
CHECK_GE
(
param_
.
X
.
size
(),
1
)
<<
"The number of input sequences is at least two."
;
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
}
bool
SequencePoolConcatOp
::
InferShape
()
const
{
int
out_dim
=
0
;
for
(
int
i
=
0
;
i
<
param_
.
X
.
size
();
++
i
)
{
out_dim
+=
param_
.
X
[
i
]
->
dims
().
count
(
1
,
param_
.
X
[
i
]
->
dims
().
size
());
}
int
seq_num
=
param_
.
X
[
0
]
->
lod
()[
0
].
size
()
-
1
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
lod
(
1
);
for
(
int
i
=
0
;
i
<
seq_num
+
1
;
++
i
)
{
lod
[
0
].
push_back
(
i
);
}
param_
.
Out
->
set_lod
(
lod
);
param_
.
Out
->
Resize
({
seq_num
,
out_dim
});
return
true
;
}
bool
SequencePoolConcatOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
input_list
=
opdesc
.
Input
(
"X"
);
param_
.
X
.
clear
();
for
(
auto
var
:
input_list
)
{
param_
.
X
.
push_back
(
scope
->
FindVar
(
var
)
->
GetMutable
<
lite
::
Tensor
>
());
}
param_
.
Out
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
param_
.
Out
)
<<
"Output(Out) of Sequence Concat Op should not be null."
;
param_
.
pool_type
=
opdesc
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
"pooltype"
);
return
true
;
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
sequence_pool_concat
,
paddle
::
lite
::
operators
::
SequencePoolConcatOp
);
lite/operators/sequence_pool_concat_op.h
0 → 100644
浏览文件 @
1b74fded
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
SequencePoolConcatOp
:
public
OpLite
{
public:
SequencePoolConcatOp
()
{}
explicit
SequencePoolConcatOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"sequence_pool_concat"
;
}
private:
mutable
SequencePoolConcatParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录