Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8809d43a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8809d43a
编写于
1月 18, 2018
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unnecessary dtype conversion & register int64 kernels
上级
7a2aa486
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
26 addition
and
35 deletion
+26
-35
paddle/operators/sequence_erase_op.cc
paddle/operators/sequence_erase_op.cc
+2
-1
paddle/operators/sequence_erase_op.cu
paddle/operators/sequence_erase_op.cu
+8
-33
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
+16
-1
未找到文件。
paddle/operators/sequence_erase_op.cc
浏览文件 @
8809d43a
...
...
@@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
ops
::
SequenceEraseOpMaker
);
REGISTER_OP_CPU_KERNEL
(
sequence_erase
,
ops
::
SequenceEraseKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int32_t
>
);
ops
::
SequenceEraseKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int32_t
>
,
ops
::
SequenceEraseKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/operators/sequence_erase_op.cu
浏览文件 @
8809d43a
...
...
@@ -28,16 +28,12 @@ __global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
size_t
*
num_erased
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
in_len
)
{
int
erased
=
0
;
for
(
size_t
i
=
0
;
i
<
tokens_len
;
++
i
)
{
if
(
in_dat
[
index
]
==
tokens
[
i
])
{
erased
=
1
;
num_erased
[
index
+
1
]
=
1
;
break
;
}
}
num_erased
[
index
+
1
]
=
erased
;
if
(
index
==
0
)
{
num_erased
[
0
]
=
0
;
}
}
}
...
...
@@ -60,26 +56,6 @@ __global__ void SetOutput(const T* in_dat, const int64_t in_len,
}
}
template
<
typename
T
,
typename
Vector
>
thrust
::
device_vector
<
T
>
set_device_vector
(
Vector
&
vector
)
{
thrust
::
host_vector
<
T
>
host_vec
(
vector
.
size
());
for
(
size_t
i
=
0
;
i
<
vector
.
size
();
++
i
)
{
host_vec
[
i
]
=
vector
[
i
];
}
thrust
::
device_vector
<
T
>
dev_vec
=
host_vec
;
return
dev_vec
;
}
template
<
typename
T
>
std
::
vector
<
T
>
get_std_vector
(
thrust
::
device_vector
<
T
>&
dev_vec
)
{
thrust
::
host_vector
<
T
>
host_vec
=
dev_vec
;
std
::
vector
<
T
>
std_vec
(
host_vec
.
size
(),
0
);
for
(
size_t
i
=
0
;
i
<
host_vec
.
size
();
++
i
)
{
std_vec
[
i
]
=
host_vec
[
i
];
}
return
std_vec
;
}
template
<
typename
T
>
class
SequenceEraseOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -95,12 +71,11 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
auto
in_len
=
in
->
numel
();
auto
in_dat
=
in
->
data
<
T
>
();
// Copy tokens to GPU
thrust
::
device_vector
<
int
>
dev_tokens
=
set_device_vector
<
int
,
std
::
vector
<
int
>>
(
tokens
);
thrust
::
device_vector
<
int
>
dev_tokens
(
tokens
.
begin
(),
tokens
.
end
());
int
*
dev_tokens_ptr
=
thrust
::
raw_pointer_cast
(
dev_tokens
.
data
());
// Count number of elements to be erased
thrust
::
device_vector
<
size_t
>
num_erased
(
in_len
+
1
);
thrust
::
device_vector
<
size_t
>
num_erased
(
in_len
+
1
,
0
);
size_t
*
num_erased_ptr
=
thrust
::
raw_pointer_cast
(
num_erased
.
data
());
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
LabelErasedIdx
<<<
(
in_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
...
...
@@ -112,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
// Copy LoD to GPU
auto
lod0
=
lod
[
0
];
auto
lod_len
=
lod0
.
size
();
thrust
::
device_vector
<
size_t
>
dev_in_lod
=
set_device_vector
<
size_t
,
paddle
::
framework
::
Vector
<
size_t
>>
(
lod0
);
thrust
::
device_vector
<
size_t
>
dev_in_lod
=
lod0
;
size_t
*
dev_in_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_in_lod
.
data
());
// Calc output LoD
...
...
@@ -124,7 +98,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
num_erased_ptr
,
dev_in_lod_ptr
,
lod_len
,
dev_out_lod_ptr
);
// Set LoD for output
std
::
vector
<
size_t
>
out_lod0
=
get_std_vector
<
size_t
>
(
dev_out_lod
)
;
thrust
::
host_vector
<
size_t
>
out_lod0
=
dev_out_lod
;
framework
::
LoD
out_lod
;
out_lod
.
push_back
(
out_lod0
);
out
->
set_lod
(
out_lod
);
...
...
@@ -142,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
sequence_erase
,
paddle
::
operators
::
SequenceEraseOpCUDAKernel
<
int32_t
>
);
paddle
::
operators
::
SequenceEraseOpCUDAKernel
<
int32_t
>
,
paddle
::
operators
::
SequenceEraseOpCUDAKernel
<
int64_t
>
);
python/paddle/v2/fluid/tests/test_sequence_erase_op.py
浏览文件 @
8809d43a
...
...
@@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens):
return
np
.
array
(
out_seq
).
astype
(
"int32"
),
new_lod0
class
TestSequenceEraseOp
(
OpTest
):
class
TestSequenceEraseOp
Int32
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"sequence_erase"
in_seq
=
np
.
random
.
randint
(
0
,
10
,
(
30
,
1
)).
astype
(
"int32"
)
...
...
@@ -44,6 +44,21 @@ class TestSequenceEraseOp(OpTest):
self
.
check_output
()
class
TestSequenceEraseOpInt64
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"sequence_erase"
in_seq
=
np
.
random
.
randint
(
0
,
10
,
(
30
,
1
)).
astype
(
"int64"
)
lod
=
[[
0
,
9
,
13
,
24
,
30
]]
tokens
=
[
2
,
3
,
5
]
out_seq
,
new_lod0
=
sequence_erase
(
in_seq
,
lod
[
0
],
tokens
)
self
.
attrs
=
{
'tokens'
:
tokens
}
self
.
inputs
=
{
'X'
:
(
in_seq
,
lod
)}
self
.
outputs
=
{
'Out'
:
(
out_seq
,
[
new_lod0
])}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestSequenceEraseOpEmpty
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"sequence_erase"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录