Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2abcf379
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2abcf379
编写于
5月 03, 2018
作者:
Y
Yu Yang
提交者:
GitHub
5月 03, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10327 from reyoung/feature/clean_blas
Feature/clean blas
上级
54797abd
bc816035
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
398 addition
and
353 deletion
+398
-353
paddle/fluid/operators/bilinear_tensor_product_op.h
paddle/fluid/operators/bilinear_tensor_product_op.h
+11
-12
paddle/fluid/operators/gru_unit_op.h
paddle/fluid/operators/gru_unit_op.h
+23
-29
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+151
-0
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+74
-0
paddle/fluid/operators/math/gru_compute.cc
paddle/fluid/operators/math/gru_compute.cc
+23
-27
paddle/fluid/operators/math/gru_compute.cu
paddle/fluid/operators/math/gru_compute.cu
+25
-26
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+8
-74
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+6
-157
paddle/fluid/operators/math/math_function.h
paddle/fluid/operators/math/math_function.h
+45
-8
paddle/fluid/operators/math/math_function_test.cc
paddle/fluid/operators/math/math_function_test.cc
+11
-6
paddle/fluid/operators/math/math_function_test.cu
paddle/fluid/operators/math/math_function_test.cu
+18
-12
paddle/fluid/operators/math/matmul.h
paddle/fluid/operators/math/matmul.h
+3
-2
未找到文件。
paddle/fluid/operators/bilinear_tensor_product_op.h
浏览文件 @
2abcf379
...
...
@@ -61,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto
output_col_vec
=
output_mat
.
chip
(
i
,
1
);
Tensor
weight_mat
=
weight
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
make_ddim
({
x_dim
,
y_dim
}));
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
x
->
data
<
T
>
(),
weight_mat
.
data
<
T
>
(),
0
,
left_mul
.
data
<
T
>
());
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
).
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
x
->
data
<
T
>
(),
weight_mat
.
data
<
T
>
(),
0
,
left_mul
.
data
<
T
>
());
output_col_vec
.
device
(
place
)
=
(
left_mul_mat
*
y_mat
).
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
));
}
...
...
@@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
set_zero
(
dev_ctx
,
d_y
,
static_cast
<
T
>
(
0
));
}
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
// Caculate the Output(X@Grad) and Output(Y@Grad).
if
(
d_x
||
d_y
)
{
Eigen
::
DSizes
<
int
,
2
>
bcast_for_x
(
1
,
y_dim
);
...
...
@@ -138,18 +140,16 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
))
.
broadcast
(
bcast_for_x
)
*
y_mat
;
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasNoTrans
,
CblasTrans
,
batch_size
,
x_dim
,
y_dim
,
1
,
y_scale
.
data
<
T
>
(),
weight_i
.
data
<
T
>
(),
1
,
d_x
->
data
<
T
>
());
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
batch_size
,
x_dim
,
y_dim
,
1
,
y_scale
.
data
<
T
>
(),
weight_i
.
data
<
T
>
(),
1
,
d_x
->
data
<
T
>
());
}
if
(
d_y
)
{
x_scale_mat
.
device
(
place
)
=
output_vec
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
))
.
broadcast
(
bcast_for_y
)
*
x_mat
;
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
x_scale
.
data
<
T
>
(),
weight_i
.
data
<
T
>
(),
1
,
d_y
->
data
<
T
>
());
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
x_scale
.
data
<
T
>
(),
weight_i
.
data
<
T
>
(),
1
,
d_y
->
data
<
T
>
());
}
}
}
...
...
@@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
))
.
broadcast
(
bcast_for_weight
)
*
x_mat
;
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasTrans
,
CblasNoTrans
,
x_dim
,
y_dim
,
batch_size
,
1
,
x_scale
.
data
<
T
>
(),
y
->
data
<
T
>
(),
0
,
d_weight_i
.
data
<
T
>
());
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
x_dim
,
y_dim
,
batch_size
,
1
,
x_scale
.
data
<
T
>
(),
y
->
data
<
T
>
(),
0
,
d_weight_i
.
data
<
T
>
());
}
}
...
...
paddle/fluid/operators/gru_unit_op.h
浏览文件 @
2abcf379
...
...
@@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
const
T
*
weight_data
=
weight
->
data
<
T
>
();
T
*
gate_data
=
gate
->
data
<
T
>
();
T
*
reset_hidden_prev_data
=
reset_hidden_prev
->
data
<
T
>
();
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
false
,
batch_size
,
2
*
frame_size
,
frame_size
,
1
,
hidden_prev_data
,
frame_size
,
weight_data
,
frame_size
*
2
,
1
,
gate_data
,
frame_size
*
3
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
2
*
frame_size
,
frame_size
,
1
,
hidden_prev_data
,
frame_size
,
weight_data
,
frame_size
*
2
,
1
,
gate_data
,
frame_size
*
3
);
// calculate activited gate
Eigen
::
array
<
int
,
2
>
extents
({{
batch_size
,
frame_size
}});
...
...
@@ -103,11 +103,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
g
.
slice
(
r_offsets
,
extents
),
g
.
slice
(
r_offsets
,
extents
));
auto
r
=
g
.
slice
(
r_offsets
,
extents
);
// reset gate
r_h_p
.
device
(
place
)
=
r
*
h_p
;
// reset previous hidden state
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
1
,
gate_data
+
frame_size
*
2
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
1
,
gate_data
+
frame_size
*
2
,
frame_size
*
3
);
Eigen
::
array
<
int
,
2
>
c_offsets
({{
0
,
frame_size
*
2
}});
ActCompute
(
context
.
Attr
<
int
>
(
"activation"
),
place
,
...
...
@@ -188,11 +187,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
ActGradCompute
(
context
.
Attr
<
int
>
(
"activation"
),
place
,
c
,
c
,
d_g
.
slice
(
c_offsets
,
extents
),
d_h
*
u
);
// backward for reset_hidden_prev
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
0
,
reset_hidden_prev_grad_data
,
frame_size
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
0
,
reset_hidden_prev_grad_data
,
frame_size
);
// backward for unactivated reset gate
ActGradCompute
(
context
.
Attr
<
int
>
(
"gate_activation"
),
place
,
r
,
r
,
d_g
.
slice
(
r_offsets
,
extents
),
d_r_h_p
*
h_p
);
...
...
@@ -200,18 +199,15 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
if
(
weight_grad
)
{
T
*
weight_grad_data
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// backward for state_weight
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
0
,
weight_grad_data
+
frame_size
*
frame_size
*
2
,
frame_size
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
0
,
weight_grad_data
+
frame_size
*
frame_size
*
2
,
frame_size
);
// backward for update_gate_weight and reset_gate_weight
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
hidden_prev_data
,
frame_size
,
gate_grad_data
,
frame_size
*
3
,
0
,
weight_grad_data
,
frame_size
*
2
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
hidden_prev_data
,
frame_size
,
gate_grad_data
,
frame_size
*
3
,
0
,
weight_grad_data
,
frame_size
*
2
);
}
// backward for hidden_prev
if
(
hidden_prev_grad
)
{
...
...
@@ -219,11 +215,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
hidden_prev_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_h_p
=
EigenMatrix
<
T
>::
From
(
*
hidden_prev_grad
);
d_h_p
.
device
(
place
)
=
d_r_h_p
*
r
+
d_h
*
(
u
.
constant
(
T
(
1
))
-
u
);
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
gate_grad_data
,
frame_size
*
3
,
weight_data
,
frame_size
*
2
,
1
,
hidden_prev_grad_data
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
gate_grad_data
,
frame_size
*
3
,
weight_data
,
frame_size
*
2
,
1
,
hidden_prev_grad_data
,
frame_size
);
}
// backward for input
if
(
input_grad
)
{
...
...
paddle/fluid/operators/math/blas_impl.cu.h
0 → 100644
浏览文件 @
2abcf379
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
CUBlas
;
template
<
>
struct
CUBlas
<
float
>
{
template
<
typename
...
ARGS
>
static
void
GEMM
(
ARGS
...
args
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
args
...));
}
};
template
<
>
struct
CUBlas
<
double
>
{
template
<
typename
...
ARGS
>
static
void
GEMM
(
ARGS
...
args
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
args
...));
}
};
template
<
>
struct
CUBlas
<
platform
::
float16
>
{
using
float16
=
platform
::
float16
;
static
void
GEMM
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float16
*
alpha
,
const
float16
*
A
,
int
lda
,
const
float16
*
B
,
int
ldb
,
const
float16
*
beta
,
float16
*
C
,
int
ldc
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
reinterpret_cast
<
const
__half
*>
(
alpha
),
reinterpret_cast
<
const
__half
*>
(
A
),
lda
,
reinterpret_cast
<
const
__half
*>
(
B
),
ldb
,
reinterpret_cast
<
const
__half
*>
(
beta
),
reinterpret_cast
<
__half
*>
(
C
),
ldc
));
}
};
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
}
template
<
>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
platform
::
float16
alpha
,
const
platform
::
float16
*
A
,
const
platform
::
float16
*
B
,
const
platform
::
float16
beta
,
platform
::
float16
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
53
,
"cublas fp16 gemm requires GPU compute capability >= 53"
);
#if CUDA_VERSION >= 8000
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
if
(
context_
.
GetComputeCapability
()
>=
70
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context_
.
cublas_handle
(),
CUBLAS_TENSOR_OP_MATH
));
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
else
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context_
.
cublas_handle
(),
CUBLAS_DEFAULT_MATH
));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
B
,
CUDA_R_16F
,
ldb
,
A
,
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
,
algo
));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
#endif // CUDA_VERSION >= 8000
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/blas_impl.h
0 → 100644
浏览文件 @
2abcf379
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
CBlas
;
template
<
>
struct
CBlas
<
float
>
{
template
<
typename
...
ARGS
>
static
void
GEMM
(
ARGS
...
args
)
{
cblas_sgemm
(
args
...);
}
};
template
<
>
struct
CBlas
<
double
>
{
template
<
typename
...
ARGS
>
static
void
GEMM
(
ARGS
...
args
)
{
cblas_dgemm
(
args
...);
}
};
template
<
>
struct
CBlas
<
platform
::
float16
>
{
static
void
GEMM
(...)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
};
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
{
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/gru_compute.cc
浏览文件 @
2abcf379
...
...
@@ -25,21 +25,21 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
#ifndef __NVCC__
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
}
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
frame_size
,
batch_size
,
active_gate
);
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
...
...
@@ -58,36 +58,32 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_node
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
&&
grad
.
prev_out_grad
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
if
(
grad
.
state_weight_grad
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_gate
);
if
(
grad
.
prev_out_grad
&&
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
if
(
grad
.
gate_weight_grad
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
}
}
#endif
...
...
paddle/fluid/operators/math/gru_compute.cu
浏览文件 @
2abcf379
...
...
@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h"
...
...
@@ -36,12 +37,11 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame_size
+
32
-
1
)
/
32
,
(
batch_size
+
32
-
1
)
/
32
);
}
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
}
if
(
batch_size
==
1
)
{
...
...
@@ -61,10 +61,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
}
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
}
if
(
batch_size
==
1
)
{
...
...
@@ -121,18 +121,19 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
grad
.
output_grad
,
frame_size
,
batch_size
,
active_node
);
}
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
&&
grad
.
prev_out_grad
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
if
(
grad
.
state_weight_grad
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
...
...
@@ -153,16 +154,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
}
if
(
grad
.
prev_out_grad
&&
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
if
(
grad
.
gate_weight_grad
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
}
}
}
...
...
paddle/fluid/operators/math/math_function.cc
浏览文件 @
2abcf379
...
...
@@ -24,72 +24,6 @@ namespace math {
using
float16
=
paddle
::
platform
::
float16
;
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
float16
*
B
,
const
float16
beta
,
float16
*
C
)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cblas_sgemm
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cblas_dgemm
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
int
lda
,
const
float16
*
B
,
const
int
ldb
,
const
float16
beta
,
float16
*
C
,
const
int
ldc
)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
cblas_sgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
cblas_dgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
matmul
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
...
...
@@ -123,8 +57,8 @@ void matmul<platform::CPUDeviceContext, float>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
...
...
@@ -152,8 +86,8 @@ void matmul<platform::CPUDeviceContext, double>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
...
...
@@ -230,8 +164,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const
float
*
Ak
=
&
A
[
k
*
strideA
];
const
float
*
Bk
=
&
B
[
k
*
strideB
];
float
*
Ck
=
&
C
[
k
*
M
*
N
];
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
}
}
...
...
@@ -246,8 +180,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const
double
*
Ak
=
&
A
[
k
*
strideA
];
const
double
*
Bk
=
&
B
[
k
*
strideB
];
double
*
Ck
=
&
C
[
k
*
M
*
N
];
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
}
}
#endif
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
2abcf379
...
...
@@ -25,157 +25,6 @@ namespace math {
using
float16
=
paddle
::
platform
::
float16
;
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
float16
*
B
,
const
float16
beta
,
float16
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context
.
GetComputeCapability
(),
53
,
"cublas fp16 gemm requires GPU compute capability >= 53"
);
#if CUDA_VERSION >= 8000
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
if
(
context
.
GetComputeCapability
()
>=
70
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context
.
cublas_handle
(),
CUBLAS_TENSOR_OP_MATH
));
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
else
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context
.
cublas_handle
(),
CUBLAS_DEFAULT_MATH
));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
B
,
CUDA_R_16F
,
ldb
,
A
,
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
,
algo
));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const
half
h_alpha
=
static_cast
<
const
half
>
(
alpha
);
const
half
h_beta
=
static_cast
<
const
half
>
(
beta
);
const
half
*
h_A
=
reinterpret_cast
<
const
half
*>
(
A
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
));
#endif // CUDA_VERSION >= 8000
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
double
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
int
lda
,
const
float16
*
B
,
const
int
ldb
,
const
float16
beta
,
float16
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
const
half
h_alpha
=
static_cast
<
const
half
>
(
alpha
);
const
half
h_beta
=
static_cast
<
const
half
>
(
beta
);
const
half
*
h_A
=
reinterpret_cast
<
const
half
*>
(
A
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context
.
GetComputeCapability
(),
53
,
"cublas Hgemm requires GPU compute capability >= 53"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
ldc
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
double
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
matmul
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
...
...
@@ -200,8 +49,8 @@ void matmul<platform::CUDADeviceContext, float16>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CUDADeviceContext
,
float16
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float16
>
(),
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float16
>
(),
matrix_b
.
data
<
float16
>
(),
beta
,
matrix_out
->
data
<
float16
>
());
}
...
...
@@ -229,8 +78,8 @@ void matmul<platform::CUDADeviceContext, float>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CUDADeviceContext
,
float
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
...
...
@@ -258,8 +107,8 @@ void matmul<platform::CUDADeviceContext, double>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CUDADeviceContext
,
double
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
...
...
paddle/fluid/operators/math/math_function.h
浏览文件 @
2abcf379
...
...
@@ -42,6 +42,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -56,17 +57,48 @@ namespace math {
// Then matrixA: M * K, matrixB: K * N, matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template
<
typename
DeviceContext
>
class
Blas
{
public:
explicit
Blas
(
const
DeviceContext
&
context
)
:
context_
(
context
)
{}
template
<
typename
T
>
void
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
;
template
<
typename
T
>
void
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
;
private:
const
DeviceContext
&
context_
;
};
template
<
typename
DeviceContext
,
typename
T
>
void
gemm
(
const
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
);
class
BlasT
:
private
Blas
<
DeviceContext
>
{
public:
using
Blas
<
DeviceContext
>::
Blas
;
template
<
typename
...
ARGS
>
void
GEMM
(
ARGS
...
args
)
const
{
static_cast
<
const
Blas
<
DeviceContext
>*>
(
this
)
->
template
GEMM
<
T
>(
args
...);
}
};
// gemm wrapper with stride args for matrix uncontinuous in memory
template
<
typename
DeviceContext
,
typename
T
>
void
gemm
(
const
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
);
inline
BlasT
<
DeviceContext
,
T
>
GetBlas
(
const
framework
::
ExecutionContext
&
exe_ctx
)
{
return
BlasT
<
DeviceContext
,
T
>
(
exe_ctx
.
template
device_context
<
DeviceContext
>());
}
template
<
typename
DeviceContext
,
typename
T
>
inline
BlasT
<
DeviceContext
,
T
>
GetBlas
(
const
DeviceContext
&
dev_ctx
)
{
return
BlasT
<
DeviceContext
,
T
>
(
dev_ctx
);
}
// matrix multiply with continuous memory
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -137,3 +169,8 @@ struct RowwiseMean {
}
// namespace math
}
// namespace operators
}
// namespace paddle
#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
paddle/fluid/operators/math/math_function_test.cc
浏览文件 @
2abcf379
...
...
@@ -14,6 +14,13 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "gtest/gtest.h"
template
<
typename
T
>
inline
paddle
::
operators
::
math
::
BlasT
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
GetBlas
(
const
paddle
::
platform
::
CPUDeviceContext
&
context
)
{
return
paddle
::
operators
::
math
::
GetBlas
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
(
context
);
}
TEST
(
math_function
,
gemm_notrans_cblas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
...
...
@@ -34,9 +41,8 @@ TEST(math_function, gemm_notrans_cblas) {
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
1
,
4
,
1
,
input3_ptr
+
1
,
4
);
GetBlas
<
float
>
(
context
).
GEMM
(
false
,
false
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
1
,
4
,
1
,
input3_ptr
+
1
,
4
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
...
...
@@ -68,9 +74,8 @@ TEST(math_function, gemm_trans_clbas) {
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
3
,
3
,
1
,
input3_ptr
+
1
,
4
);
GetBlas
<
float
>
(
context
).
GEMM
(
false
,
true
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
3
,
3
,
1
,
input3_ptr
+
1
,
4
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
...
...
paddle/fluid/operators/math/math_function_test.cu
浏览文件 @
2abcf379
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
void
fill_fp16_data
(
paddle
::
platform
::
float16
*
in_ptr
,
size_t
size
,
const
std
::
vector
<
float
>&
data
)
{
...
...
@@ -178,6 +179,13 @@ TEST(math_function, trans_mul_notrans_fp16) {
EXPECT_EQ
(
static_cast
<
float
>
(
out_ptr
[
8
]),
29
);
}
template
<
typename
T
>
inline
paddle
::
operators
::
math
::
BlasT
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
GetBlas
(
const
paddle
::
platform
::
CUDADeviceContext
&
context
)
{
return
paddle
::
operators
::
math
::
GetBlas
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
context
);
}
TEST
(
math_function
,
gemm_notrans_cublas_fp32
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
...
...
@@ -210,8 +218,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
1
,
4
,
1
,
c
+
1
,
4
);
GetBlas
<
float
>
(
context
).
GEMM
(
false
,
false
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
1
,
4
,
1
,
c
+
1
,
4
);
paddle
::
framework
::
TensorCopySync
(
input3_gpu
,
cpu_place
,
&
input3
);
...
...
@@ -271,10 +279,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
paddle
::
platform
::
float16
*
c
=
input3_gpu
.
mutable_data
<
paddle
::
platform
::
float16
>
(
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
(
context
,
false
,
false
,
m
,
n
,
k
,
paddle
::
platform
::
float16
(
1
),
a
,
3
,
b
+
1
,
4
,
paddle
::
platform
::
float16
(
1
),
c
+
1
,
4
);
GetBlas
<
paddle
::
platform
::
float16
>
(
context
).
GEMM
(
false
,
false
,
m
,
n
,
k
,
static_cast
<
paddle
::
platform
::
float16
>
(
1
),
a
,
3
,
b
+
1
,
4
,
static_cast
<
paddle
::
platform
::
float16
>
(
1
),
c
+
1
,
4
);
paddle
::
framework
::
TensorCopySync
(
input3_gpu
,
cpu_place
,
&
input3
);
...
...
@@ -327,8 +334,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
3
,
3
,
1
,
c
+
1
,
4
);
GetBlas
<
float
>
(
context
).
GEMM
(
false
,
true
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
3
,
3
,
1
,
c
+
1
,
4
);
paddle
::
framework
::
TensorCopySync
(
input3_gpu
,
cpu_place
,
&
input3
);
...
...
@@ -382,10 +389,9 @@ TEST(math_function, gemm_trans_cublas_fp16) {
paddle
::
platform
::
float16
*
c
=
input3_gpu
.
mutable_data
<
paddle
::
platform
::
float16
>
(
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
(
context
,
false
,
true
,
m
,
n
,
k
,
paddle
::
platform
::
float16
(
1
),
a
,
3
,
b
+
3
,
3
,
paddle
::
platform
::
float16
(
1
),
c
+
1
,
4
);
GetBlas
<
paddle
::
platform
::
float16
>
(
context
).
GEMM
(
false
,
true
,
m
,
n
,
k
,
static_cast
<
paddle
::
platform
::
float16
>
(
1
),
a
,
3
,
b
+
3
,
3
,
static_cast
<
paddle
::
platform
::
float16
>
(
1
),
c
+
1
,
4
);
paddle
::
framework
::
TensorCopySync
(
input3_gpu
,
cpu_place
,
&
input3
);
...
...
paddle/fluid/operators/math/matmul.h
浏览文件 @
2abcf379
...
...
@@ -131,8 +131,9 @@ class MatMulFunctor {
if
(
!
batchCount
)
{
// regular matrix multiplication
gemm
<
DeviceContext
,
T
>
(
context
,
transA
,
transB
,
M
,
N
,
kA
,
alpha
,
a
.
data
<
T
>
(),
b
.
data
<
T
>
(),
beta
,
out
->
data
<
T
>
());
Blas
<
DeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
kA
,
alpha
,
a
.
data
<
T
>
(),
b
.
data
<
T
>
(),
beta
,
out
->
data
<
T
>
());
}
else
{
// batched matrix multiplication
batched_gemm
<
DeviceContext
,
T
>
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录