Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
a1c281f0
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1c281f0
编写于
1月 19, 2018
作者:
Y
Yibing Liu
提交者:
GitHub
1月 19, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7603 from kuke/simplify_erase
Enhance GPU kernel of sequence erase op
上级
41b83884
8809d43a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
66 addition
and
48 deletion
+66
-48
paddle/operators/sequence_erase_op.cc
paddle/operators/sequence_erase_op.cc
+2
-1
paddle/operators/sequence_erase_op.cu
paddle/operators/sequence_erase_op.cu
+33
-46
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
+31
-1
未找到文件。
paddle/operators/sequence_erase_op.cc
浏览文件 @
a1c281f0
...
...
@@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
ops
::
SequenceEraseOpMaker
);
REGISTER_OP_CPU_KERNEL
(
sequence_erase
,
ops
::
SequenceEraseKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int32_t
>
);
ops
::
SequenceEraseKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int32_t
>
,
ops
::
SequenceEraseKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/operators/sequence_erase_op.cu
浏览文件 @
a1c281f0
...
...
@@ -23,27 +23,22 @@ using platform::PADDLE_CUDA_NUM_THREADS;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
__global__
void
LabelErasedIdx
(
const
T
*
in_dat
,
const
int
in_len
,
const
T
*
tokens
,
const
in
t
tokens_len
,
in
t
*
num_erased
)
{
__global__
void
LabelErasedIdx
(
const
T
*
in_dat
,
const
int
64_t
in_len
,
const
int
*
tokens
,
const
size_
t
tokens_len
,
size_
t
*
num_erased
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
in_len
)
{
int
erased
=
0
;
for
(
int
i
=
0
;
i
<
tokens_len
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tokens_len
;
++
i
)
{
if
(
in_dat
[
index
]
==
tokens
[
i
])
{
erased
=
1
;
num_erased
[
index
+
1
]
=
1
;
break
;
}
}
num_erased
[
index
+
1
]
=
erased
;
if
(
index
==
0
)
{
num_erased
[
0
]
=
0
;
}
}
}
template
<
typename
T
>
__global__
void
GetOutLod
(
const
T
*
num_erased
,
const
int
*
in_lod
,
const
int
lod_len
,
int
*
out_lod0
)
{
__global__
void
GetOutLod
(
const
size_t
*
num_erased
,
const
size_t
*
in_lod
,
const
size_t
lod_len
,
size_t
*
out_lod0
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
lod_len
)
{
out_lod0
[
index
]
=
in_lod
[
index
]
-
num_erased
[
in_lod
[
index
]];
...
...
@@ -51,11 +46,11 @@ __global__ void GetOutLod(const T* num_erased, const int* in_lod,
}
template
<
typename
T
>
__global__
void
SetOutput
(
const
T
*
in_dat
,
const
int
in_len
,
const
in
t
*
num_erased
,
T
*
out_dat
)
{
__global__
void
SetOutput
(
const
T
*
in_dat
,
const
int
64_t
in_len
,
const
size_
t
*
num_erased
,
T
*
out_dat
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
in_len
)
{
if
(
in_dat
[
index
]
!=
in_dat
[
index
+
1
])
{
if
(
num_erased
[
index
]
==
num_erased
[
index
+
1
])
{
out_dat
[
index
-
num_erased
[
index
]]
=
in_dat
[
index
];
}
}
...
...
@@ -72,53 +67,44 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
lod
.
size
(),
1UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
lod
[
0
].
back
(),
(
size_t
)
in
->
numel
(),
"The actual size mismatches with the LoD information."
);
auto
tokens
=
ctx
.
Attr
<
std
::
vector
<
T
>>
(
"tokens"
);
auto
tokens_len
=
tokens
.
size
();
auto
tokens
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"tokens"
);
auto
in_len
=
in
->
numel
();
auto
in_dat
=
in
->
data
<
T
>
();
auto
lod0
=
lod
[
0
];
thrust
::
host_vector
<
T
>
host_tokens
(
tokens_len
);
for
(
size_t
i
=
0
;
i
<
tokens
.
size
();
++
i
)
{
host_tokens
[
i
]
=
tokens
[
i
];
}
thrust
::
device_vector
<
T
>
dev_tokens
=
host_tokens
;
thrust
::
device_vector
<
int
>
num_erased
(
in_len
+
1
);
T
*
dev_tokens_ptr
=
thrust
::
raw_pointer_cast
(
dev_tokens
.
data
());
int
*
num_erased_ptr
=
thrust
::
raw_pointer_cast
(
num_erased
.
data
());
// Copy tokens to GPU
thrust
::
device_vector
<
int
>
dev_tokens
(
tokens
.
begin
(),
tokens
.
end
());
int
*
dev_tokens_ptr
=
thrust
::
raw_pointer_cast
(
dev_tokens
.
data
());
// Count number of elements to be erased
thrust
::
device_vector
<
size_t
>
num_erased
(
in_len
+
1
,
0
);
size_t
*
num_erased_ptr
=
thrust
::
raw_pointer_cast
(
num_erased
.
data
());
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
LabelErasedIdx
<<<
(
in_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_dat
,
in_len
,
dev_tokens_ptr
,
tokens
_len
,
num_erased_ptr
);
in_dat
,
in_len
,
dev_tokens_ptr
,
tokens
.
size
()
,
num_erased_ptr
);
thrust
::
inclusive_scan
(
num_erased
.
begin
()
+
1
,
num_erased
.
end
(),
num_erased
.
begin
()
+
1
);
// Calc LoD
// Copy LoD to GPU
auto
lod0
=
lod
[
0
];
auto
lod_len
=
lod0
.
size
();
thrust
::
host_vector
<
int
>
host_lod
(
lod_len
);
for
(
size_t
i
=
0
;
i
<
lod_len
;
++
i
)
{
host_lod
[
i
]
=
lod0
[
i
];
}
thrust
::
device_vector
<
int
>
dev_in_lod
=
host_lod
;
thrust
::
device_vector
<
int
>
dev_out_lod
(
lod_len
);
int
*
dev_in_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_in_lod
.
data
());
int
*
dev_out_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_out_lod
.
data
());
thrust
::
device_vector
<
size_t
>
dev_in_lod
=
lod0
;
size_t
*
dev_in_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_in_lod
.
data
());
// Calc output LoD
thrust
::
device_vector
<
size_t
>
dev_out_lod
(
lod_len
);
size_t
*
dev_out_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_out_lod
.
data
());
GetOutLod
<<<
(
lod_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
num_erased_ptr
,
dev_in_lod_ptr
,
lod_len
,
dev_out_lod_ptr
);
thrust
::
host_vector
<
int
>
host_out_lod
=
dev_out_lod
;
std
::
vector
<
int
>
out_lod0
(
lod_len
,
0
);
for
(
size_t
i
=
0
;
i
<
lod_len
;
i
++
)
{
out_lod0
[
i
]
=
host_out_lod
[
i
];
}
// Set LoD for output
thrust
::
host_vector
<
size_t
>
out_lod0
=
dev_out_lod
;
framework
::
LoD
out_lod
;
out_lod
.
push_back
(
out_lod0
);
out
->
set_lod
(
out_lod
);
// Set output
out
->
Resize
({
out_lod0
.
back
(
),
1
});
out
->
Resize
({
static_cast
<
int64_t
>
(
out_lod0
.
back
()
),
1
});
auto
out_dat
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
SetOutput
<<<
(
in_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_dat
,
in_len
,
...
...
@@ -130,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
sequence_erase
,
paddle
::
operators
::
SequenceEraseOpCUDAKernel
<
int32_t
>
);
paddle
::
operators
::
SequenceEraseOpCUDAKernel
<
int32_t
>
,
paddle
::
operators
::
SequenceEraseOpCUDAKernel
<
int64_t
>
);
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
浏览文件 @
a1c281f0
...
...
@@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens):
return
np
.
array
(
out_seq
).
astype
(
"int32"
),
new_lod0
class
TestSequenceEraseOp
(
OpTest
):
class
TestSequenceEraseOp
Int32
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"sequence_erase"
in_seq
=
np
.
random
.
randint
(
0
,
10
,
(
30
,
1
)).
astype
(
"int32"
)
...
...
@@ -44,5 +44,35 @@ class TestSequenceEraseOp(OpTest):
self
.
check_output
()
class
TestSequenceEraseOpInt64
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"sequence_erase"
in_seq
=
np
.
random
.
randint
(
0
,
10
,
(
30
,
1
)).
astype
(
"int64"
)
lod
=
[[
0
,
9
,
13
,
24
,
30
]]
tokens
=
[
2
,
3
,
5
]
out_seq
,
new_lod0
=
sequence_erase
(
in_seq
,
lod
[
0
],
tokens
)
self
.
attrs
=
{
'tokens'
:
tokens
}
self
.
inputs
=
{
'X'
:
(
in_seq
,
lod
)}
self
.
outputs
=
{
'Out'
:
(
out_seq
,
[
new_lod0
])}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestSequenceEraseOpEmpty
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"sequence_erase"
in_seq
=
np
.
random
.
randint
(
0
,
10
,
(
30
,
1
)).
astype
(
"int32"
)
lod
=
[[
0
,
9
,
13
,
24
,
30
]]
tokens
=
[]
out_seq
,
new_lod0
=
sequence_erase
(
in_seq
,
lod
[
0
],
tokens
)
self
.
attrs
=
{
'tokens'
:
tokens
}
self
.
inputs
=
{
'X'
:
(
in_seq
,
lod
)}
self
.
outputs
=
{
'Out'
:
(
out_seq
,
[
new_lod0
])}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录