Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
00ad7512
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
00ad7512
编写于
10月 20, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use stream while memory::Copy in GPU mode
上级
74b283c9
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
31 addition
and
9 deletion
+31
-9
paddle/operators/seq_expand_op.cc
paddle/operators/seq_expand_op.cc
+1
-1
paddle/operators/seq_expand_op.h
paddle/operators/seq_expand_op.h
+30
-8
未找到文件。
paddle/operators/seq_expand_op.cc
浏览文件 @
00ad7512
...
@@ -40,7 +40,7 @@ class SeqExpandOp : public framework::OperatorWithKernel {
...
@@ -40,7 +40,7 @@ class SeqExpandOp : public framework::OperatorWithKernel {
out_dim
[
0
]
=
out_dim
[
0
]
*
repeat
;
out_dim
[
0
]
=
out_dim
[
0
]
*
repeat
;
}
}
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of
Pa
dOp should not be null."
);
"Output(Out) of
SeqExpan
dOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
out_dim
);
ctx
->
SetOutputDim
(
"Out"
,
out_dim
);
}
}
};
};
...
...
paddle/operators/seq_expand_op.h
浏览文件 @
00ad7512
...
@@ -75,16 +75,38 @@ class SeqExpandKernel : public framework::OpKernel<T> {
...
@@ -75,16 +75,38 @@ class SeqExpandKernel : public framework::OpKernel<T> {
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// copy data
// copy data
Place
place
=
boost
::
get
<
Place
>
(
context
.
GetPlace
()
);
auto
place
=
context
.
GetPlace
(
);
size_t
count
=
0
;
size_t
count
=
0
;
if
(
platform
::
is_cpu_place
(
place
))
{
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
place
);
for
(
size_t
i
=
0
;
i
<
scales
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
scales
.
size
();
++
i
)
{
count
=
element_len
*
(
x_lod
[
0
][
i
+
1
]
-
x_lod
[
0
][
i
]);
count
=
element_len
*
(
x_lod
[
0
][
i
+
1
]
-
x_lod
[
0
][
i
]);
for
(
size_t
j
=
0
;
j
<
scales
[
i
];
++
j
)
{
for
(
size_t
j
=
0
;
j
<
scales
[
i
];
++
j
)
{
memory
::
Copy
(
place
,
out_data
,
place
,
x_data
,
sizeof
(
T
)
*
count
);
memory
::
Copy
(
cpu_place
,
out_data
,
cpu_place
,
x_data
,
sizeof
(
T
)
*
count
);
out_data
+=
count
;
out_data
+=
count
;
}
}
x_data
+=
count
;
x_data
+=
count
;
}
}
}
else
{
#ifdef PADDLE_WITH_CUDA
auto
&
gpu_place
=
boost
::
get
<
platform
::
GPUPlace
>
(
place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
();
for
(
size_t
i
=
0
;
i
<
scales
.
size
();
++
i
)
{
count
=
element_len
*
(
x_lod
[
0
][
i
+
1
]
-
x_lod
[
0
][
i
]);
for
(
size_t
j
=
0
;
j
<
scales
[
i
];
++
j
)
{
memory
::
Copy
(
gpu_place
,
out_data
,
gpu_place
,
x_data
,
sizeof
(
T
)
*
count
,
stream
);
out_data
+=
count
;
}
x_data
+=
count
;
}
#else
PADDLE_THROW
(
"Paddle is not compiled with GPU"
);
#endif
}
out
->
set_lod
(
out_lod
);
out
->
set_lod
(
out_lod
);
}
}
...
@@ -113,7 +135,7 @@ class SeqExpandGradKernel : public framework::OpKernel<T> {
...
@@ -113,7 +135,7 @@ class SeqExpandGradKernel : public framework::OpKernel<T> {
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
>>
d_x_t
(
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
>>
d_x_t
(
d_x_data
,
static_cast
<
int
>
((
ele_count
*
element_len
)
/
repeat
));
d_x_data
,
static_cast
<
int
>
((
ele_count
*
element_len
)
/
repeat
));
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
d_x_t
.
device
(
place
)
=
d_out_t
.
sum
(
Eigen
::
array
<
int
,
1
>
({
0
}));
d_x_t
.
device
(
place
)
=
d_out_t
.
sum
(
Eigen
::
array
<
int
,
1
>
({
{
0
}
}));
d_out_data
+=
(
ele_count
*
element_len
);
d_out_data
+=
(
ele_count
*
element_len
);
d_x_data
+=
((
ele_count
*
element_len
)
/
repeat
);
d_x_data
+=
((
ele_count
*
element_len
)
/
repeat
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录