Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
4ee1c9e6
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4ee1c9e6
编写于
3月 19, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add sequence expand kernel"
上级
b3f076a6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
86 addition
and
19 deletion
+86
-19
paddle/fluid/operators/sequence_expand_op.cu
paddle/fluid/operators/sequence_expand_op.cu
+52
-0
paddle/fluid/operators/sequence_expand_op.h
paddle/fluid/operators/sequence_expand_op.h
+34
-19
未找到文件。
paddle/fluid/operators/sequence_expand_op.cu
浏览文件 @
4ee1c9e6
...
...
@@ -15,6 +15,58 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/sequence_expand_op.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
__global__
sequence_expand_kernel
(
const
T
*
x_data
,
T
*
out_data
,
size_t
*
lod
,
size_t
element_len
)
{
int
BLOCK_SIZE
=
1024
;
__shared__
T
shm_lod
[
BLOCK_SIZE
];
for
(
int
idx
=
threadIdx
.
x
;
idx
<
BLOCK_SIZE
;
++
idx
)
{
shm_lod
[
idx
]
=
lod
[
idx
];
}
for
(
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
idx
<
lod
.
size
();
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
scale
=
lod
[
i
]
}
}
template
<
typename
T
>
void
SequenceExpandFunctor
<
platform
::
CPUDeviceContext
,
T
>::
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
LoDTensor
&
x
,
LoDTensor
*
out
)
{
x_dims
=
x
.
dims
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
kThreadsPerBlock
;
if
(
out_cols
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
out_cols
+
31
)
>>
5
)
<<
5
;
}
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
std
::
min
((
out_cols
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
out_rows
/
block_rows
,
1
));
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
sequence_expand_kernel
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
x
.
data
<
T
>
(),
out
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
out_starts
.
CUDAData
(
context
.
GetPlace
()),
element_len
);
}
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sequence_expand
,
...
...
paddle/fluid/operators/sequence_expand_op.h
浏览文件 @
4ee1c9e6
...
...
@@ -16,13 +16,44 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "
unsupported/Eigen/CXX11/Tensor
"
#include "
paddle/fluid/platform/device_context.h
"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
DeviceContext
,
typename
T
>
struct
SequenceExpandFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
LoDTensor
&
x
,
LoDTensor
*
out
);
};
// template <typename DeviceContext, typename T>
// struct SequenceExpandGradFunctor {};
template
<
typename
T
>
void
SequenceExpandFunctor
<
platform
::
CPUDeviceContext
,
T
>::
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
LoDTensor
&
x
,
LoDTensor
*
out
)
{
x_dims
=
x
.
dims
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
for
(
size_t
i
=
0
;
i
<
out_starts
.
size
()
-
1
;
i
++
)
{
int
scale
=
out_starts
[
i
+
1
]
-
out_starts
[
i
];
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
const
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
x_t
(
x_data
,
1
,
element_len
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
out_t
(
out_data
,
scale
,
element_len
);
Eigen
::
array
<
int
,
2
>
cast
({{
scale
,
1
}});
out_t
.
device
(
*
context
.
eigen_device
())
=
x_t
.
broadcast
(
cast
);
x_data
+=
element_len
;
out_data
+=
element_len
*
scale
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
SequenceExpandKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -38,24 +69,8 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X)."
);
out
->
set_lod
(
y
->
lod
());
auto
*
place
=
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
for
(
size_t
i
=
0
;
i
<
out_starts
.
size
()
-
1
;
i
++
)
{
int
scale
=
out_starts
[
i
+
1
]
-
out_starts
[
i
];
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
const
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
x_t
(
x_data
,
1
,
element_len
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
out_t
(
out_data
,
scale
,
element_len
);
Eigen
::
array
<
int
,
2
>
cast
({{
scale
,
1
}});
out_t
.
device
(
*
place
)
=
x_t
.
broadcast
(
cast
);
x_data
+=
element_len
;
out_data
+=
element_len
*
scale
;
}
SequenceExpandFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
context
.
template
device_context
<
DeviceContext
>(),
*
x
,
out
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录