Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
26822bd7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
26822bd7
编写于
3月 20, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add sequence kernel"
上级
4ee1c9e6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
123 addition
and
70 deletion
+123
-70
paddle/fluid/operators/sequence_expand_op.cu
paddle/fluid/operators/sequence_expand_op.cu
+74
-33
paddle/fluid/operators/sequence_expand_op.h
paddle/fluid/operators/sequence_expand_op.h
+49
-37
未找到文件。
paddle/fluid/operators/sequence_expand_op.cu
浏览文件 @
26822bd7
...
@@ -21,48 +21,89 @@ namespace operators {
...
@@ -21,48 +21,89 @@ namespace operators {
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
template
<
typename
T
>
__global__
sequence_expand_kernel
(
const
T
*
x_data
,
T
*
out_data
,
size_t
*
lod
,
__global__
void
sequence_expand_kernel
(
const
T
*
x_data
,
T
*
out_data
,
size_t
element_len
)
{
const
size_t
*
lod
,
size_t
lod_size
,
int
BLOCK_SIZE
=
1024
;
size_t
element_len
)
{
__shared__
T
shm_lod
[
BLOCK_SIZE
];
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
BLOCK_SIZE
;
++
idx
)
{
for
(;
tid_x
<
static_cast
<
int
>
(
lod_size
-
1
);
shm_lod
[
idx
]
=
lod
[
idx
];
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
scale
=
lod
[
tid_x
+
1
]
-
lod
[
tid_x
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
scale
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
int
tid_z
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
int
item_start
=
tid_x
/
element_len
;
for
(;
tid_z
<
element_len
;
tid_z
+=
blockDim
.
z
*
gridDim
.
z
)
{
out_data
[
item_start
*
scale
+
tid_z
]
=
x_data
[
item_start
+
tid_z
];
}
}
}
}
for
(
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
idx
<
lod
.
size
();
}
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
scale
=
lod
[
i
]
template
<
typename
T
>
__global__
void
sequence_expand_grad_kernel
(
const
T
*
dout_data
,
T
*
dx_data
,
const
size_t
*
lod
,
size_t
lod_size
,
size_t
element_len
,
size_t
dout_size
)
{
extern
__shared__
T
shm
[];
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
tid_x
<
static_cast
<
int
>
(
lod_size
-
1
);
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
scale
=
lod
[
tid_x
+
1
]
-
lod
[
tid_x
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
scale
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
int
tid_z
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
int
item_start
=
tid_x
/
element_len
;
for
(;
tid_z
<
element_len
;
tid_z
+=
blockDim
.
z
*
gridDim
.
z
)
{
shm
[
item_start
+
tid_z
]
+=
doutx_data
[
item_start
*
scale
+
tid_z
];
}
}
}
// synchronize before write to dx
__syncthreads
();
for
(
int
idx
=
blockDimx
*
blockIdx
.
x
+
threadIdx
.
x
;
idx
<
static_cast
<
int
>
(
dout_size
);
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
dx_data
[
idx
]
=
shm
[
idx
;]
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
SequenceExpandFunctor
<
platform
::
CPUDeviceContext
,
T
>::
operator
()(
struct
SequenceExpandFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
const
platform
::
CPUDeviceContext
&
context
,
const
LoDTensor
&
x
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
LoDTensor
*
out
)
{
const
LoDTensor
&
x
,
LoDTensor
*
out
)
{
x_dims
=
x
.
dims
();
auto
x_dims
=
x
.
dims
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
auto
out_starts
=
out
->
lod
().
back
();
const
int
kThreadsPerBlock
=
1024
;
dim3
block_size
(
16
,
32
,
element_len
);
int
block_cols
=
kThreadsPerBlock
;
dim3
grid_size
(
10
,
10
);
if
(
out_cols
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
sequence_expand_kernel
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
block_cols
=
((
out_cols
+
31
)
>>
5
)
<<
5
;
x
.
data
<
T
>
(),
out
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
out_starts
.
CUDAData
(
context
.
GetPlace
()),
out_starts
.
size
(),
element_len
);
}
}
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
};
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
template
<
typename
T
>
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
struct
SequenceExpandGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
LoDTensor
&
x
,
const
LoDTensor
&
out
,
const
LoDTensor
&
dout
,
LoDTensor
*
dx
)
{
auto
x_dims
=
x
.
dims
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
int
grid_cols
=
dim3
block_size
(
16
,
32
,
element_len
);
std
::
min
((
out_cols
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
dim3
grid_size
(
10
,
10
);
int
grid_rows
=
size_t
out_size
=
framework
::
product
(
dx
->
dims
());
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
out_rows
/
block_rows
,
1
));
sequence_expand_kernel
<<<
grid_size
,
block_size
,
out_size
*
sizeof
(
T
),
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
context
.
stream
()
>>>
(
sequence_expand_kernel
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
dout
.
data
<
T
>
(),
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
x
.
data
<
T
>
(),
out
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
out_starts
.
CUDAData
(
context
.
GetPlace
()),
out_starts
.
size
(),
element_len
,
out_starts
.
CUDAData
(
context
.
GetPlace
()),
element_len
);
out_size
);
}
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/operators/sequence_expand_op.h
浏览文件 @
26822bd7
...
@@ -28,31 +28,36 @@ struct SequenceExpandFunctor {
...
@@ -28,31 +28,36 @@ struct SequenceExpandFunctor {
void
operator
()(
const
DeviceContext
&
ctx
,
const
LoDTensor
&
x
,
LoDTensor
*
out
);
void
operator
()(
const
DeviceContext
&
ctx
,
const
LoDTensor
&
x
,
LoDTensor
*
out
);
};
};
// template <typename DeviceContext, typename T>
template
<
typename
DeviceContext
,
typename
T
>
// struct SequenceExpandGradFunctor {};
struct
SequenceExpandGradFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
LoDTensor
&
x
,
const
LoDTensor
&
out
,
const
LoDTensor
&
dout
,
LoDTensor
*
dx
);
};
template
<
typename
T
>
template
<
typename
T
>
void
SequenceExpandFunctor
<
platform
::
CPUDeviceContext
,
T
>::
operator
()(
struct
SequenceExpandFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
const
platform
::
CPUDeviceContext
&
context
,
const
LoDTensor
&
x
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
LoDTensor
&
x
,
LoDTensor
*
out
)
{
LoDTensor
*
out
)
{
x_dims
=
x
.
dims
();
auto
x_dims
=
x
.
dims
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
auto
out_starts
=
out
->
lod
().
back
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
for
(
size_t
i
=
0
;
i
<
out_starts
.
size
()
-
1
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
out_starts
.
size
()
-
1
;
i
++
)
{
int
scale
=
out_starts
[
i
+
1
]
-
out_starts
[
i
];
int
scale
=
out_starts
[
i
+
1
]
-
out_starts
[
i
];
Eigen
::
TensorMap
<
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
const
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
Eigen
::
Tensor
<
const
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
x_t
(
x_data
,
1
,
element_len
);
x_t
(
x_data
,
1
,
element_len
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
out_t
(
out_data
,
scale
,
element_len
);
out_t
(
out_data
,
scale
,
element_len
);
Eigen
::
array
<
int
,
2
>
cast
({{
scale
,
1
}});
Eigen
::
array
<
int
,
2
>
cast
({{
scale
,
1
}});
out_t
.
device
(
*
context
.
eigen_device
())
=
x_t
.
broadcast
(
cast
);
out_t
.
device
(
*
context
.
eigen_device
())
=
x_t
.
broadcast
(
cast
);
x_data
+=
element_len
;
x_data
+=
element_len
;
out_data
+=
element_len
*
scale
;
out_data
+=
element_len
*
scale
;
}
}
}
}
}
;
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
SequenceExpandKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SequenceExpandKernel
:
public
framework
::
OpKernel
<
T
>
{
...
@@ -60,7 +65,6 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
...
@@ -60,7 +65,6 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
->
dims
();
auto
*
y
=
context
.
Input
<
LoDTensor
>
(
"Y"
);
auto
*
y
=
context
.
Input
<
LoDTensor
>
(
"Y"
);
PADDLE_ENFORCE
(
!
y
->
lod
().
empty
(),
"y should have lod"
);
PADDLE_ENFORCE
(
!
y
->
lod
().
empty
(),
"y should have lod"
);
...
@@ -86,19 +90,14 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
...
@@ -86,19 +90,14 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
* Grad(X).lod = Input(X).lod
* Grad(X).lod = Input(X).lod
*
*
* */
* */
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
>
class
SequenceExpandGradKernel
:
public
framework
::
OpKernel
<
T
>
{
struct
SequenceExpandGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
LoDTensor
&
x
,
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
LoDTensor
&
out
,
const
LoDTensor
&
dout
,
LoDTensor
*
dx
)
{
auto
*
d_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
out_last_level
=
out
.
lod
().
back
();
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
const
T
*
d_out_data
=
d_out
.
data
<
T
>
();
auto
*
out
=
context
.
Input
<
LoDTensor
>
(
"Out"
);
auto
*
d_x
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
out_last_level
=
out
->
lod
().
back
();
d_x
->
set_lod
(
x
->
lod
());
const
T
*
d_out_data
=
d_out
->
data
<
T
>
();
T
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
size_t
element_len
=
d_out
->
numel
()
/
d_out
->
dims
()[
0
];
size_t
element_len
=
d_out
.
numel
()
/
d_out
.
dims
()[
0
];
for
(
size_t
i
=
0
;
i
<
out_last_level
.
size
()
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
out_last_level
.
size
()
-
1
;
++
i
)
{
size_t
repeat
=
out_last_level
[
i
+
1
]
-
out_last_level
[
i
];
size_t
repeat
=
out_last_level
[
i
+
1
]
-
out_last_level
[
i
];
Eigen
::
TensorMap
<
Eigen
::
TensorMap
<
...
@@ -106,14 +105,27 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
...
@@ -106,14 +105,27 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
d_out_t
(
d_out_data
,
static_cast
<
int
>
(
repeat
),
element_len
);
d_out_t
(
d_out_data
,
static_cast
<
int
>
(
repeat
),
element_len
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
d_x_t
(
d_x_data
,
static_cast
<
int
>
(
element_len
));
d_x_t
(
d_x_data
,
static_cast
<
int
>
(
element_len
));
auto
place
=
d_x_t
.
device
(
*
context
.
eigen_device
())
=
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
d_out_t
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
d_x_t
.
device
(
*
place
)
=
d_out_t
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
d_out_data
+=
(
repeat
*
element_len
);
d_out_data
+=
(
repeat
*
element_len
);
d_x_data
+=
element_len
;
d_x_data
+=
element_len
;
}
}
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
class
SequenceExpandGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
d_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Input
<
LoDTensor
>
(
"Out"
);
auto
*
d_x
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
set_lod
(
x
->
lod
());
SequenceExpandGradFunctor
(
context
.
template
device_context
(),
*
x
,
*
out
,
d_out
,
d_x
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录