Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e4c35d83
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e4c35d83
编写于
3月 20, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add details"
上级
26822bd7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
19 addition
and
18 deletion
+19
-18
paddle/fluid/operators/sequence_expand_op.cu
paddle/fluid/operators/sequence_expand_op.cu
+9
-10
paddle/fluid/operators/sequence_expand_op.h
paddle/fluid/operators/sequence_expand_op.h
+10
-8
未找到文件。
paddle/fluid/operators/sequence_expand_op.cu
浏览文件 @
e4c35d83
...
...
@@ -54,15 +54,15 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data,
int
tid_z
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
int
item_start
=
tid_x
/
element_len
;
for
(;
tid_z
<
element_len
;
tid_z
+=
blockDim
.
z
*
gridDim
.
z
)
{
shm
[
item_start
+
tid_z
]
+=
dout
x
_data
[
item_start
*
scale
+
tid_z
];
shm
[
item_start
+
tid_z
]
+=
dout_data
[
item_start
*
scale
+
tid_z
];
}
}
}
// synchronize before write to dx
__syncthreads
();
for
(
int
idx
=
blockDimx
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
idx
<
static_cast
<
int
>
(
dout_size
);
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
dx_data
[
idx
]
=
shm
[
idx
;]
dx_data
[
idx
]
=
shm
[
idx
];
}
}
...
...
@@ -86,19 +86,18 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
template
<
typename
T
>
struct
SequenceExpandGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
LoDTensor
&
x
,
const
LoDTensor
&
out
,
const
LoDTensor
&
dout
,
LoDTensor
*
dx
)
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
LoDTensor
&
x
,
const
LoDTensor
&
out
,
const
LoDTensor
&
dout
,
LoDTensor
*
dx
)
{
auto
x_dims
=
x
.
dims
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
auto
out_starts
=
out
.
lod
().
back
();
dim3
block_size
(
16
,
32
,
element_len
);
dim3
grid_size
(
10
,
10
);
size_t
out_size
=
framework
::
product
(
dx
->
dims
());
sequence_expand_kernel
<<<
grid_size
,
block_size
,
out_size
*
sizeof
(
T
),
context
.
stream
()
>>>
(
sequence_expand_
grad_
kernel
<<<
grid_size
,
block_size
,
out_size
*
sizeof
(
T
),
context
.
stream
()
>>>
(
dout
.
data
<
T
>
(),
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
out_starts
.
CUDAData
(
context
.
GetPlace
()),
out_starts
.
size
(),
element_len
,
out_size
);
...
...
paddle/fluid/operators/sequence_expand_op.h
浏览文件 @
e4c35d83
...
...
@@ -40,7 +40,7 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
LoDTensor
*
out
)
{
auto
x_dims
=
x
.
dims
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
...
...
@@ -92,12 +92,12 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
* */
template
<
typename
T
>
struct
SequenceExpandGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
c
tx
,
const
LoDTensor
&
x
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
c
ontext
,
const
LoDTensor
&
x
,
const
LoDTensor
&
out
,
const
LoDTensor
&
dout
,
LoDTensor
*
dx
)
{
auto
out_last_level
=
out
.
lod
().
back
();
const
T
*
d_out_data
=
d
_
out
.
data
<
T
>
();
T
*
d_x_data
=
d
_
x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
size_t
element_len
=
d
_out
.
numel
()
/
d_
out
.
dims
()[
0
];
const
T
*
d_out_data
=
dout
.
data
<
T
>
();
T
*
d_x_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
size_t
element_len
=
d
out
.
numel
()
/
d
out
.
dims
()[
0
];
for
(
size_t
i
=
0
;
i
<
out_last_level
.
size
()
-
1
;
++
i
)
{
size_t
repeat
=
out_last_level
[
i
+
1
]
-
out_last_level
[
i
];
Eigen
::
TensorMap
<
...
...
@@ -117,13 +117,15 @@ template <typename DeviceContext, typename T>
class
SequenceExpandGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
d_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Input
<
LoDTensor
>
(
"Out"
);
auto
*
d_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
set_lod
(
x
->
lod
());
SequenceExpandGradFunctor
(
context
.
template
device_context
(),
*
x
,
*
out
,
d_out
,
d_x
);
SequenceExpandGradFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
context
.
template
device_context
<
DeviceContext
>(),
*
x
,
*
out
,
*
d_out
,
d_x
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录