Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
fbdb5b7b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fbdb5b7b
编写于
3月 29, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"fix based on comment"
上级
a80bf702
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
32 addition
and
36 deletion
+32
-36
paddle/fluid/operators/sequence_expand_op.cu
paddle/fluid/operators/sequence_expand_op.cu
+32
-36
未找到文件。
paddle/fluid/operators/sequence_expand_op.cu
浏览文件 @
fbdb5b7b
...
...
@@ -25,27 +25,17 @@ using LoDTensor = framework::LoDTensor;
template
<
typename
T
>
__global__
void
sequence_expand_kernel
(
const
T
*
x_data
,
const
size_t
*
x_lod
,
const
size_t
*
ref_lod
,
const
size_t
*
offset
,
const
size_t
lod_size
,
/* default=1,
the instance length*/
const
int
x_item_length
,
T
*
out_data
)
{
constexpr
int
N
=
1024
;
__shared__
int
mem
[
N
];
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
lod_size
;
++
i
)
{
mem
[
i
]
=
offset
;
if
(
i
<
lod_size
-
1
)
{
offset
+=
(
ref_lod
[
i
+
1
]
-
ref_lod
[
i
])
*
(
x_lod
[
i
+
1
]
-
x_lod
[
i
]);
}
}
__syncthreads
();
int
bid
=
blockIdx
.
x
;
if
(
bid
>=
lod_size
-
1
)
return
;
int
x_item_count
=
x_lod
[
bid
+
1
]
-
x_lod
[
bid
];
int
repeats
=
ref_lod
[
bid
+
1
]
-
ref_lod
[
bid
];
int
out_offset
=
mem
[
bid
]
;
int
out_offset
=
static_cast
<
int
>
(
offset
[
bid
])
;
int
x_offset
=
x_lod
[
bid
];
for
(
int
tid_z
=
threadIdx
.
z
;
tid_z
<
repeats
;
tid_z
+=
blockDim
.
z
)
{
for
(
int
tid_y
=
threadIdx
.
y
;
tid_y
<
x_item_count
;
tid_y
+=
blockDim
.
y
)
{
...
...
@@ -59,32 +49,17 @@ __global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
}
template
<
typename
T
>
__global__
void
sequence_expand_grad_kernel
(
const
T
*
dout_data
,
const
size_t
*
ref_lod
,
const
size_t
*
dx_lod
,
const
size_t
lod_size
,
__global__
void
sequence_expand_grad_kernel
(
const
T
*
dout_data
,
const
size_t
*
ref_lod
,
const
size_t
*
dx_lod
,
const
size_t
*
offset
,
const
size_t
lod_size
,
/* default=1,
the instance length*/
const
int
x_item_length
,
T
*
dx_data
)
{
// TODO(dzhwinter) : too many atomicAdd
// use shared memory to reduce memory visits
constexpr
int
N
=
1024
;
__shared__
int
mem
[
N
];
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
lod_size
;
++
i
)
{
mem
[
i
]
=
offset
;
if
(
i
<
lod_size
-
1
)
{
offset
+=
(
ref_lod
[
i
+
1
]
-
ref_lod
[
i
])
*
(
dx_lod
[
i
+
1
]
-
dx_lod
[
i
]);
}
}
__syncthreads
();
const
int
x_item_length
,
T
*
dx_data
)
{
int
bid
=
blockIdx
.
x
;
if
(
bid
>=
lod_size
-
1
)
return
;
int
x_item_count
=
dx_lod
[
bid
+
1
]
-
dx_lod
[
bid
];
int
repeats
=
ref_lod
[
bid
+
1
]
-
ref_lod
[
bid
];
int
out_offset
=
mem
[
bid
]
;
int
out_offset
=
static_cast
<
int
>
(
offset
[
bid
])
;
int
x_offset
=
dx_lod
[
bid
];
for
(
int
tid_z
=
threadIdx
.
z
;
tid_z
<
repeats
;
tid_z
+=
blockDim
.
z
)
{
...
...
@@ -101,6 +76,19 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data,
}
}
void
GetOutputOffset
(
const
framework
::
Vector
<
size_t
>&
x_lod
,
const
framework
::
Vector
<
size_t
>&
ref_lod
,
framework
::
Vector
<
size_t
>&
out_offset
)
{
size_t
offset
=
0
;
int
lod_size
=
static_cast
<
int
>
(
x_lod
.
size
());
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
x_lod
.
size
());
++
i
)
{
out_offset
[
i
]
=
offset
;
if
(
i
<
lod_size
-
1
)
{
offset
+=
(
ref_lod
[
i
+
1
]
-
ref_lod
[
i
])
*
(
x_lod
[
i
+
1
]
-
x_lod
[
i
]);
}
}
}
template
<
typename
T
>
struct
SequenceExpandFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
...
...
@@ -109,6 +97,9 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
const
framework
::
Vector
<
size_t
>&
ref_lod
,
/*expand referenced lod*/
LoDTensor
*
out
)
{
int
x_item_length
=
x
.
numel
()
/
x
.
dims
()[
0
];
framework
::
Vector
<
size_t
>
out_offset
(
x_lod
.
size
());
GetOutputOffset
(
x_lod
,
ref_lod
,
out_offset
);
int
thread_x
=
std
::
min
(
32
,
std
::
max
(
static_cast
<
int
>
(
ref_lod
.
size
()),
16
));
int
thread_y
=
16
;
int
thread_z
=
1024
/
thread_x
/
thread_y
;
...
...
@@ -118,7 +109,8 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
sequence_expand_kernel
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
x
.
data
<
T
>
(),
x_lod
.
CUDAData
(
context
.
GetPlace
()),
ref_lod
.
CUDAData
(
context
.
GetPlace
()),
x_lod
.
size
(),
x_item_length
,
ref_lod
.
CUDAData
(
context
.
GetPlace
()),
out_offset
.
CUDAData
(
context
.
GetPlace
()),
x_lod
.
size
(),
x_item_length
,
out
->
mutable_data
<
T
>
(
context
.
GetPlace
()));
}
};
...
...
@@ -131,6 +123,9 @@ struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
const
framework
::
Vector
<
size_t
>&
ref_lod
,
/*expand based lod*/
LoDTensor
*
dx
)
{
int
x_item_length
=
framework
::
product
(
dx
->
dims
())
/
dx
->
dims
()[
0
];
framework
::
Vector
<
size_t
>
out_offset
(
x_lod
.
size
());
GetOutputOffset
(
x_lod
,
ref_lod
,
out_offset
);
int
thread_x
=
std
::
min
(
32
,
std
::
max
(
static_cast
<
int
>
(
ref_lod
.
size
()),
16
));
int
thread_y
=
16
;
int
thread_z
=
1024
/
thread_x
/
thread_y
;
...
...
@@ -139,7 +134,8 @@ struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
dim3
grid_size
(
block_x
,
1
);
sequence_expand_grad_kernel
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
dout
.
data
<
T
>
(),
ref_lod
.
CUDAData
(
context
.
GetPlace
()),
x_lod
.
CUDAData
(
context
.
GetPlace
()),
ref_lod
.
size
(),
x_item_length
,
x_lod
.
CUDAData
(
context
.
GetPlace
()),
out_offset
.
CUDAData
(
context
.
GetPlace
()),
ref_lod
.
size
(),
x_item_length
,
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
()));
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录