Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
37f933b8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
37f933b8
编写于
1月 08, 2018
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add gpu kernel for sequence_erase_op
上级
2a54ddd2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
138 addition
and
3 deletion
+138
-3
paddle/operators/sequence_erase_op.cu
paddle/operators/sequence_erase_op.cu
+136
-0
paddle/operators/sequence_erase_op.h
paddle/operators/sequence_erase_op.h
+2
-2
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
+0
-1
未找到文件。
paddle/operators/sequence_erase_op.cu
0 → 100644
浏览文件 @
37f933b8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include "paddle/operators/sequence_erase_op.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
__global__
void
LabelErasedIdx
(
const
T
*
in_dat
,
const
int
in_len
,
const
T
*
tokens
,
const
int
tokens_len
,
int
*
num_erased
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
in_len
)
{
int
erased
=
0
;
for
(
int
i
=
0
;
i
<
tokens_len
;
++
i
)
{
if
(
in_dat
[
index
]
==
tokens
[
i
])
{
erased
=
1
;
}
}
num_erased
[
index
+
1
]
=
erased
;
if
(
index
==
0
)
{
num_erased
[
0
]
=
0
;
}
}
}
template
<
typename
T
>
__global__
void
GetOutLod
(
const
T
*
num_erased
,
const
int
*
in_lod
,
const
int
lod_len
,
int
*
out_lod0
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
lod_len
)
{
out_lod0
[
index
]
=
in_lod
[
index
]
-
num_erased
[
in_lod
[
index
]];
}
}
template
<
typename
T
>
__global__
void
SetOutput
(
const
T
*
in_dat
,
const
int
in_len
,
const
int
*
num_erased
,
T
*
out_dat
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
in_len
)
{
if
(
in_dat
[
index
]
!=
in_dat
[
index
+
1
])
{
out_dat
[
index
-
num_erased
[
index
]]
=
in_dat
[
index
];
}
}
}
template
<
typename
T
>
class
SequenceEraseOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
lod
=
in
->
lod
();
PADDLE_ENFORCE_EQ
(
lod
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
tokens
=
ctx
.
Attr
<
std
::
vector
<
T
>>
(
"tokens"
);
auto
tokens_len
=
tokens
.
size
();
auto
in_len
=
in
->
numel
();
auto
in_dat
=
in
->
data
<
T
>
();
auto
lod0
=
lod
[
0
];
thrust
::
host_vector
<
T
>
host_tokens
(
tokens_len
);
for
(
size_t
i
=
0
;
i
<
tokens
.
size
();
++
i
)
{
host_tokens
[
i
]
=
tokens
[
i
];
}
thrust
::
device_vector
<
T
>
dev_tokens
=
host_tokens
;
thrust
::
device_vector
<
int
>
num_erased
(
in_len
+
1
);
T
*
dev_tokens_ptr
=
thrust
::
raw_pointer_cast
(
dev_tokens
.
data
());
int
*
num_erased_ptr
=
thrust
::
raw_pointer_cast
(
num_erased
.
data
());
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
LabelErasedIdx
<<<
(
in_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_dat
,
in_len
,
dev_tokens_ptr
,
tokens_len
,
num_erased_ptr
);
thrust
::
inclusive_scan
(
num_erased
.
begin
()
+
1
,
num_erased
.
end
(),
num_erased
.
begin
()
+
1
);
// Reset LoD
auto
lod_len
=
lod0
.
size
();
thrust
::
host_vector
<
int
>
host_lod
(
lod_len
);
for
(
size_t
i
=
0
;
i
<
lod_len
;
++
i
)
{
host_lod
[
i
]
=
lod0
[
i
];
}
thrust
::
device_vector
<
int
>
dev_in_lod
=
host_lod
;
thrust
::
device_vector
<
int
>
dev_out_lod
(
lod_len
);
int
*
dev_in_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_in_lod
.
data
());
int
*
dev_out_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_out_lod
.
data
());
GetOutLod
<<<
(
lod_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
num_erased_ptr
,
dev_in_lod_ptr
,
lod_len
,
dev_out_lod_ptr
);
thrust
::
host_vector
<
int
>
host_out_lod
=
dev_out_lod
;
std
::
vector
<
int
>
out_lod0
(
lod_len
,
0
);
for
(
size_t
i
=
0
;
i
<
lod_len
;
i
++
)
{
out_lod0
[
i
]
=
host_out_lod
[
i
];
}
framework
::
LoD
out_lod
;
out_lod
.
push_back
(
out_lod0
);
out
->
Resize
({
out_lod0
.
back
(),
1
});
// Set output
auto
out_dat
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
SetOutput
<<<
(
in_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_dat
,
in_len
,
num_erased_ptr
,
out_dat
);
// Set LoD
out
->
set_lod
(
out_lod
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
sequence_erase
,
paddle
::
operators
::
SequenceEraseOpCUDAKernel
<
int32_t
>
);
paddle/operators/sequence_erase_op.h
浏览文件 @
37f933b8
...
...
@@ -27,8 +27,8 @@ template <typename DeviceContext, typename T>
class
SequenceEraseKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
lod
=
in
->
lod
();
PADDLE_ENFORCE_EQ
(
lod
.
size
(),
1UL
,
"Only support one level sequence now."
);
...
...
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
浏览文件 @
37f933b8
...
...
@@ -32,7 +32,6 @@ class TestSequenceEraseOp(OpTest):
lod
=
[[
0
,
5
,
15
,
30
]]
tokens
=
[
2
,
5
]
out_seq
,
new_lod0
=
sequence_erase
(
in_seq
,
lod
[
0
],
tokens
)
self
.
attrs
=
{
'tokens'
:
tokens
}
self
.
inputs
=
{
'X'
:
(
in_seq
,
lod
)}
self
.
outputs
=
{
'Out'
:
(
out_seq
,
[
new_lod0
])}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录