Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c888e016
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c888e016
编写于
4月 28, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor GEMM in blas
上级
c93a624b
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
357 addition
and
335 deletion
+357
-335
paddle/fluid/operators/bilinear_tensor_product_op.h
paddle/fluid/operators/bilinear_tensor_product_op.h
+11
-12
paddle/fluid/operators/gru_unit_op.h
paddle/fluid/operators/gru_unit_op.h
+23
-29
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+145
-0
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+68
-0
paddle/fluid/operators/math/gru_compute.cc
paddle/fluid/operators/math/gru_compute.cc
+23
-27
paddle/fluid/operators/math/gru_compute.cu
paddle/fluid/operators/math/gru_compute.cu
+25
-26
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+8
-74
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+6
-157
paddle/fluid/operators/math/math_function.h
paddle/fluid/operators/math/math_function.h
+45
-8
paddle/fluid/operators/math/matmul.h
paddle/fluid/operators/math/matmul.h
+3
-2
未找到文件。
paddle/fluid/operators/bilinear_tensor_product_op.h
浏览文件 @
c888e016
...
...
@@ -61,8 +61,8 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto
output_col_vec
=
output_mat
.
chip
(
i
,
1
);
Tensor
weight_mat
=
weight
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
make_ddim
({
x_dim
,
y_dim
}));
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
x
->
data
<
T
>
(),
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
).
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
x
->
data
<
T
>
(),
weight_mat
.
data
<
T
>
(),
0
,
left_mul
.
data
<
T
>
());
output_col_vec
.
device
(
place
)
=
(
left_mul_mat
*
y_mat
).
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
));
...
...
@@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
set_zero
(
dev_ctx
,
d_y
,
static_cast
<
T
>
(
0
));
}
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
// Caculate the Output(X@Grad) and Output(Y@Grad).
if
(
d_x
||
d_y
)
{
Eigen
::
DSizes
<
int
,
2
>
bcast_for_x
(
1
,
y_dim
);
...
...
@@ -138,8 +140,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
))
.
broadcast
(
bcast_for_x
)
*
y_mat
;
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasNoTrans
,
CblasTrans
,
batch_size
,
x_dim
,
y_dim
,
1
,
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
batch_size
,
x_dim
,
y_dim
,
1
,
y_scale
.
data
<
T
>
(),
weight_i
.
data
<
T
>
(),
1
,
d_x
->
data
<
T
>
());
}
if
(
d_y
)
{
...
...
@@ -147,8 +148,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
))
.
broadcast
(
bcast_for_y
)
*
x_mat
;
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
batch_size
,
y_dim
,
x_dim
,
1
,
x_scale
.
data
<
T
>
(),
weight_i
.
data
<
T
>
(),
1
,
d_y
->
data
<
T
>
());
}
}
...
...
@@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
))
.
broadcast
(
bcast_for_weight
)
*
x_mat
;
math
::
gemm
<
DeviceContext
,
T
>
(
dev_ctx
,
CblasTrans
,
CblasNoTrans
,
x_dim
,
y_dim
,
batch_size
,
1
,
x_scale
.
data
<
T
>
(),
y
->
data
<
T
>
(),
0
,
d_weight_i
.
data
<
T
>
());
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
x_dim
,
y_dim
,
batch_size
,
1
,
x_scale
.
data
<
T
>
(),
y
->
data
<
T
>
(),
0
,
d_weight_i
.
data
<
T
>
());
}
}
...
...
paddle/fluid/operators/gru_unit_op.h
浏览文件 @
c888e016
...
...
@@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
const
T
*
weight_data
=
weight
->
data
<
T
>
();
T
*
gate_data
=
gate
->
data
<
T
>
();
T
*
reset_hidden_prev_data
=
reset_hidden_prev
->
data
<
T
>
();
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
false
,
batch_size
,
2
*
frame_size
,
frame_size
,
1
,
hidden_prev_data
,
frame_size
,
weight_data
,
frame_size
*
2
,
1
,
gate_data
,
frame_size
*
3
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
2
*
frame_size
,
frame_size
,
1
,
hidden_prev_data
,
frame_size
,
weight_data
,
frame_size
*
2
,
1
,
gate_data
,
frame_size
*
3
);
// calculate activited gate
Eigen
::
array
<
int
,
2
>
extents
({{
batch_size
,
frame_size
}});
...
...
@@ -103,10 +103,9 @@ class GRUUnitKernel : public framework::OpKernel<T> {
g
.
slice
(
r_offsets
,
extents
),
g
.
slice
(
r_offsets
,
extents
));
auto
r
=
g
.
slice
(
r_offsets
,
extents
);
// reset gate
r_h_p
.
device
(
place
)
=
r
*
h_p
;
// reset previous hidden state
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
1
,
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
1
,
gate_data
+
frame_size
*
2
,
frame_size
*
3
);
Eigen
::
array
<
int
,
2
>
c_offsets
({{
0
,
frame_size
*
2
}});
...
...
@@ -188,11 +187,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
ActGradCompute
(
context
.
Attr
<
int
>
(
"activation"
),
place
,
c
,
c
,
d_g
.
slice
(
c_offsets
,
extents
),
d_h
*
u
);
// backward for reset_hidden_prev
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
0
,
reset_hidden_prev_grad_data
,
frame_size
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
weight_data
+
frame_size
*
frame_size
*
2
,
frame_size
,
0
,
reset_hidden_prev_grad_data
,
frame_size
);
// backward for unactivated reset gate
ActGradCompute
(
context
.
Attr
<
int
>
(
"gate_activation"
),
place
,
r
,
r
,
d_g
.
slice
(
r_offsets
,
extents
),
d_r_h_p
*
h_p
);
...
...
@@ -200,18 +199,15 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
if
(
weight_grad
)
{
T
*
weight_grad_data
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// backward for state_weight
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
0
,
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
reset_hidden_prev_data
,
frame_size
,
gate_grad_data
+
frame_size
*
2
,
frame_size
*
3
,
0
,
weight_grad_data
+
frame_size
*
frame_size
*
2
,
frame_size
);
// backward for update_gate_weight and reset_gate_weight
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
hidden_prev_data
,
frame_size
,
gate_grad_data
,
frame_size
*
3
,
0
,
weight_grad_data
,
frame_size
*
2
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
hidden_prev_data
,
frame_size
,
gate_grad_data
,
frame_size
*
3
,
0
,
weight_grad_data
,
frame_size
*
2
);
}
// backward for hidden_prev
if
(
hidden_prev_grad
)
{
...
...
@@ -219,11 +215,9 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
hidden_prev_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_h_p
=
EigenMatrix
<
T
>::
From
(
*
hidden_prev_grad
);
d_h_p
.
device
(
place
)
=
d_r_h_p
*
r
+
d_h
*
(
u
.
constant
(
T
(
1
))
-
u
);
math
::
gemm
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
gate_grad_data
,
frame_size
*
3
,
weight_data
,
frame_size
*
2
,
1
,
hidden_prev_grad_data
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
gate_grad_data
,
frame_size
*
3
,
weight_data
,
frame_size
*
2
,
1
,
hidden_prev_grad_data
,
frame_size
);
}
// backward for input
if
(
input_grad
)
{
...
...
paddle/fluid/operators/math/blas_impl.cu.h
0 → 100644
浏览文件 @
c888e016
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
CUBlas
;
template
<
>
struct
CUBlas
<
float
>
{
template
<
typename
...
ARGS
>
static
void
GEMM
(
ARGS
...
args
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
args
...));
}
};
template
<
>
struct
CUBlas
<
double
>
{
template
<
typename
...
ARGS
>
static
void
GEMM
(
ARGS
...
args
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
args
...));
}
};
template
<
>
struct
CUBlas
<
platform
::
float16
>
{
template
<
typename
...
ARGS
>
static
void
GEMM
(
ARGS
...
args
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
args
...));
}
};
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
}
template
<
>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
platform
::
float16
alpha
,
const
platform
::
float16
*
A
,
const
platform
::
float16
*
B
,
const
platform
::
float16
beta
,
platform
::
float16
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
53
,
"cublas fp16 gemm requires GPU compute capability >= 53"
);
#if CUDA_VERSION >= 8000
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
if
(
context_
.
GetComputeCapability
()
>=
70
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context_
.
cublas_handle
(),
CUBLAS_TENSOR_OP_MATH
));
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
else
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context_
.
cublas_handle
(),
CUBLAS_DEFAULT_MATH
));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
B
,
CUDA_R_16F
,
ldb
,
A
,
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
,
algo
));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const
half
h_alpha
=
static_cast
<
const
half
>
(
alpha
);
const
half
h_beta
=
static_cast
<
const
half
>
(
beta
);
const
half
*
h_A
=
reinterpret_cast
<
const
half
*>
(
A
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
CUBlas
<
platform
::
float16
>
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
#endif // CUDA_VERSION >= 8000
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/blas_impl.h
0 → 100644
浏览文件 @
c888e016
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
CBlas
;
template
<
>
struct
CBlas
<
float
>
{
static
constexpr
auto
GEMM
=
cblas_sgemm
;
};
template
<
>
struct
CBlas
<
double
>
{
static
constexpr
auto
GEMM
=
cblas_dgemm
;
};
template
<
>
struct
CBlas
<
platform
::
float16
>
{
void
GEMM
(...)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
};
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
{
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/gru_compute.cc
浏览文件 @
c888e016
...
...
@@ -25,21 +25,21 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
const
detail
::
ActivationType
active_node
,
const
detail
::
ActivationType
active_gate
)
{
#ifndef __NVCC__
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
}
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
frame_size
,
batch_size
,
active_gate
);
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
...
...
@@ -58,16 +58,15 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_node
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
&&
grad
.
prev_out_grad
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
if
(
grad
.
state_weight_grad
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
...
...
@@ -76,18 +75,15 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_gate
);
if
(
grad
.
prev_out_grad
&&
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
if
(
grad
.
gate_weight_grad
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
}
}
#endif
...
...
paddle/fluid/operators/math/gru_compute.cu
浏览文件 @
c888e016
...
...
@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h"
...
...
@@ -36,12 +37,11 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame_size
+
32
-
1
)
/
32
,
(
batch_size
+
32
-
1
)
/
32
);
}
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
1
,
value
.
prev_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_size
*
3
);
}
if
(
batch_size
==
1
)
{
...
...
@@ -61,10 +61,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
}
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
blas
.
GEMM
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
1
,
value
.
reset_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
}
if
(
batch_size
==
1
)
{
...
...
@@ -121,15 +121,16 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
grad
.
output_grad
,
frame_size
,
batch_size
,
active_node
);
}
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
&&
grad
.
prev_out_grad
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_weight
,
frame_size
,
0
,
grad
.
reset_output_grad
,
frame_size
);
if
(
grad
.
state_weight_grad
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
...
...
@@ -153,16 +154,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
}
if
(
grad
.
prev_out_grad
&&
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
grad
.
prev_out_grad
,
frame_size
);
if
(
grad
.
gate_weight_grad
)
{
math
::
gemm
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
blas
.
GEMM
(
true
,
false
,
frame_size
,
frame_size
*
2
,
batch_size
,
1
,
value
.
prev_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_size
*
3
,
1
,
grad
.
gate_weight_grad
,
frame_size
*
2
);
}
}
}
...
...
paddle/fluid/operators/math/math_function.cc
浏览文件 @
c888e016
...
...
@@ -24,72 +24,6 @@ namespace math {
using
float16
=
paddle
::
platform
::
float16
;
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
float16
*
B
,
const
float16
beta
,
float16
*
C
)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cblas_sgemm
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cblas_dgemm
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
int
lda
,
const
float16
*
B
,
const
int
ldb
,
const
float16
beta
,
float16
*
C
,
const
int
ldc
)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
cblas_sgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
cblas_dgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
matmul
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
...
...
@@ -123,8 +57,8 @@ void matmul<platform::CPUDeviceContext, float>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
...
...
@@ -152,8 +86,8 @@ void matmul<platform::CPUDeviceContext, double>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
...
...
@@ -230,7 +164,7 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const
float
*
Ak
=
&
A
[
k
*
strideA
];
const
float
*
Bk
=
&
B
[
k
*
strideB
];
float
*
Ck
=
&
C
[
k
*
M
*
N
];
gemm
<
platform
::
CPUDeviceContext
,
float
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
}
}
...
...
@@ -246,7 +180,7 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const
double
*
Ak
=
&
A
[
k
*
strideA
];
const
double
*
Bk
=
&
B
[
k
*
strideB
];
double
*
Ck
=
&
C
[
k
*
M
*
N
];
gemm
<
platform
::
CPUDeviceContext
,
double
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
}
}
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
c888e016
...
...
@@ -25,157 +25,6 @@ namespace math {
using
float16
=
paddle
::
platform
::
float16
;
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
float16
*
B
,
const
float16
beta
,
float16
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context
.
GetComputeCapability
(),
53
,
"cublas fp16 gemm requires GPU compute capability >= 53"
);
#if CUDA_VERSION >= 8000
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
if
(
context
.
GetComputeCapability
()
>=
70
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context
.
cublas_handle
(),
CUBLAS_TENSOR_OP_MATH
));
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
else
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
context
.
cublas_handle
(),
CUBLAS_DEFAULT_MATH
));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
B
,
CUDA_R_16F
,
ldb
,
A
,
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
,
algo
));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const
half
h_alpha
=
static_cast
<
const
half
>
(
alpha
);
const
half
h_beta
=
static_cast
<
const
half
>
(
beta
);
const
half
*
h_A
=
reinterpret_cast
<
const
half
*>
(
A
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
));
#endif // CUDA_VERSION >= 8000
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
double
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
int
lda
,
const
float16
*
B
,
const
int
ldb
,
const
float16
beta
,
float16
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
const
half
h_alpha
=
static_cast
<
const
half
>
(
alpha
);
const
half
h_beta
=
static_cast
<
const
half
>
(
beta
);
const
half
*
h_A
=
reinterpret_cast
<
const
half
*>
(
A
);
const
half
*
h_B
=
reinterpret_cast
<
const
half
*>
(
B
);
half
*
h_C
=
reinterpret_cast
<
half
*>
(
C
);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context
.
GetComputeCapability
(),
53
,
"cublas Hgemm requires GPU compute capability >= 53"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
ldc
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
float
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
gemm
<
platform
::
CUDADeviceContext
,
double
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
matmul
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
...
...
@@ -200,8 +49,8 @@ void matmul<platform::CUDADeviceContext, float16>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CUDADeviceContext
,
float16
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float16
>
(),
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float16
>
(),
matrix_b
.
data
<
float16
>
(),
beta
,
matrix_out
->
data
<
float16
>
());
}
...
...
@@ -229,8 +78,8 @@ void matmul<platform::CUDADeviceContext, float>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CUDADeviceContext
,
float
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
...
...
@@ -258,8 +107,8 @@ void matmul<platform::CUDADeviceContext, double>(
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CUDADeviceContext
,
double
>
(
context
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
...
...
paddle/fluid/operators/math/math_function.h
浏览文件 @
c888e016
...
...
@@ -42,6 +42,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -56,17 +57,48 @@ namespace math {
// Then matrixA: M * K, matrixB: K * N, matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template
<
typename
DeviceContext
>
class
Blas
{
public:
explicit
Blas
(
const
DeviceContext
&
context
)
:
context_
(
context
)
{}
template
<
typename
T
>
void
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
;
template
<
typename
T
>
void
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
;
private:
const
DeviceContext
&
context_
;
};
template
<
typename
DeviceContext
,
typename
T
>
void
gemm
(
const
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
);
class
BlasT
:
private
Blas
<
DeviceContext
>
{
public:
using
Blas
<
DeviceContext
>::
Blas
;
template
<
typename
...
ARGS
>
void
GEMM
(
ARGS
...
args
)
const
{
static_cast
<
const
Blas
<
DeviceContext
>*>
(
this
)
->
template
GEMM
<
T
>(
args
...);
}
};
// gemm wrapper with stride args for matrix uncontinuous in memory
template
<
typename
DeviceContext
,
typename
T
>
void
gemm
(
const
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
);
inline
BlasT
<
DeviceContext
,
T
>
GetBlas
(
const
framework
::
ExecutionContext
&
exe_ctx
)
{
return
BlasT
<
DeviceContext
,
T
>
(
exe_ctx
.
template
device_context
<
DeviceContext
>());
}
template
<
typename
DeviceContext
,
typename
T
>
inline
BlasT
<
DeviceContext
,
T
>
GetBlas
(
const
DeviceContext
&
dev_ctx
)
{
return
BlasT
<
DeviceContext
,
T
>
(
dev_ctx
);
}
// matrix multiply with continuous memory
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -137,3 +169,8 @@ struct RowwiseMean {
}
// namespace math
}
// namespace operators
}
// namespace paddle
#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
paddle/fluid/operators/math/matmul.h
浏览文件 @
c888e016
...
...
@@ -131,8 +131,9 @@ class MatMulFunctor {
if
(
!
batchCount
)
{
// regular matrix multiplication
gemm
<
DeviceContext
,
T
>
(
context
,
transA
,
transB
,
M
,
N
,
kA
,
alpha
,
a
.
data
<
T
>
(),
b
.
data
<
T
>
(),
beta
,
out
->
data
<
T
>
());
Blas
<
DeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
kA
,
alpha
,
a
.
data
<
T
>
(),
b
.
data
<
T
>
(),
beta
,
out
->
data
<
T
>
());
}
else
{
// batched matrix multiplication
batched_gemm
<
DeviceContext
,
T
>
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录