Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7ef568e8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7ef568e8
编写于
10月 14, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gpu unittest error
上级
4130e5fa
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
36 addition
and
21 deletion
+36
-21
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-2
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+27
-13
paddle/operators/math/math_function_test.cu
paddle/operators/math/math_function_test.cu
+7
-6
未找到文件。
paddle/operators/math/CMakeLists.txt
浏览文件 @
7ef568e8
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator
)
nv_test
(
math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor
)
nv_test
(
math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor
selected_rows
)
nv_library
(
softmax SRCS softmax.cc softmax.cu DEPS operator
)
nv_library
(
cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator
)
nv_library
(
pooling SRCS pooling.cc pooling.cu DEPS device_context
)
nv_library
(
vol2col SRCS vol2col.cc vol2col.cu DEPS device_context
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator
)
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
cc_library
(
softmax SRCS softmax.cc DEPS operator
)
cc_library
(
cross_entropy SRCS cross_entropy.cc DEPS operator
)
cc_library
(
pooling SRCS pooling.cc DEPS device_context
)
cc_library
(
vol2col SRCS vol2col.cc DEPS device_context
)
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor selected_rows
)
cc_test
(
im2col_test SRCS im2col_test.cc DEPS math_function tensor
)
cc_test
(
vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor
)
paddle/operators/math/math_function.cu
浏览文件 @
7ef568e8
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -191,7 +192,7 @@ struct SelectedRowsAdd<platform::GPUPlace, T> {
auto
in2_place
=
input2
.
place
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
in2_place
));
auto
out_place
=
context
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
out_place
))
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
out_place
))
;
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
out_place
),
out_data
,
...
...
@@ -211,22 +212,26 @@ struct SelectedRowsAdd<platform::GPUPlace, T> {
template
struct
SelectedRowsAdd
<
platform
::
GPUPlace
,
float
>;
namespace
{
template
<
int
block_size
,
typename
T
>
__global__
void
SelectedRowsAddTensorKernel
(
T
*
selected_rows
,
int64_t
*
rows
,
T
*
tensor_in
,
T
*
tensor_out
,
const
int64_t
row_numel
)
{
const
ty
=
blockIdx
.
y
;
template
<
typename
T
>
__global__
void
SelectedRowsAddTensorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
T
*
tensor_out
,
int64_t
row_numel
,
int
block_size
)
{
const
int
ty
=
blockIdx
.
y
;
int
tid
=
threadIdx
.
x
;
selected_rows
+=
ty
*
row_numel
;
tensor_in
+=
rows
[
ty
]
*
row_numel
;
tensor_out
+=
rows
[
ty
]
*
row_numel
;
for
(
int
index
=
tid
;
index
<
row_numel
;
index
+=
block_size
)
{
tensor_out
[
index
]
=
tensor_in
[
index
]
+
selected_rows
[
index
];
// Since index in rows of SelectedRows can be duplicate, we can not use
// tensor_out[index] += selected_rows[index]; Instead, we have to use
// AtomicAdd to avoid concurrent write error.
paddle
::
platform
::
CudaAtomicAdd
(
&
tensor_out
[
index
],
selected_rows
[
index
]);
}
}
}
}
// namespace
template
<
typename
T
>
struct
SelectedRowsAddTensor
<
platform
::
GPUPlace
,
T
>
{
...
...
@@ -250,13 +255,22 @@ struct SelectedRowsAddTensor<platform::GPUPlace, T> {
auto
*
in2_data
=
input2
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
const
int
block_size
=
256
;
SetConstant
<
platform
::
GPUPlace
,
T
>
functor
;
functor
(
context
,
output
,
0.0
);
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in1_height
);
SelectedRowsAddTensorKernel
<
block_size
,
T
><<<
SelectedRowsAddTensorKernel
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
in1_data
,
in1_rows
.
data
(),
in2_data
,
out_data
,
in1_row_numel
);
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
()
>>>
(
in1_data
,
in1_rows
.
data
(),
out_data
,
in1_row_numel
,
block_size
);
auto
out_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
in2_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
input2
);
out_eigen
.
device
(
*
context
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
out_eigen
+
in2_eigen
;
}
};
...
...
paddle/operators/math/math_function_test.cu
浏览文件 @
7ef568e8
...
...
@@ -183,20 +183,21 @@ TEST(math_function, selected_rows_add) {
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
operators
::
math
;
CPUPlace
gpu_place
(
0
);
GPUPlace
gpu_place
(
0
);
CPUPlace
cpu_place
;
CUDADeviceContext
ctx
(
gpu_place
);
SetConstant
<
GPUPlace
,
float
>
functor
;
int64_t
height
=
10
;
int64_t
row_numel
=
10
;
V
ector
<
int64_t
>
rows1
{
0
,
4
,
7
};
std
::
v
ector
<
int64_t
>
rows1
{
0
,
4
,
7
};
std
::
unique_ptr
<
SelectedRows
>
selected_rows1
{
new
SelectedRows
(
rows1
,
height
)};
auto
*
in1_value
=
selected_rows1
->
mutable_value
();
in1_value
->
mutable_data
<
float
>
(
make_ddim
({
static_cast
<
int64_t
>
(
rows1
.
size
()),
row_numel
}),
gpu_place
);
functor
(
ctx
,
in1_value
,
1.0
);
V
ector
<
int64_t
>
rows2
{
0
,
5
,
7
,
9
};
std
::
v
ector
<
int64_t
>
rows2
{
0
,
5
,
7
,
9
};
std
::
unique_ptr
<
SelectedRows
>
selected_rows2
{
new
SelectedRows
(
rows2
,
height
)};
auto
*
in2_value
=
selected_rows2
->
mutable_value
();
in2_value
->
mutable_data
<
float
>
(
...
...
@@ -228,7 +229,7 @@ TEST(math_function, selected_rows_add) {
EXPECT_EQ
(
out_rows
[
6
],
9
);
Tensor
out_cpu
;
out_cpu
.
CopyFrom
<
float
>
(
*
out_value
,
platform
::
CPUPlace
()
,
ctx
);
out_cpu
.
CopyFrom
<
float
>
(
*
out_value
,
cpu_place
,
ctx
);
ctx
.
Wait
();
auto
*
out_cpu_data
=
out_cpu
.
data
<
float
>
();
...
...
@@ -256,10 +257,10 @@ TEST(math_function, selected_rows_add) {
add_tensor_functor
(
ctx
,
*
output
,
*
tensor1
,
tensor2
.
get
());
Tensor
tensor2_cpu
;
tensor2_cpu
.
CopyFrom
<
float
>
(
*
tensor2
,
platform
::
CPUPlace
()
,
ctx
);
tensor2_cpu
.
CopyFrom
<
float
>
(
*
tensor2
,
cpu_place
,
ctx
);
ctx
.
Wait
();
auto
*
tensor2_cpu_data
=
tensor2_cpu
->
data
<
float
>
();
auto
*
tensor2_cpu_data
=
tensor2_cpu
.
data
<
float
>
();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
0
*
row_numel
+
0
],
6.0
);
// row1: 3.0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录