Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
86acf39c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
86acf39c
编写于
10月 16, 2017
作者:
Q
QI JUN
提交者:
GitHub
10月 16, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4801 from QiJune/add_selected_rows_functor
add some basic math functor for SelectedRows
上级
240a37ee
5993497c
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
732 addition
and
196 deletion
+732
-196
paddle/framework/selected_rows.h
paddle/framework/selected_rows.h
+7
-3
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+2
-1
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+2
-1
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+6
-2
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+2
-0
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+2
-0
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+8
-5
paddle/operators/math/math_function_test.cc
paddle/operators/math/math_function_test.cc
+4
-183
paddle/operators/math/math_function_test.cu
paddle/operators/math/math_function_test.cu
+179
-0
paddle/operators/math/selected_rows_functor.cc
paddle/operators/math/selected_rows_functor.cc
+114
-0
paddle/operators/math/selected_rows_functor.cu
paddle/operators/math/selected_rows_functor.cu
+142
-0
paddle/operators/math/selected_rows_functor.h
paddle/operators/math/selected_rows_functor.h
+41
-0
paddle/operators/math/selected_rows_functor_test.cc
paddle/operators/math/selected_rows_functor_test.cc
+106
-0
paddle/operators/math/selected_rows_functor_test.cu
paddle/operators/math/selected_rows_functor_test.cu
+115
-0
paddle/operators/sequence_pool_op.h
paddle/operators/sequence_pool_op.h
+2
-1
未找到文件。
paddle/framework/selected_rows.h
浏览文件 @
86acf39c
...
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
namespace
paddle
{
...
...
@@ -34,9 +35,9 @@ class SelectedRows {
void
set_height
(
int64_t
height
)
{
height_
=
height
;
}
const
std
::
v
ector
<
int64_t
>&
rows
()
const
{
return
rows_
;
}
const
V
ector
<
int64_t
>&
rows
()
const
{
return
rows_
;
}
void
set_rows
(
const
std
::
v
ector
<
int64_t
>&
rows
)
{
rows_
=
rows
;
}
void
set_rows
(
const
V
ector
<
int64_t
>&
rows
)
{
rows_
=
rows
;
}
DDim
GetCompleteDims
()
const
{
std
::
vector
<
int64_t
>
dims
=
vectorize
(
value_
->
dims
());
...
...
@@ -45,7 +46,10 @@ class SelectedRows {
}
private:
std
::
vector
<
int64_t
>
rows_
;
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simplely concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
Vector
<
int64_t
>
rows_
;
std
::
unique_ptr
<
Tensor
>
value_
{
nullptr
};
int64_t
height_
;
};
...
...
paddle/operators/cross_entropy_op.cu
浏览文件 @
86acf39c
...
...
@@ -91,7 +91,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
else
{
math
::
SetConstant
<
platform
::
GPUPlace
,
T
>
(
ctx
.
device_context
(),
dx
,
0
);
math
::
SetConstant
<
platform
::
GPUPlace
,
T
>
functor
;
functor
(
ctx
.
device_context
(),
dx
,
0
);
auto
*
label_data
=
label
->
data
<
int
>
();
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyGradientKernel
<
T
><<<
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
86acf39c
...
...
@@ -70,7 +70,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
*
label_data
=
label
->
data
<
int
>
();
math
::
SetConstant
<
platform
::
CPUPlace
,
T
>
(
ctx
.
device_context
(),
dx
,
0
);
math
::
SetConstant
<
platform
::
CPUPlace
,
T
>
functor
;
functor
(
ctx
.
device_context
(),
dx
,
0
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
...
...
paddle/operators/math/CMakeLists.txt
浏览文件 @
86acf39c
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator
)
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
nv_test
(
math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor
)
nv_library
(
selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function
)
nv_test
(
selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor
)
nv_library
(
softmax SRCS softmax.cc softmax.cu DEPS operator
)
nv_library
(
cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator
)
nv_library
(
pooling SRCS pooling.cc pooling.cu DEPS device_context
)
nv_library
(
vol2col SRCS vol2col.cc vol2col.cu DEPS device_context
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator
)
cc_
test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
cc_
library
(
selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function
)
cc_library
(
softmax SRCS softmax.cc DEPS operator
)
cc_library
(
cross_entropy SRCS cross_entropy.cc DEPS operator
)
cc_library
(
pooling SRCS pooling.cc DEPS device_context
)
cc_library
(
vol2col SRCS vol2col.cc DEPS device_context
)
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
cc_test
(
selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor
)
cc_test
(
im2col_test SRCS im2col_test.cc DEPS math_function tensor
)
cc_test
(
vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor
)
paddle/operators/math/math_function.cc
浏览文件 @
86acf39c
...
...
@@ -130,6 +130,8 @@ void matmul<platform::CPUPlace, double>(
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
template
struct
SetConstant
<
platform
::
CPUPlace
,
float
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.cu
浏览文件 @
86acf39c
...
...
@@ -155,6 +155,8 @@ void matmul<platform::GPUPlace, double>(
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
template
struct
SetConstant
<
platform
::
GPUPlace
,
float
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.h
浏览文件 @
86acf39c
...
...
@@ -86,11 +86,14 @@ void matmul(const platform::DeviceContext& context,
framework
::
Tensor
*
matrix_out
,
T
beta
);
template
<
typename
Place
,
typename
T
>
void
SetConstant
(
const
platform
::
DeviceContext
&
context
,
struct
SetConstant
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
T
num
)
{
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
t
.
device
(
*
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
num
));
}
t
.
device
(
*
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
num
));
}
};
}
// namespace math
}
// namespace operators
...
...
paddle/operators/math/math_function_test.cc
浏览文件 @
86acf39c
#include "paddle/operators/math/math_function.h"
#include "gtest/gtest.h"
#ifdef PADDLE_WITH_CUDA
TEST
(
math_function
,
notrans_mul_trans
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
out_gpu
;
paddle
::
framework
::
Tensor
out
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
out_gpu
.
mutable_data
<
float
>
({
2
,
2
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
,
context
);
float
*
out_ptr
=
out
.
data
<
float
>
();
context
.
Wait
();
EXPECT_EQ
(
out_ptr
[
0
],
5
);
EXPECT_EQ
(
out_ptr
[
1
],
14
);
EXPECT_EQ
(
out_ptr
[
2
],
14
);
EXPECT_EQ
(
out_ptr
[
3
],
50
);
delete
gpu_place
;
}
TEST
(
math_function
,
trans_mul_notrans
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
out_gpu
;
paddle
::
framework
::
Tensor
out
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
out_gpu
.
mutable_data
<
float
>
({
3
,
3
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
,
context
);
float
*
out_ptr
=
out
.
data
<
float
>
();
context
.
Wait
();
EXPECT_EQ
(
out_ptr
[
0
],
9
);
EXPECT_EQ
(
out_ptr
[
1
],
12
);
EXPECT_EQ
(
out_ptr
[
2
],
15
);
EXPECT_EQ
(
out_ptr
[
3
],
12
);
EXPECT_EQ
(
out_ptr
[
4
],
17
);
EXPECT_EQ
(
out_ptr
[
5
],
22
);
EXPECT_EQ
(
out_ptr
[
6
],
15
);
EXPECT_EQ
(
out_ptr
[
7
],
22
);
EXPECT_EQ
(
out_ptr
[
8
],
29
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_notrans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
3
,
4
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
,
context
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
,
context
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
1
,
4
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
,
context
);
// numpy code:
// a = np.arange(6).reshape(2, 3)
// b = np.arange(12).reshape(3, 4)[:, 1:]
// c = np.arange(8).reshape(2, 4)[:, 1:]
// out = np.arange(8).reshape(2, 4)
// out[:, 1:] = np.dot(a, b) + c
context
.
Wait
();
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_trans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
4
,
3
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
4
,
8
,
1
,
5
,
9
,
2
,
6
,
10
,
3
,
7
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
,
context
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
,
context
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
3
,
3
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
,
context
);
context
.
Wait
();
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
#endif
TEST
(
math_function
,
gemm_notrans_cblas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
...
...
@@ -253,15 +74,15 @@ TEST(math_function, zero) {
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
t
=
tensor
.
mutable_data
<
float
>
({
2
,
2
},
*
cpu_place
);
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
SetConstant
<
paddle
::
platform
::
CPUPlace
,
float
>
(
context
,
&
tensor
,
0
);
paddle
::
operators
::
math
::
SetConstant
<
paddle
::
platform
::
CPUPlace
,
float
>
functor
;
functor
(
context
,
&
tensor
,
0
);
EXPECT_EQ
(
t
[
0
],
0
);
EXPECT_EQ
(
t
[
1
],
0
);
EXPECT_EQ
(
t
[
2
],
0
);
EXPECT_EQ
(
t
[
3
],
0
);
paddle
::
operators
::
math
::
SetConstant
<
paddle
::
platform
::
CPUPlace
,
float
>
(
context
,
&
tensor
,
1
);
functor
(
context
,
&
tensor
,
1
);
EXPECT_EQ
(
t
[
0
],
1
);
EXPECT_EQ
(
t
[
1
],
1
);
...
...
paddle/operators/math/math_function_test.cu
0 → 100644
浏览文件 @
86acf39c
#include "gtest/gtest.h"
#include "paddle/operators/math/math_function.h"
TEST
(
math_function
,
notrans_mul_trans
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
out_gpu
;
paddle
::
framework
::
Tensor
out
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
out_gpu
.
mutable_data
<
float
>
({
2
,
2
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
,
context
);
float
*
out_ptr
=
out
.
data
<
float
>
();
context
.
Wait
();
EXPECT_EQ
(
out_ptr
[
0
],
5
);
EXPECT_EQ
(
out_ptr
[
1
],
14
);
EXPECT_EQ
(
out_ptr
[
2
],
14
);
EXPECT_EQ
(
out_ptr
[
3
],
50
);
delete
gpu_place
;
}
TEST
(
math_function
,
trans_mul_notrans
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
out_gpu
;
paddle
::
framework
::
Tensor
out
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
out_gpu
.
mutable_data
<
float
>
({
3
,
3
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
,
context
);
float
*
out_ptr
=
out
.
data
<
float
>
();
context
.
Wait
();
EXPECT_EQ
(
out_ptr
[
0
],
9
);
EXPECT_EQ
(
out_ptr
[
1
],
12
);
EXPECT_EQ
(
out_ptr
[
2
],
15
);
EXPECT_EQ
(
out_ptr
[
3
],
12
);
EXPECT_EQ
(
out_ptr
[
4
],
17
);
EXPECT_EQ
(
out_ptr
[
5
],
22
);
EXPECT_EQ
(
out_ptr
[
6
],
15
);
EXPECT_EQ
(
out_ptr
[
7
],
22
);
EXPECT_EQ
(
out_ptr
[
8
],
29
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_notrans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
3
,
4
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
,
context
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
,
context
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
1
,
4
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
,
context
);
// numpy code:
// a = np.arange(6).reshape(2, 3)
// b = np.arange(12).reshape(3, 4)[:, 1:]
// c = np.arange(8).reshape(2, 4)[:, 1:]
// out = np.arange(8).reshape(2, 4)
// out[:, 1:] = np.dot(a, b) + c
context
.
Wait
();
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_trans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
4
,
3
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
4
,
8
,
1
,
5
,
9
,
2
,
6
,
10
,
3
,
7
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
,
context
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
,
context
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
,
context
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
3
,
3
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
,
context
);
context
.
Wait
();
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
paddle/operators/math/selected_rows_functor.cc
0 → 100644
浏览文件 @
86acf39c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
SelectedRowsAdd
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input1
,
const
framework
::
SelectedRows
&
input2
,
framework
::
SelectedRows
*
output
)
{
auto
in1_height
=
input1
.
height
();
PADDLE_ENFORCE_EQ
(
in1_height
,
input2
.
height
());
output
->
set_height
(
in1_height
);
auto
&
in1_rows
=
input1
.
rows
();
auto
&
in2_rows
=
input2
.
rows
();
std
::
vector
<
int64_t
>
out_rows
;
out_rows
.
reserve
(
in1_rows
.
size
()
+
in2_rows
.
size
());
// concat rows
out_rows
.
insert
(
out_rows
.
end
(),
in1_rows
.
begin
(),
in1_rows
.
end
());
out_rows
.
insert
(
out_rows
.
end
(),
in2_rows
.
begin
(),
in2_rows
.
end
());
output
->
set_rows
(
out_rows
);
auto
*
out_value
=
output
->
mutable_value
();
auto
&
in1_value
=
input1
.
value
();
auto
&
in2_value
=
input2
.
value
();
auto
in1_row_numel
=
in1_value
.
numel
()
/
in1_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
in2_value
.
numel
()
/
in2_rows
.
size
());
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
out_value
->
numel
()
/
out_rows
.
size
());
auto
in1_place
=
input1
.
place
();
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
in1_place
));
auto
in2_place
=
input2
.
place
();
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
in2_place
));
auto
out_place
=
context
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
out_place
));
auto
*
out_data
=
out_value
->
data
<
T
>
();
auto
*
in1_data
=
in1_value
.
data
<
T
>
();
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
out_place
),
out_data
,
boost
::
get
<
platform
::
CPUPlace
>
(
in1_place
),
in1_data
,
in1_value
.
numel
()
*
sizeof
(
T
));
auto
*
in2_data
=
in2_value
.
data
<
T
>
();
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
out_place
),
out_data
+
in1_value
.
numel
(),
boost
::
get
<
platform
::
CPUPlace
>
(
in2_place
),
in2_data
,
in2_value
.
numel
()
*
sizeof
(
T
));
}
};
template
struct
SelectedRowsAdd
<
platform
::
CPUPlace
,
float
>;
template
<
typename
T
>
struct
SelectedRowsAddTensor
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input1
,
const
framework
::
Tensor
&
input2
,
framework
::
Tensor
*
output
)
{
auto
in1_height
=
input1
.
height
();
auto
in2_dims
=
input2
.
dims
();
auto
out_dims
=
output
->
dims
();
PADDLE_ENFORCE_EQ
(
in1_height
,
in2_dims
[
0
]);
PADDLE_ENFORCE_EQ
(
in1_height
,
out_dims
[
0
]);
auto
&
in1_value
=
input1
.
value
();
auto
&
in1_rows
=
input1
.
rows
();
int64_t
in1_row_numel
=
in1_value
.
numel
()
/
in1_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
input2
.
numel
()
/
in1_height
);
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
output
->
numel
()
/
in1_height
);
SetConstant
<
platform
::
CPUPlace
,
T
>
functor
;
functor
(
context
,
output
,
0.0
);
auto
*
in1_data
=
in1_value
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
in1_rows
.
size
();
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
in1_row_numel
;
j
++
)
{
out_data
[
in1_rows
[
i
]
*
in1_row_numel
+
j
]
+=
in1_data
[
i
*
in1_row_numel
+
j
];
}
}
auto
out_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
in2_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
input2
);
out_eigen
.
device
(
*
context
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
out_eigen
+
in2_eigen
;
}
};
template
struct
SelectedRowsAddTensor
<
platform
::
CPUPlace
,
float
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/selected_rows_functor.cu
0 → 100644
浏览文件 @
86acf39c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
SelectedRowsAdd
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input1
,
const
framework
::
SelectedRows
&
input2
,
framework
::
SelectedRows
*
output
)
{
auto
in1_height
=
input1
.
height
();
PADDLE_ENFORCE_EQ
(
in1_height
,
input2
.
height
());
output
->
set_height
(
in1_height
);
auto
&
in1_rows
=
input1
.
rows
();
auto
&
in2_rows
=
input2
.
rows
();
std
::
vector
<
int64_t
>
out_rows
;
out_rows
.
reserve
(
in1_rows
.
size
()
+
in2_rows
.
size
());
// concat rows
out_rows
.
insert
(
out_rows
.
end
(),
in1_rows
.
begin
(),
in1_rows
.
end
());
out_rows
.
insert
(
out_rows
.
end
(),
in2_rows
.
begin
(),
in2_rows
.
end
());
output
->
set_rows
(
out_rows
);
auto
*
out_value
=
output
->
mutable_value
();
auto
&
in1_value
=
input1
.
value
();
auto
&
in2_value
=
input2
.
value
();
auto
in1_row_numel
=
in1_value
.
numel
()
/
in1_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
in2_value
.
numel
()
/
in2_rows
.
size
());
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
out_value
->
numel
()
/
out_rows
.
size
());
auto
*
out_data
=
out_value
->
data
<
T
>
();
auto
*
in1_data
=
in1_value
.
data
<
T
>
();
auto
in1_place
=
input1
.
place
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
in1_place
));
auto
in2_place
=
input2
.
place
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
in2_place
));
auto
out_place
=
context
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
out_place
));
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
out_place
),
out_data
,
boost
::
get
<
platform
::
GPUPlace
>
(
in1_place
),
in1_data
,
in1_value
.
numel
()
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
());
auto
*
in2_data
=
in2_value
.
data
<
T
>
();
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
out_place
),
out_data
+
in1_value
.
numel
(),
boost
::
get
<
platform
::
GPUPlace
>
(
in2_place
),
in2_data
,
in2_value
.
numel
()
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
());
}
};
template
struct
SelectedRowsAdd
<
platform
::
GPUPlace
,
float
>;
namespace
{
template
<
typename
T
>
__global__
void
SelectedRowsAddTensorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
T
*
tensor_out
,
int64_t
row_numel
,
int
block_size
)
{
const
int
ty
=
blockIdx
.
y
;
int
tid
=
threadIdx
.
x
;
selected_rows
+=
ty
*
row_numel
;
tensor_out
+=
rows
[
ty
]
*
row_numel
;
for
(
int
index
=
tid
;
index
<
row_numel
;
index
+=
block_size
)
{
// Since index in rows of SelectedRows can be duplicate, we can not use
// tensor_out[index] += selected_rows[index]; Instead, we have to use
// AtomicAdd to avoid concurrent write error.
paddle
::
platform
::
CudaAtomicAdd
(
tensor_out
+
index
,
selected_rows
[
index
]);
}
}
}
// namespace
template
<
typename
T
>
struct
SelectedRowsAddTensor
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input1
,
const
framework
::
Tensor
&
input2
,
framework
::
Tensor
*
output
)
{
auto
in1_height
=
input1
.
height
();
auto
in2_dims
=
input2
.
dims
();
auto
out_dims
=
output
->
dims
();
PADDLE_ENFORCE_EQ
(
in1_height
,
in2_dims
[
0
]);
PADDLE_ENFORCE_EQ
(
in1_height
,
out_dims
[
0
]);
auto
&
in1_value
=
input1
.
value
();
auto
&
in1_rows
=
input1
.
rows
();
int64_t
in1_row_numel
=
in1_value
.
numel
()
/
in1_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
input2
.
numel
()
/
in1_height
);
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
output
->
numel
()
/
in1_height
);
auto
*
in1_data
=
in1_value
.
data
<
T
>
();
auto
*
in2_data
=
input2
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
SetConstant
<
platform
::
GPUPlace
,
T
>
functor
;
functor
(
context
,
output
,
0.0
);
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in1_rows
.
size
());
SelectedRowsAddTensorKernel
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
in1_data
,
in1_rows
.
data
(),
out_data
,
in1_row_numel
,
block_size
);
auto
out_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
in2_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
input2
);
out_eigen
.
device
(
*
context
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
out_eigen
+
in2_eigen
;
}
};
template
struct
SelectedRowsAddTensor
<
platform
::
GPUPlace
,
float
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/selected_rows_functor.h
0 → 100644
浏览文件 @
86acf39c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
// SelectedRows + SelectedRows will simplely concat value and rows.
// The real computation happens in dealing with LoDTensor.
template
<
typename
Place
,
typename
T
>
struct
SelectedRowsAdd
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input1
,
const
framework
::
SelectedRows
&
input2
,
framework
::
SelectedRows
*
output
);
};
template
<
typename
Place
,
typename
T
>
struct
SelectedRowsAddTensor
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input1
,
const
framework
::
Tensor
&
input2
,
framework
::
Tensor
*
output
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/selected_rows_functor_test.cc
0 → 100644
浏览文件 @
86acf39c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h"
#include "gtest/gtest.h"
#include "paddle/operators/math/math_function.h"
TEST
(
selected_rows_functor
,
cpu_add
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
operators
::
math
;
CPUPlace
cpu_place
;
CPUDeviceContext
ctx
(
cpu_place
);
SetConstant
<
CPUPlace
,
float
>
functor
;
int64_t
height
=
10
;
int64_t
row_numel
=
10
;
std
::
vector
<
int64_t
>
rows1
{
0
,
4
,
7
};
std
::
unique_ptr
<
SelectedRows
>
selected_rows1
{
new
SelectedRows
(
rows1
,
height
)};
auto
*
in1_value
=
selected_rows1
->
mutable_value
();
in1_value
->
mutable_data
<
float
>
(
make_ddim
({
static_cast
<
int64_t
>
(
rows1
.
size
()),
row_numel
}),
cpu_place
);
functor
(
ctx
,
in1_value
,
1.0
);
std
::
vector
<
int64_t
>
rows2
{
0
,
5
,
7
,
9
};
std
::
unique_ptr
<
SelectedRows
>
selected_rows2
{
new
SelectedRows
(
rows2
,
height
)};
auto
*
in2_value
=
selected_rows2
->
mutable_value
();
in2_value
->
mutable_data
<
float
>
(
make_ddim
({
static_cast
<
int64_t
>
(
rows2
.
size
()),
row_numel
}),
cpu_place
);
functor
(
ctx
,
in2_value
,
2.0
);
std
::
unique_ptr
<
SelectedRows
>
output
{
new
SelectedRows
()};
auto
*
out_value
=
output
->
mutable_value
();
// simplely concat two SelectedRows
out_value
->
mutable_data
<
float
>
(
make_ddim
({
7
,
10
}),
cpu_place
);
SelectedRowsAdd
<
CPUPlace
,
float
>
add_functor
;
add_functor
(
ctx
,
*
selected_rows1
,
*
selected_rows2
,
output
.
get
());
auto
out_height
=
output
->
height
();
EXPECT_EQ
(
out_height
,
height
);
auto
&
out_rows
=
output
->
rows
();
// input1 rows
EXPECT_EQ
(
out_rows
[
0
],
0
);
EXPECT_EQ
(
out_rows
[
1
],
4
);
EXPECT_EQ
(
out_rows
[
2
],
7
);
// input2 rows
EXPECT_EQ
(
out_rows
[
3
],
0
);
EXPECT_EQ
(
out_rows
[
4
],
5
);
EXPECT_EQ
(
out_rows
[
5
],
7
);
EXPECT_EQ
(
out_rows
[
6
],
9
);
auto
*
out_data
=
output
->
value
().
data
<
float
>
();
// input1 value
EXPECT_EQ
(
out_data
[
0
*
row_numel
+
0
],
1.0
);
EXPECT_EQ
(
out_data
[
0
*
row_numel
+
8
],
1.0
);
EXPECT_EQ
(
out_data
[
1
*
row_numel
+
1
],
1.0
);
EXPECT_EQ
(
out_data
[
2
*
row_numel
+
6
],
1.0
);
// input2 value
EXPECT_EQ
(
out_data
[
3
*
row_numel
+
3
],
2.0
);
EXPECT_EQ
(
out_data
[
3
*
row_numel
+
8
],
2.0
);
EXPECT_EQ
(
out_data
[
4
*
row_numel
+
4
],
2.0
);
EXPECT_EQ
(
out_data
[
5
*
row_numel
+
7
],
2.0
);
EXPECT_EQ
(
out_data
[
6
*
row_numel
+
9
],
2.0
);
std
::
unique_ptr
<
Tensor
>
tensor1
{
new
Tensor
()};
tensor1
->
mutable_data
<
float
>
(
make_ddim
({
height
,
row_numel
}),
cpu_place
);
functor
(
ctx
,
tensor1
.
get
(),
3.0
);
std
::
unique_ptr
<
Tensor
>
tensor2
{
new
Tensor
()};
tensor2
->
mutable_data
<
float
>
(
make_ddim
({
height
,
row_numel
}),
cpu_place
);
SelectedRowsAddTensor
<
CPUPlace
,
float
>
add_tensor_functor
;
add_tensor_functor
(
ctx
,
*
output
,
*
tensor1
,
tensor2
.
get
());
auto
*
tensor2_data
=
tensor2
->
data
<
float
>
();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ
(
tensor2_data
[
0
*
row_numel
+
0
],
6.0
);
// row1: 3.0
EXPECT_EQ
(
tensor2_data
[
1
*
row_numel
+
1
],
3.0
);
// row4 : 1.0 + 3.0
EXPECT_EQ
(
tensor2_data
[
4
*
row_numel
+
6
],
4.0
);
// row5: 2.0 + 3.0
EXPECT_EQ
(
tensor2_data
[
5
*
row_numel
+
7
],
5.0
);
// row6: 3.0
EXPECT_EQ
(
tensor2_data
[
6
*
row_numel
+
1
],
3.0
);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ
(
tensor2_data
[
7
*
row_numel
+
3
],
6.0
);
// row9: 2.0 + 3.0
EXPECT_EQ
(
tensor2_data
[
9
*
row_numel
+
6
],
5.0
);
}
paddle/operators/math/selected_rows_functor_test.cu
0 → 100644
浏览文件 @
86acf39c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
TEST
(
selected_rows_functor
,
gpu_add
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
operators
::
math
;
GPUPlace
gpu_place
(
0
);
CPUPlace
cpu_place
;
CUDADeviceContext
ctx
(
gpu_place
);
SetConstant
<
GPUPlace
,
float
>
functor
;
int64_t
height
=
10
;
int64_t
row_numel
=
10
;
std
::
vector
<
int64_t
>
rows1
{
0
,
4
,
7
};
std
::
unique_ptr
<
SelectedRows
>
selected_rows1
{
new
SelectedRows
(
rows1
,
height
)};
auto
*
in1_value
=
selected_rows1
->
mutable_value
();
in1_value
->
mutable_data
<
float
>
(
make_ddim
({
static_cast
<
int64_t
>
(
rows1
.
size
()),
row_numel
}),
gpu_place
);
functor
(
ctx
,
in1_value
,
1.0
);
std
::
vector
<
int64_t
>
rows2
{
0
,
5
,
7
,
9
};
std
::
unique_ptr
<
SelectedRows
>
selected_rows2
{
new
SelectedRows
(
rows2
,
height
)};
auto
*
in2_value
=
selected_rows2
->
mutable_value
();
in2_value
->
mutable_data
<
float
>
(
make_ddim
({
static_cast
<
int64_t
>
(
rows2
.
size
()),
row_numel
}),
gpu_place
);
functor
(
ctx
,
in2_value
,
2.0
);
std
::
unique_ptr
<
SelectedRows
>
output
{
new
SelectedRows
()};
auto
*
out_value
=
output
->
mutable_value
();
// simplely concat two SelectedRows
out_value
->
mutable_data
<
float
>
(
make_ddim
({
7
,
10
}),
gpu_place
);
SelectedRowsAdd
<
GPUPlace
,
float
>
add_functor
;
add_functor
(
ctx
,
*
selected_rows1
,
*
selected_rows2
,
output
.
get
());
auto
out_height
=
output
->
height
();
EXPECT_EQ
(
out_height
,
height
);
auto
&
out_rows
=
output
->
rows
();
// input1 rows
EXPECT_EQ
(
out_rows
[
0
],
0
);
EXPECT_EQ
(
out_rows
[
1
],
4
);
EXPECT_EQ
(
out_rows
[
2
],
7
);
// input2 rows
EXPECT_EQ
(
out_rows
[
3
],
0
);
EXPECT_EQ
(
out_rows
[
4
],
5
);
EXPECT_EQ
(
out_rows
[
5
],
7
);
EXPECT_EQ
(
out_rows
[
6
],
9
);
Tensor
out_cpu
;
out_cpu
.
CopyFrom
<
float
>
(
*
out_value
,
cpu_place
,
ctx
);
ctx
.
Wait
();
auto
*
out_cpu_data
=
out_cpu
.
data
<
float
>
();
// input1 value
EXPECT_EQ
(
out_cpu_data
[
0
*
row_numel
+
0
],
1.0
);
EXPECT_EQ
(
out_cpu_data
[
0
*
row_numel
+
8
],
1.0
);
EXPECT_EQ
(
out_cpu_data
[
1
*
row_numel
+
1
],
1.0
);
EXPECT_EQ
(
out_cpu_data
[
2
*
row_numel
+
6
],
1.0
);
// input2 value
EXPECT_EQ
(
out_cpu_data
[
3
*
row_numel
+
3
],
2.0
);
EXPECT_EQ
(
out_cpu_data
[
3
*
row_numel
+
8
],
2.0
);
EXPECT_EQ
(
out_cpu_data
[
4
*
row_numel
+
4
],
2.0
);
EXPECT_EQ
(
out_cpu_data
[
5
*
row_numel
+
7
],
2.0
);
EXPECT_EQ
(
out_cpu_data
[
6
*
row_numel
+
9
],
2.0
);
std
::
unique_ptr
<
Tensor
>
tensor1
{
new
Tensor
()};
tensor1
->
mutable_data
<
float
>
(
make_ddim
({
height
,
row_numel
}),
gpu_place
);
functor
(
ctx
,
tensor1
.
get
(),
3.0
);
std
::
unique_ptr
<
Tensor
>
tensor2
{
new
Tensor
()};
tensor2
->
mutable_data
<
float
>
(
make_ddim
({
height
,
row_numel
}),
gpu_place
);
SelectedRowsAddTensor
<
GPUPlace
,
float
>
add_tensor_functor
;
add_tensor_functor
(
ctx
,
*
output
,
*
tensor1
,
tensor2
.
get
());
Tensor
tensor2_cpu
;
tensor2_cpu
.
CopyFrom
<
float
>
(
*
tensor2
,
cpu_place
,
ctx
);
ctx
.
Wait
();
auto
*
tensor2_cpu_data
=
tensor2_cpu
.
data
<
float
>
();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
0
*
row_numel
+
0
],
6.0
);
// row1: 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
1
*
row_numel
+
1
],
3.0
);
// row4 : 1.0 + 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
4
*
row_numel
+
6
],
4.0
);
// row5: 2.0 + 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
5
*
row_numel
+
7
],
5.0
);
// row6: 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
6
*
row_numel
+
1
],
3.0
);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
7
*
row_numel
+
3
],
6.0
);
// row9: 2.0 + 3.0
EXPECT_EQ
(
tensor2_cpu_data
[
9
*
row_numel
+
6
],
5.0
);
}
paddle/operators/sequence_pool_op.h
浏览文件 @
86acf39c
...
...
@@ -111,7 +111,8 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
in_g
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
strategy
==
LAST
||
strategy
==
FIRST
)
{
// set X@Grad be zero at first when strategy is LAST/FIRST
math
::
SetConstant
<
Place
,
T
>
(
context
.
device_context
(),
in_g
,
0
);
math
::
SetConstant
<
Place
,
T
>
functor
;
functor
(
context
.
device_context
(),
in_g
,
0
);
}
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录