Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9ffa79cd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9ffa79cd
编写于
9月 20, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add gemm with stride
上级
d865b047
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
239 addition
and
0 deletion
+239
-0
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+26
-0
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+36
-0
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+7
-0
paddle/operators/math/math_function_test.cc
paddle/operators/math/math_function_test.cc
+170
-0
未找到文件。
paddle/operators/math/math_function.cc
浏览文件 @
9ffa79cd
...
...
@@ -48,6 +48,32 @@ void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
cblas_sgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
cblas_dgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
matmul
<
platform
::
CPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
...
...
paddle/operators/math/math_function.cu
浏览文件 @
9ffa79cd
...
...
@@ -63,6 +63,42 @@ void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
GPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
gemm
<
platform
::
GPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
matmul
<
platform
::
GPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
...
...
paddle/operators/math/math_function.h
浏览文件 @
9ffa79cd
...
...
@@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
);
// gemm wrapper with stride args for matrix uncontinuous in memory
template
<
typename
Place
,
typename
T
>
void
gemm
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
);
// matrix multiply with continuous memory
template
<
typename
Place
,
typename
T
>
void
matmul
(
const
platform
::
DeviceContext
&
context
,
...
...
paddle/operators/math/math_function_test.cc
浏览文件 @
9ffa79cd
...
...
@@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) {
EXPECT_EQ
(
out_ptr
[
8
],
29
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_notrans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
3
,
4
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
1
,
4
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
);
// numpy code:
// a = np.arange(6).reshape(2, 3)
// b = np.arange(12).reshape(3, 4)[:, 1:]
// c = np.arange(8).reshape(2, 4)[:, 1:]
// out = np.arange(8).reshape(2, 4)
// out[:, 1:] = np.dot(a, b) + c
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_trans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
4
,
3
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
4
,
8
,
1
,
5
,
9
,
2
,
6
,
10
,
3
,
7
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
3
,
3
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
#endif
TEST
(
math_function
,
gemm_notrans_cblas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
3
,
4
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CPUPlace
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
1
,
4
,
1
,
input3_ptr
+
1
,
4
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
}
TEST
(
math_function
,
gemm_trans_clbas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
4
,
3
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
4
,
8
,
1
,
5
,
9
,
2
,
6
,
10
,
3
,
7
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CPUPlace
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
3
,
3
,
1
,
input3_ptr
+
1
,
4
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录