Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f21b6f08
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f21b6f08
编写于
5月 09, 2023
作者:
L
limingshu
提交者:
GitHub
5月 09, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cherry pick fused linear (#53621)
Cherry pick fused linear
上级
77eeb226
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
602 addition
and
337 deletion
+602
-337
paddle/fluid/operators/fused/attn_gemm.h
paddle/fluid/operators/fused/attn_gemm.h
+14
-19
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc
+24
-38
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
+13
-21
paddle/phi/kernels/autotune/cache.cc
paddle/phi/kernels/autotune/cache.cc
+1
-1
paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
+433
-187
paddle/phi/kernels/funcs/common_shape.h
paddle/phi/kernels/funcs/common_shape.h
+2
-1
paddle/phi/kernels/funcs/dropout_impl.cu.h
paddle/phi/kernels/funcs/dropout_impl.cu.h
+4
-2
paddle/phi/kernels/funcs/fused_gemm_epilogue.h
paddle/phi/kernels/funcs/fused_gemm_epilogue.h
+109
-66
paddle/phi/kernels/gpu/cross_entropy_kernel.cu
paddle/phi/kernels/gpu/cross_entropy_kernel.cu
+2
-2
未找到文件。
paddle/fluid/operators/fused/attn_gemm.h
浏览文件 @
f21b6f08
...
...
@@ -68,25 +68,20 @@ class AttnMatMul {
"The output (= input * weight) is expected to be nullptr or the "
"same as bias_out when fused is true."
));
auto
fused_impl
=
phi
::
funcs
::
MatmulPlanner
(
vectorize
(
input
->
dims
()),
vectorize
(
weight
->
dims
()),
transA_
,
transB_
,
phi
::
CppTypeToDataType
<
T
>::
Type
(),
phi
::
funcs
::
MatmulFusedType
::
kMatmulBias
,
static_cast
<
const
void
*>
(
bias
->
data
<
T
>
()),
nullptr
);
phi
::
funcs
::
MatmulWithCublasLt
<
T
>::
Run
(
dev_ctx_
,
input
->
data
<
T
>
(),
weight
->
data
<
T
>
(),
bias_out
->
data
<
T
>
(),
phi
::
funcs
::
LinearWithCublasLt
<
T
>::
Run
(
dev_ctx_
,
input
,
// x
weight
,
// y
bias_out
,
// out
static_cast
<
const
void
*>
(
bias
->
data
<
T
>
()),
// bias
nullptr
,
bsz_seq_
,
// M
output_size_
,
// N
input_size_
,
// K
transA_
,
transB_
,
&
fused_impl
);
phi
::
funcs
::
MatmulFusedType
::
kMatmulBias
);
return
;
}
#endif
...
...
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc
浏览文件 @
f21b6f08
...
...
@@ -36,7 +36,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
bias_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
trans_x
=
ctx
->
Attrs
().
Get
<
bool
>
(
"trans_x"
);
auto
trans_y
=
ctx
->
Attrs
().
Get
<
bool
>
(
"trans_y"
);
...
...
@@ -88,27 +87,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
K_from_x
,
K_from_y
));
auto
activation
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"activation"
);
if
(
activation
==
"none"
&&
ctx
->
HasOutput
(
"ReserveSpace"
))
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The ReserveSpace would not be used when activation =
\"
none
\"
"
));
}
// cublasLt's restriction for auxiliary.
if
(
ctx
->
HasOutput
(
"ReserveSpace"
)
&&
activation
!=
"none"
)
{
int
min_size_of_n
=
activation
==
"relu"
?
128
:
8
;
int
N_size
=
trans_y
?
y_dims
[
0
]
:
y_dims
[
1
];
PADDLE_ENFORCE_EQ
(
N_size
%
min_size_of_n
,
0
,
platform
::
errors
::
InvalidArgument
(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d."
,
min_size_of_n
,
activation
,
N_size
));
}
std
::
vector
<
int64_t
>
out_dims
;
out_dims
.
reserve
(
static_cast
<
size_t
>
(
x_dims
.
size
()));
if
(
trans_x
)
{
...
...
@@ -122,11 +100,29 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
}
else
{
out_dims
.
push_back
(
y_dims
[
1
]);
}
ctx
->
SetOutputDim
(
"Out"
,
phi
::
make_ddim
(
out_dims
));
auto
activation
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"activation"
);
if
(
ctx
->
HasOutput
(
"ReserveSpace"
))
{
ctx
->
SetOutputDim
(
"ReserveSpace"
,
phi
::
make_ddim
(
out_dims
));
if
(
activation
==
"none"
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The ReserveSpace would not be used when activation =
\"
none
\"
"
));
}
else
{
int
min_size_of_n
=
activation
==
"relu"
?
128
:
8
;
int
N_size
=
trans_y
?
y_dims
[
0
]
:
y_dims
[
1
];
PADDLE_ENFORCE_EQ
(
N_size
%
min_size_of_n
,
0
,
platform
::
errors
::
InvalidArgument
(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d."
,
min_size_of_n
,
activation
,
N_size
));
}
}
}
...
...
@@ -202,7 +198,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto
dout_dims
=
ctx
->
GetInputDim
(
"DOut"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
trans_x
=
ctx
->
Attrs
().
Get
<
bool
>
(
"trans_x"
);
auto
trans_y
=
ctx
->
Attrs
().
Get
<
bool
>
(
"trans_y"
);
...
...
@@ -241,7 +236,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
x_dims
.
size
()));
auto
dout_mat_dims
=
phi
::
flatten_to_2d
(
dout_dims
,
dout_dims
.
size
()
-
1
);
auto
x_mat_dims
=
phi
::
flatten_to_2d
(
x_dims
,
x_dims
.
size
()
-
1
);
PADDLE_ENFORCE_EQ
(
...
...
@@ -268,25 +262,17 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
false
,
platform
::
errors
::
InvalidArgument
(
"The ReserveSpace should not be empty. "
"when activation
_grad
== {relu_grad, gelu_grad}."
));
"when activation == {relu_grad, gelu_grad}."
));
}
if
(
ctx
->
HasOutput
(
"DX"
))
{
std
::
vector
<
int64_t
>
dx_dims
;
dx_dims
.
reserve
(
static_cast
<
size_t
>
(
x_dims
.
size
()));
for
(
int
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
dx_dims
.
push_back
(
x_dims
[
i
]);
ctx
->
SetOutputDim
(
"DX"
,
x_dims
);
}
ctx
->
SetOutputDim
(
"DX"
,
phi
::
make_ddim
(
dx_dims
));
}
std
::
vector
<
int64_t
>
dy_dims
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
());
ctx
->
SetOutputDim
(
"DY"
,
phi
::
make_ddim
(
dy_dims
));
ctx
->
SetOutputDim
(
"DY"
,
y_dims
);
if
(
ctx
->
HasOutput
(
"DBias"
))
{
std
::
vector
<
int64_t
>
dbias_dims
;
dbias_dims
.
push_back
(
trans_y
?
y_dims
[
0
]
:
y_dims
[
1
]);
ctx
->
SetOutputDim
(
"DBias"
,
phi
::
make_ddim
(
dbias_dims
));
int64_t
dbias_dim
=
trans_y
?
y_dims
[
0
]
:
y_dims
[
1
];
ctx
->
SetOutputDim
(
"DBias"
,
phi
::
make_ddim
({
dbias_dim
}));
}
}
...
...
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
浏览文件 @
f21b6f08
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
namespace
paddle
{
...
...
@@ -101,26 +100,19 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
<<
", activation="
<<
activation
<<
", fused_type="
<<
fused_type
<<
", reserve_space="
<<
reserve_space
;
auto
fused_impl
=
phi
::
funcs
::
MatmulPlanner
(
vectorize
(
x
->
dims
()),
vectorize
(
y
->
dims
()),
trans_x
,
trans_y
,
phi
::
CppTypeToDataType
<
T
>::
Type
(),
fused_type
,
phi
::
funcs
::
LinearWithCublasLt
<
T
>::
Run
(
dev_ctx
,
x
,
y
,
out
,
static_cast
<
const
void
*>
(
bias
->
data
<
T
>
()),
reserve_data
);
phi
::
funcs
::
MatmulWithCublasLt
<
T
>::
Run
(
dev_ctx
,
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
out
->
data
<
T
>
(),
reserve_data
,
M
,
N
,
K
,
trans_x
,
trans_y
,
&
fused_impl
);
fused_type
);
}
};
...
...
paddle/phi/kernels/autotune/cache.cc
浏览文件 @
f21b6f08
...
...
@@ -25,7 +25,7 @@ size_t TransposeKey(const std::vector<int64_t>& x_dims,
const
std
::
vector
<
int32_t
>&
perm
,
phi
::
DataType
dtype
)
{
const
auto
rank
=
perm
.
size
();
return
GenKey
(
x_dims
,
perm
,
rank
,
static_cast
<
int
64_t
>
(
dtype
));
return
GenKey
(
x_dims
,
perm
,
rank
,
static_cast
<
int
>
(
dtype
));
}
std
::
string
AlgorithmTypeString
(
int64_t
algo_type
)
{
...
...
paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
浏览文件 @
f21b6f08
...
...
@@ -33,20 +33,87 @@ namespace phi {
namespace
funcs
{
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
// Set this enum according to
// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t
// While kMatmul, kMatmulGrad, kMatmulGradWithoutBias share the same
// enum value, but if all elements for MatmulPlanner->GetKey() is same,
// no matter forward or backward, they could share the same descriptor
// cache, in that the descritpor is for decription of matmul operation.
enum
MatmulFusedType
{
kMatmul
=
CUBLASLT_EPILOGUE_DEFAULT
,
// No special postprocessing.
kMatmul
=
CUBLASLT_EPILOGUE_DEFAULT
,
kMatmulGrad
=
CUBLASLT_EPILOGUE_DEFAULT
,
kMatmulGradWithoutBias
=
CUBLASLT_EPILOGUE_DEFAULT
,
kMatmulBias
=
CUBLASLT_EPILOGUE_BIAS
,
kMatmulRelu
=
CUBLASLT_EPILOGUE_RELU
,
kMatmulBiasRelu
=
CUBLASLT_EPILOGUE_RELU_BIAS
,
// Apply bias and then ReLU transform.
kMatmulBiasGelu
=
CUBLASLT_EPILOGUE_GELU_BIAS
,
// Apply Bias and then GELU transform.
kMatmulBiasRelu
=
CUBLASLT_EPILOGUE_RELU_BIAS
,
kMatmulBiasGelu
=
CUBLASLT_EPILOGUE_GELU_BIAS
,
kMatmulBiasReluWithReservedData
=
CUBLASLT_EPILOGUE_RELU_AUX_BIAS
,
kMatmulBiasGeluWithReservedData
=
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
kMatmulBiasGeluWithReservedData
=
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
,
kMatmulReluGrad
=
CUBLASLT_EPILOGUE_DRELU
,
kMatmulGeluGrad
=
CUBLASLT_EPILOGUE_DGELU
,
kMatmulBiasGradToA
=
CUBLASLT_EPILOGUE_BGRADA
,
kMatmulBiasGradToB
=
CUBLASLT_EPILOGUE_BGRADB
};
enum
FusedGEMMGradInType
{
kDX
=
0
,
kDY
=
1
,
kDZ
=
2
};
template
<
bool
TransX
,
bool
TransY
>
struct
FusedGEMMGradTrait
;
template
<
>
struct
FusedGEMMGradTrait
<
false
,
false
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradATrans
=
false
;
static
constexpr
auto
kXGradBTrans
=
true
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradATrans
=
true
;
static
constexpr
auto
kYGradBTrans
=
false
;
};
template
<
>
struct
FusedGEMMGradTrait
<
true
,
false
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradATrans
=
false
;
static
constexpr
auto
kXGradBTrans
=
true
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradATrans
=
false
;
static
constexpr
auto
kYGradBTrans
=
false
;
};
template
<
>
struct
FusedGEMMGradTrait
<
false
,
true
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradATrans
=
false
;
static
constexpr
auto
kXGradBTrans
=
false
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradATrans
=
true
;
static
constexpr
auto
kYGradBTrans
=
false
;
};
template
<
>
struct
FusedGEMMGradTrait
<
true
,
true
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradATrans
=
true
;
static
constexpr
auto
kXGradBTrans
=
true
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradATrans
=
true
;
static
constexpr
auto
kYGradBTrans
=
true
;
};
// To tell any matmul or fused matmul operation from each other.
struct
MatmulPlanner
{
public:
const
void
*
bias
{
nullptr
};
...
...
@@ -60,23 +127,31 @@ struct MatmulPlanner {
phi
::
DataType
dtype
,
MatmulFusedType
impl_type
,
const
void
*
bias_data
=
nullptr
,
void
*
reserve_data
=
nullptr
)
:
bias
(
bias_data
),
aux_data
(
reserve_data
)
{
type
=
impl_type
;
key
=
phi
::
autotune
::
GenKey
(
x_dims
,
void
*
reserve_data
=
nullptr
,
// Commonly for ReLu bit-mask.
bool
use_addto
=
false
,
bool
no_exchange
=
true
)
:
bias
(
bias_data
),
aux_data
(
reserve_data
),
impl_type_
(
impl_type
)
{
use_addto_
=
use_addto
;
key_
=
phi
::
autotune
::
GenKey
(
x_dims
,
y_dims
,
static_cast
<
int64_t
>
(
trans_x
),
static_cast
<
int64_t
>
(
trans_y
),
static_cast
<
int64_t
>
(
dtype
));
static_cast
<
int
>
(
trans_x
),
static_cast
<
int
>
(
trans_y
),
static_cast
<
int
>
(
dtype
),
static_cast
<
int
>
(
no_exchange
));
}
MatmulFusedType
ImplType
()
const
{
return
type
;
}
size_t
GetKey
()
const
{
return
key
;
}
size_t
GenSubKey
(
int
idx
)
const
{
return
phi
::
autotune
::
GenKey
(
key
,
idx
);
}
bool
UseAddTo
()
const
{
return
use_addto_
;
}
size_t
GetKey
()
const
{
return
key_
;
}
MatmulFusedType
ImplType
()
const
{
return
impl_type_
;
}
size_t
GenSubKey
(
int
idx
)
const
{
return
phi
::
autotune
::
GenKey
(
key_
,
static_cast
<
int
>
(
use_addto_
),
idx
);
}
private:
MatmulFusedType
type
;
size_t
key
;
MatmulFusedType
impl_type_
;
bool
use_addto_
;
size_t
key_
;
};
template
<
typename
T
>
...
...
@@ -124,19 +199,19 @@ struct MatmulDescriptor {
}
// x_desc, y_desc, op_desc are allocated in heap memory.
template
<
typename
T
>
void
Create
(
const
int
M
,
const
int
N
,
const
int
K
,
template
<
typename
T
,
typename
DXT
,
typename
DYT
,
bool
TransX
,
bool
TransY
>
void
Create
(
const
int
64_t
M
,
const
int
64_t
N
,
const
int
64_t
K
,
const
bool
trans_x
,
const
bool
trans_y
,
phi
::
funcs
::
MatmulPlanner
*
planner
,
const
int
batch_size
=
1
,
int64_t
stride_x
=
0
,
int64_t
stride_y
=
0
,
int64_t
stride_out
=
0
)
{
const
int64_t
stride_x
=
0
,
const
int64_t
stride_y
=
0
,
const
int64_t
stride_out
=
0
,
bool
grad_for_dx
=
true
)
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
cudaDataType_t
mat_type
=
phi
::
backends
::
gpu
::
ToCudaDataType
<
T
>
();
cudaDataType_t
scale_type
=
phi
::
backends
::
gpu
::
ToCudaDataType
<
MT
>
();
cublasComputeType_t
compute_type
=
GetCudaComputeType
<
T
>
();
...
...
@@ -145,18 +220,7 @@ struct MatmulDescriptor {
// details about defaults; just need to set the transforms for A and B
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescCreate
(
&
op_desc
,
compute_type
,
scale_type
));
cublasOperation_t
cublas_trans_x
=
trans_x
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cublas_trans_y
=
trans_y
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
cublas_trans_x
,
sizeof
(
cublas_trans_x
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
cublas_trans_y
,
sizeof
(
cublas_trans_y
)));
SetFusedEpilogueOpDescriptor
(
planner
,
trans_x
,
trans_y
,
N
);
// Create matrix descriptors
CreateMatrixLayout
(
&
x_desc
,
mat_type
,
M
,
K
,
trans_x
);
...
...
@@ -169,7 +233,6 @@ struct MatmulDescriptor {
SetBatchAndStride
(
y_desc
,
batch_size
,
stride_y
);
SetBatchAndStride
(
out_desc
,
batch_size
,
stride_out
);
}
SetFusedEpilogueOpDescriptor
(
planner
,
N
);
}
cublasLtMatmulAlgo_t
*
SetAlgo
()
{
...
...
@@ -188,7 +251,7 @@ struct MatmulDescriptor {
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_data
,
sizeof
(
bias_data
)));
}
if
(
planner
->
aux_data
!=
nullptr
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
...
...
@@ -197,7 +260,6 @@ struct MatmulDescriptor {
sizeof
(
planner
->
aux_data
)));
}
}
}
std
::
string
GetDescResultString
(
std
::
string
prefix
,
bool
has_algo
=
true
)
const
{
...
...
@@ -223,7 +285,42 @@ struct MatmulDescriptor {
return
out
.
str
();
}
private:
void
ExchangeXYDesc
(
bool
no_exchange
)
{}
protected:
void
SetFusedEpilogueOpDescriptor
(
phi
::
funcs
::
MatmulPlanner
*
planner
,
const
bool
trans_x
,
const
bool
trans_y
,
int64_t
lead_dim
)
{
cublasOperation_t
cublas_trans_x
=
trans_x
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cublas_trans_y
=
trans_y
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
cublas_trans_x
,
sizeof
(
cublas_trans_x
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
cublas_trans_y
,
sizeof
(
cublas_trans_y
)));
if
(
planner
->
ImplType
()
!=
kMatmul
)
{
auto
fused_type
=
static_cast
<
cublasLtEpilogue_t
>
(
planner
->
ImplType
());
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
fused_type
,
sizeof
(
fused_type
)));
}
if
(
planner
->
aux_data
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
lead_dim
,
sizeof
(
lead_dim
)));
}
}
void
CreateMatrixLayout
(
cublasLtMatrixLayout_t
*
desc
,
cudaDataType
type
,
uint64_t
rows
,
...
...
@@ -252,145 +349,62 @@ struct MatmulDescriptor {
&
stride
,
sizeof
(
stride
)));
}
void
SetFusedEpilogueOpDescriptor
(
phi
::
funcs
::
MatmulPlanner
*
planner
,
int64_t
lead_dim
)
{
if
(
planner
->
bias
)
{
auto
fuse_type
=
static_cast
<
cublasLtEpilogue_t
>
(
planner
->
ImplType
());
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
fuse_type
,
sizeof
(
fuse_type
)));
if
(
planner
->
aux_data
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
lead_dim
,
sizeof
(
lead_dim
)));
}
}
}
};
template
<
typename
T
>
struct
DescriptorSetter
{
MatmulDescriptor
desc
;
size_t
sub_key
{
std
::
numeric_limits
<
size_t
>::
min
()};
struct
MatmulGradDescriptor
:
MatmulDescriptor
{
public:
MatmulGradDescriptor
()
{}
DescriptorSetter
(
phi
::
funcs
::
MatmulPlanner
*
planner
,
const
in
t
M
,
const
in
t
N
,
const
in
t
K
,
template
<
typename
T
,
typename
DXT
,
typename
DYT
,
bool
TransX
,
bool
TransY
>
void
Create
(
const
int64_
t
M
,
const
int64_
t
N
,
const
int64_
t
K
,
const
bool
trans_x
,
const
bool
trans_y
,
phi
::
funcs
::
MatmulPlanner
*
planner
,
const
int
batch_size
=
1
,
int64_t
stride_x
=
0
,
int64_t
stride_y
=
0
,
int64_t
stride_out
=
0
)
{
if
(
planner
!=
nullptr
)
{
sub_key
=
planner
->
GenSubKey
(
static_cast
<
size_t
>
(
planner
->
ImplType
()));
}
int64_t
stride_out
=
0
,
bool
grad_for_dx
=
true
)
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
cudaDataType_t
mat_type
=
phi
::
backends
::
gpu
::
ToCudaDataType
<
T
>
();
cudaDataType_t
scale_type
=
phi
::
backends
::
gpu
::
ToCudaDataType
<
MT
>
();
cublasComputeType_t
compute_type
=
GetCudaComputeType
<
T
>
();
auto
&
mamtul_cache
=
phi
::
autotune
::
AutoTuneCache
::
Instance
().
GetMatmul
();
if
(
mamtul_cache
.
FindSubKey
(
sub_key
))
{
desc
=
*
(
reinterpret_cast
<
MatmulDescriptor
*>
(
mamtul_cache
.
GetSubKey
(
sub_key
)));
desc
.
SetFusedEpiloguePtr
<
T
>
(
planner
);
VLOG
(
6
)
<<
desc
.
GetDescResultString
(
"[Heap MatmulDescriptor] "
);
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulDescCreate
(
&
op_desc
,
compute_type
,
scale_type
));
this
->
SetFusedEpilogueOpDescriptor
(
planner
,
trans_x
,
trans_y
,
TransX
?
M
:
K
);
// Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
// details about defaults; just need to set the transforms for A and B
this
->
CreateMatrixLayout
(
&
x_desc
,
mat_type
,
N
,
M
,
true
);
if
(
grad_for_dx
)
{
this
->
CreateMatrixLayout
(
&
y_desc
,
mat_type
,
K
,
N
,
TransY
);
this
->
CreateMatrixLayout
(
&
out_desc
,
phi
::
backends
::
gpu
::
ToCudaDataType
<
DXT
>
(),
M
,
K
,
TransX
);
}
else
{
desc
.
Create
<
T
>
(
M
,
N
,
K
,
trans_x
,
trans_y
,
planner
,
batch_size
,
stride_x
,
stride_y
,
stride_out
);
if
(
planner
!=
nullptr
)
{
desc
.
SetFusedEpiloguePtr
<
T
>
(
planner
);
this
->
CreateMatrixLayout
(
&
y_desc
,
mat_type
,
M
,
K
,
TransX
);
this
->
CreateMatrixLayout
(
&
out_desc
,
phi
::
backends
::
gpu
::
ToCudaDataType
<
DYT
>
(),
K
,
N
,
TransY
);
}
VLOG
(
6
)
<<
desc
.
GetDescResultString
(
"[Stack MatmulDescriptor] "
,
false
);
}
void
ExchangeXYDesc
(
bool
no_exchange
)
{
if
(
no_exchange
)
{
return
;
}
auto
*
temp
=
y_desc
;
y_desc
=
x_desc
;
x_desc
=
temp
;
}
};
template
<
typename
T
>
struct
MatmulWithCublasLt
{
template
<
typename
T
,
typename
OutT
=
T
,
class
MatmulDescT
=
MatmulDescriptor
>
struct
CublasLtBase
{
public:
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
static
void
Run
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
x_data
,
const
T
*
y_data
,
T
*
out_data
,
const
int
M
,
const
int
N
,
const
int
K
,
const
bool
trans_x
,
const
bool
trans_y
,
phi
::
funcs
::
MatmulPlanner
*
planner
=
nullptr
)
{
auto
setter
=
DescriptorSetter
<
T
>
(
planner
,
M
,
N
,
K
,
trans_x
,
trans_y
);
RunImpl
(
ctx
,
&
setter
.
desc
,
setter
.
sub_key
,
x_data
,
y_data
,
out_data
,
planner
);
}
static
void
RunWithBatch
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
x_data
,
const
T
*
y_data
,
T
*
out_data
,
const
int
M
,
const
int
N
,
const
int
K
,
bool
trans_x
,
bool
trans_y
,
int
batch_size
,
int64_t
stride_x
,
int64_t
stride_y
,
int64_t
stride_out
,
phi
::
funcs
::
MatmulPlanner
*
planner
=
nullptr
)
{
auto
setter
=
DescriptorSetter
<
T
>
(
planner
,
M
,
N
,
K
,
trans_x
,
trans_y
,
batch_size
,
stride_x
,
stride_y
,
stride_out
);
RunImpl
(
ctx
,
&
setter
.
desc
,
setter
.
sub_key
,
x_data
,
y_data
,
out_data
,
planner
);
}
static
void
RunWithBatch
(
const
phi
::
GPUContext
&
ctx
,
const
T
**
x_data
,
const
T
**
y_data
,
T
**
out_data
,
const
int
M
,
const
int
N
,
const
int
K
,
bool
trans_x
,
bool
trans_y
,
int
batch_size
,
phi
::
funcs
::
MatmulPlanner
*
planner
=
nullptr
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
Run
(
ctx
,
x_data
[
i
],
y_data
[
i
],
out_data
[
i
],
M
,
N
,
K
,
trans_x
,
trans_y
,
planner
);
}
}
private:
static
phi
::
Allocator
::
AllocationPtr
GetWorkspace
(
const
phi
::
GPUContext
&
ctx
,
size_t
workspace_size
)
{
return
phi
::
memory_utils
::
Alloc
(
...
...
@@ -400,16 +414,19 @@ struct MatmulWithCublasLt {
}
static
void
RunImpl
(
const
phi
::
GPUContext
&
ctx
,
MatmulDesc
riptor
*
desc
,
MatmulDesc
T
*
desc
,
const
size_t
sub_key
,
const
T
*
x_ptr
,
const
T
*
y_ptr
,
T
*
out_ptr
,
Out
T
*
out_ptr
,
phi
::
funcs
::
MatmulPlanner
*
planner
)
{
MT
alpha
=
static_cast
<
MT
>
(
1
);
MT
beta
=
static_cast
<
MT
>
(
0
);
MT
beta
=
planner
->
UseAddTo
()
?
static_cast
<
MT
>
(
1
)
:
static_cast
<
MT
>
(
0
);
cublasLtHandle_t
cublaslt_handle
=
ctx
.
cublaslt_handle
();
// NOTE(limingshu): As workspace_size varies from different DL framework,
// I wonder is there any smarter idea for workspace setting, currently I
// just followed the settings from the NVIDIA colleague`s setting.
size_t
workspace_size
=
static_cast
<
size_t
>
(
4
)
*
1024
*
1024
;
phi
::
Allocator
::
AllocationPtr
workspace
=
GetWorkspace
(
ctx
,
workspace_size
);
...
...
@@ -426,16 +443,16 @@ struct MatmulWithCublasLt {
out_ptr
,
workspace
->
ptr
(),
workspace_size
);
MatmulDesc
riptor
*
best_desc
=
new
MatmulDescriptor
(
*
desc
);
MatmulDesc
T
*
best_desc
=
new
MatmulDescT
(
*
desc
);
VLOG
(
6
)
<<
best_desc
->
GetDescResultString
(
"[Searched
Matmul
Descriptor] "
);
"[Searched
Cublaslt
Descriptor] "
);
auto
&
cache
=
phi
::
autotune
::
AutoTuneCache
::
Instance
().
GetMatmul
();
cache
.
SetSubKey
(
sub_key
,
reinterpret_cast
<
void
*>
(
best_desc
));
}
}
VLOG
(
6
)
<<
desc
->
GetDescResultString
(
"[Impl
Matmul
Descriptor] "
);
VLOG
(
6
)
<<
desc
->
GetDescResultString
(
"[Impl
Cublaslt
Descriptor] "
);
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmul
(
cublaslt_handle
,
desc
->
op_desc
,
...
...
@@ -457,7 +474,7 @@ struct MatmulWithCublasLt {
static
void
SearchBestAlgo
(
const
phi
::
GPUContext
&
ctx
,
const
cublasLtHandle_t
&
lt_handle
,
MatmulDesc
riptor
*
desc
,
MatmulDesc
T
*
desc
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
y_data
,
...
...
@@ -526,7 +543,7 @@ struct MatmulWithCublasLt {
}
}
float
time_cnt
=
(
cur_time
/
(
repeats
-
1
));
VLOG
(
4
)
<<
"Time cost in MatmulWithCublaslt algo["
<<
algo_idx
<<
"]"
VLOG
(
6
)
<<
"Time cost in MatmulWithCublaslt algo["
<<
algo_idx
<<
"]"
<<
"is : "
<<
time_cnt
<<
" s"
;
if
(
cur_time
<
min_time_cost
)
{
...
...
@@ -534,12 +551,241 @@ struct MatmulWithCublasLt {
min_time_cost
=
cur_time
;
}
}
VLOG
(
4
)
<<
"Best_algo_idx in MatmulWithCublaslt is : "
<<
best_algo_idx
;
VLOG
(
6
)
<<
"Best_algo_idx in MatmulWithCublaslt is : "
<<
best_algo_idx
;
*
best_algo
=
heuristic_results
[
best_algo_idx
].
algo
;
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cublasLtMatmulPreferenceDestroy
(
preference
));
}
};
// To judge if desc is cached or not.
template
<
class
DescT
,
typename
T
,
typename
DXT
=
T
,
typename
DYT
=
T
,
bool
TransX
=
false
,
bool
TransY
=
false
>
struct
DescriptorSetter
{
public:
DescT
desc
;
size_t
sub_key
{
std
::
numeric_limits
<
size_t
>::
min
()};
DescriptorSetter
(
phi
::
funcs
::
MatmulPlanner
*
planner
,
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
K
,
const
bool
trans_x
,
const
bool
trans_y
,
const
int
batch_size
=
1
,
int64_t
stride_x
=
0
,
int64_t
stride_y
=
0
,
int64_t
stride_out
=
0
,
const
bool
no_exchange
=
true
,
bool
grad_for_dx
=
true
)
{
if
(
planner
!=
nullptr
)
{
sub_key
=
planner
->
GenSubKey
(
static_cast
<
size_t
>
(
planner
->
ImplType
()));
}
auto
&
mamtul_cache
=
phi
::
autotune
::
AutoTuneCache
::
Instance
().
GetMatmul
();
if
(
mamtul_cache
.
FindSubKey
(
sub_key
))
{
desc
=
*
(
reinterpret_cast
<
DescT
*>
(
mamtul_cache
.
GetSubKey
(
sub_key
)));
desc
.
template
SetFusedEpiloguePtr
<
DYT
>(
planner
);
VLOG
(
6
)
<<
desc
.
GetDescResultString
(
"[Heap CublasltDescriptor] "
);
}
else
{
desc
.
template
Create
<
T
,
DXT
,
DYT
,
TransX
,
TransY
>(
M
,
N
,
K
,
trans_x
,
trans_y
,
planner
,
batch_size
,
stride_x
,
stride_y
,
stride_out
,
grad_for_dx
);
desc
.
ExchangeXYDesc
(
no_exchange
);
if
(
planner
!=
nullptr
)
{
desc
.
template
SetFusedEpiloguePtr
<
DYT
>(
planner
);
}
VLOG
(
6
)
<<
desc
.
GetDescResultString
(
"[Stack CublasltDescriptor] "
,
false
);
}
}
};
// For matmul with kernels autotune
template
<
typename
T
>
struct
MatmulWithCublasLt
:
public
CublasLtBase
<
T
>
{
public:
static
void
Run
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
x_data
,
const
T
*
y_data
,
T
*
out_data
,
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
K
,
const
bool
trans_x
,
const
bool
trans_y
,
phi
::
funcs
::
MatmulPlanner
*
planner
=
nullptr
)
{
auto
setter
=
DescriptorSetter
<
MatmulDescriptor
,
T
>
(
planner
,
M
,
N
,
K
,
trans_x
,
trans_y
);
CublasLtBase
<
T
>::
RunImpl
(
ctx
,
&
setter
.
desc
,
setter
.
sub_key
,
x_data
,
y_data
,
out_data
,
planner
);
}
static
void
RunWithBatch
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
x_data
,
const
T
*
y_data
,
T
*
out_data
,
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
K
,
bool
trans_x
,
bool
trans_y
,
int
batch_size
,
int64_t
stride_x
,
int64_t
stride_y
,
int64_t
stride_out
,
phi
::
funcs
::
MatmulPlanner
*
planner
=
nullptr
)
{
auto
setter
=
DescriptorSetter
<
MatmulDescriptor
,
T
>
(
planner
,
M
,
N
,
K
,
trans_x
,
trans_y
,
batch_size
,
stride_x
,
stride_y
,
stride_out
);
CublasLtBase
<
T
>::
RunImpl
(
ctx
,
&
setter
.
desc
,
setter
.
sub_key
,
x_data
,
y_data
,
out_data
,
planner
);
}
static
void
RunWithBatch
(
const
phi
::
GPUContext
&
ctx
,
const
T
**
x_data
,
const
T
**
y_data
,
T
**
out_data
,
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
K
,
bool
trans_x
,
bool
trans_y
,
int
batch_size
,
phi
::
funcs
::
MatmulPlanner
*
planner
=
nullptr
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
Run
(
ctx
,
x_data
[
i
],
y_data
[
i
],
out_data
[
i
],
M
,
N
,
K
,
trans_x
,
trans_y
,
planner
);
}
}
};
// As for just Linear fused ephilogue below: out = matmul(x, y) + bias.
template
<
typename
T
>
struct
LinearWithCublasLt
:
public
CublasLtBase
<
T
>
{
static
void
Run
(
const
phi
::
GPUContext
&
ctx
,
const
phi
::
DenseTensor
*
x
,
const
phi
::
DenseTensor
*
y
,
phi
::
DenseTensor
*
out
,
const
void
*
bias_data
,
void
*
reserve_data
,
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
K
,
const
bool
trans_x
,
const
bool
trans_y
,
const
MatmulFusedType
fused_type
)
{
auto
planner
=
phi
::
funcs
::
MatmulPlanner
(
vectorize
(
x
->
dims
()),
vectorize
(
y
->
dims
()),
trans_x
,
trans_y
,
phi
::
CppTypeToDataType
<
T
>::
Type
(),
fused_type
,
bias_data
,
reserve_data
);
auto
setter
=
DescriptorSetter
<
MatmulDescriptor
,
T
>
(
&
planner
,
M
,
N
,
K
,
trans_x
,
trans_y
);
CublasLtBase
<
T
>::
RunImpl
(
ctx
,
&
setter
.
desc
,
setter
.
sub_key
,
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
out
->
data
<
T
>
(),
&
planner
);
}
};
template
<
typename
T
,
typename
DXT
,
typename
DYT
,
bool
TransX
,
bool
TransY
>
struct
LinearGradWithCublasLt
:
public
CublasLtBase
<
T
>
{
static
void
Run
(
const
phi
::
GPUContext
&
ctx
,
const
phi
::
DenseTensor
*
x
,
const
phi
::
DenseTensor
*
y
,
phi
::
DenseTensor
*
out
,
const
void
*
bias_data
,
void
*
reserve_data
,
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
K
,
const
MatmulFusedType
fused_type
,
const
bool
trans_x
,
const
bool
trans_y
,
const
bool
use_addto
,
const
bool
no_exchange
,
// exchange x_desc and y_desc for grad.
bool
grad_for_dx
=
true
)
{
auto
planner
=
phi
::
funcs
::
MatmulPlanner
(
vectorize
(
x
->
dims
()),
vectorize
(
y
->
dims
()),
trans_x
,
trans_y
,
phi
::
CppTypeToDataType
<
T
>::
Type
(),
fused_type
,
bias_data
,
reserve_data
,
use_addto
,
no_exchange
);
auto
setter
=
DescriptorSetter
<
MatmulGradDescriptor
,
T
,
DXT
,
DYT
,
TransX
,
TransY
>
(
&
planner
,
M
,
N
,
K
,
trans_x
,
trans_y
,
/*batch_size=*/
1
,
/*stride_x=*/
0
,
/*stride_y=*/
0
,
/*stride_out=*/
0
,
/*exchange_x_y_desc=*/
no_exchange
,
/*grad_for_dx=*/
grad_for_dx
);
// To setting data type for different kinda out_data.
if
(
grad_for_dx
)
{
CublasLtBase
<
T
,
DXT
,
MatmulGradDescriptor
>::
RunImpl
(
ctx
,
&
setter
.
desc
,
setter
.
sub_key
,
no_exchange
?
x
->
data
<
T
>
()
:
y
->
data
<
T
>
(),
no_exchange
?
y
->
data
<
T
>
()
:
x
->
data
<
T
>
(),
out
->
data
<
DXT
>
(),
&
planner
);
}
else
{
CublasLtBase
<
T
,
DYT
,
MatmulGradDescriptor
>::
RunImpl
(
ctx
,
&
setter
.
desc
,
setter
.
sub_key
,
no_exchange
?
x
->
data
<
T
>
()
:
y
->
data
<
T
>
(),
no_exchange
?
y
->
data
<
T
>
()
:
x
->
data
<
T
>
(),
out
->
data
<
DYT
>
(),
&
planner
);
}
}
};
#else
// A void structure just for successfully complile.
struct
MatmulPlanner
{};
...
...
paddle/phi/kernels/funcs/common_shape.h
浏览文件 @
f21b6f08
...
...
@@ -52,6 +52,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
"Axis should be less than or equal to %d, but received axis is %d."
,
max_dim
,
axis
));
if
(
x_dims
.
size
()
>
y_dims
.
size
())
{
std
::
fill
(
y_dims_array
,
y_dims_array
+
axis
,
1
);
if
(
axis
+
y_dims
.
size
()
<
max_dim
)
{
...
...
@@ -68,7 +69,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
std
::
copy
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
(),
y_dims_array
);
}
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
for
(
int
i
=
0
;
i
<
max_dim
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_dims_array
[
i
]
==
y_dims_array
[
i
]
||
x_dims_array
[
i
]
<=
1
||
y_dims_array
[
i
]
<=
1
,
...
...
paddle/phi/kernels/funcs/dropout_impl.cu.h
浏览文件 @
f21b6f08
...
...
@@ -350,8 +350,10 @@ void DropoutFwGPUKernelDriver(
auto
dst_functor
=
DstFunctor
<
T
>
(
1.0
f
-
dropout_prob
,
upscale_in_train
,
x_numel
);
std
::
vector
<
int64_t
>
out_dims
=
phi
::
vectorize
<
int64_t
>
(
x
.
dims
());
std
::
vector
<
int64_t
>
in_dims
=
phi
::
vectorize
<
int64_t
>
(
mask
->
dims
());
std
::
vector
<
int64_t
>
out_dims
=
std
::
move
(
phi
::
vectorize
<
int64_t
>
(
x
.
dims
()));
std
::
vector
<
int64_t
>
in_dims
=
std
::
move
(
phi
::
vectorize
<
int64_t
>
(
mask
->
dims
()));
std
::
reverse
(
out_dims
.
begin
(),
out_dims
.
end
());
std
::
reverse
(
in_dims
.
begin
(),
in_dims
.
end
());
kps
::
details
::
BroadcastConfig
broadcast_config
(
...
...
paddle/phi/kernels/funcs/fused_gemm_epilogue.h
浏览文件 @
f21b6f08
...
...
@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/scope_guard.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/utils/optional.h"
DECLARE_int64
(
cublaslt_exhaustive_search_times
);
...
...
@@ -488,62 +489,103 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
phi
::
dynload
::
cublasLtMatrixLayoutDestroy
(
out_desc
));
}
enum
FusedGEMMGradInType
{
kDX
=
0
,
kDY
=
1
,
kDZ
=
2
};
template
<
bool
TransX
,
bool
TransY
>
struct
FusedGEMMGradTrait
;
template
<
>
struct
FusedGEMMGradTrait
<
false
,
false
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradATrans
=
false
;
static
constexpr
auto
kXGradBTrans
=
true
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradATrans
=
true
;
static
constexpr
auto
kYGradBTrans
=
false
;
};
struct
BwdFusedEpilogueSetter
{
public:
static
phi
::
funcs
::
MatmulFusedType
SetForDx
(
const
std
::
string
&
activation_grad
)
{
if
(
activation_grad
==
"none"
)
{
return
kMatmulGrad
;
}
else
if
(
activation_grad
==
"relu_grad"
)
{
return
kMatmulReluGrad
;
}
else
if
(
activation_grad
==
"gelu_grad"
)
{
return
kMatmulGeluGrad
;
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Fued linear epilogue type should be one of {none, relu, gelu}."
"But received activation is %s, please check"
,
activation_grad
))
;
}
}
template
<
>
struct
FusedGEMMGradTrait
<
true
,
false
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradATrans
=
false
;
static
constexpr
auto
kXGradBTrans
=
true
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradATrans
=
false
;
static
constexpr
auto
kYGradBTrans
=
false
;
template
<
typename
DYT
,
bool
TransY
>
static
phi
::
funcs
::
MatmulFusedType
SetForDy
(
const
phi
::
GPUContext
&
dev_ctx
,
phi
::
DenseTensor
*
dbias
)
{
if
(
dbias
!=
nullptr
)
{
dev_ctx
.
Alloc
<
DYT
>
(
dbias
,
dbias
->
numel
()
*
sizeof
(
DYT
));
return
TransY
?
kMatmulBiasGradToB
:
kMatmulBiasGradToA
;
}
else
{
return
kMatmulGradWithoutBias
;
}
}
};
template
<
>
struct
FusedGEMMGradTrait
<
false
,
true
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradATrans
=
false
;
static
constexpr
auto
kXGradBTrans
=
false
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradATrans
=
true
;
static
constexpr
auto
kYGradBTrans
=
false
;
};
template
<
typename
T
,
typename
DXT
,
typename
DYT
,
bool
TransX
,
bool
TransY
>
void
ComputeFusedGemmEpilogueBackwardImpl
(
const
phi
::
GPUContext
&
dev_ctx
,
const
phi
::
DenseTensor
*
dout
,
const
phi
::
DenseTensor
*
x
,
const
phi
::
DenseTensor
*
y
,
const
phi
::
DenseTensor
*
reserve_space
,
int64_t
M
,
int64_t
N
,
int64_t
K
,
const
std
::
string
activation_grad
,
phi
::
DenseTensor
*
dx
,
phi
::
DenseTensor
*
dy
,
phi
::
DenseTensor
*
dbias
,
bool
use_addto_dx
,
bool
use_addto_dy
)
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
static_assert
(
std
::
is_same
<
DXT
,
T
>::
value
||
std
::
is_same
<
DXT
,
MT
>::
value
);
static_assert
(
std
::
is_same
<
DYT
,
T
>::
value
||
std
::
is_same
<
DYT
,
MT
>::
value
);
using
Trait
=
FusedGEMMGradTrait
<
TransX
,
TransY
>
;
template
<
>
struct
FusedGEMMGradTrait
<
true
,
true
>
{
static
constexpr
auto
kXGradA
=
FusedGEMMGradInType
::
kDY
;
static
constexpr
auto
kXGradB
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kXGradATrans
=
true
;
static
constexpr
auto
kXGradBTrans
=
true
;
static
constexpr
auto
kYGradA
=
FusedGEMMGradInType
::
kDZ
;
static
constexpr
auto
kYGradB
=
FusedGEMMGradInType
::
kDX
;
static
constexpr
auto
kYGradATrans
=
true
;
static
constexpr
auto
kYGradBTrans
=
true
;
};
if
(
dx
)
{
constexpr
auto
kXGradAIsDZ
=
(
Trait
::
kXGradA
==
FusedGEMMGradInType
::
kDZ
);
auto
fused_type
=
BwdFusedEpilogueSetter
::
SetForDx
(
activation_grad
);
void
*
reserve_data
=
(
fused_type
==
kMatmulGrad
)
?
nullptr
:
const_cast
<
void
*>
(
reserve_space
->
data
());
dev_ctx
.
Alloc
<
DXT
>
(
dx
,
dx
->
numel
()
*
sizeof
(
DXT
));
phi
::
funcs
::
LinearGradWithCublasLt
<
T
,
DXT
,
DYT
,
TransX
,
TransY
>::
Run
(
dev_ctx
,
dout
,
y
,
dx
,
nullptr
,
reserve_data
,
M
,
N
,
K
,
fused_type
,
Trait
::
kXGradATrans
,
Trait
::
kXGradBTrans
,
use_addto_dx
,
kXGradAIsDZ
);
}
if
(
dy
)
{
auto
fused_type
=
BwdFusedEpilogueSetter
::
SetForDy
<
DYT
,
TransY
>
(
dev_ctx
,
dbias
);
constexpr
auto
kYGradAIsDZ
=
(
Trait
::
kYGradA
==
FusedGEMMGradInType
::
kDZ
);
// Caution: DYT is in front of DXT in this template arguments.
dev_ctx
.
Alloc
<
DYT
>
(
dy
,
dy
->
numel
()
*
sizeof
(
DYT
));
phi
::
funcs
::
LinearGradWithCublasLt
<
T
,
DXT
,
DYT
,
TransX
,
TransY
>::
Run
(
dev_ctx
,
dout
,
x
,
dy
,
dbias
?
static_cast
<
const
void
*>
(
dbias
->
data
<
DYT
>
())
:
nullptr
,
nullptr
,
M
,
N
,
K
,
fused_type
,
Trait
::
kYGradATrans
,
Trait
::
kYGradBTrans
,
use_addto_dy
,
kYGradAIsDZ
,
/*is_dx=*/
false
);
}
}
static
constexpr
auto
BoolToCuBlasEnum
(
bool
transpose
)
{
return
transpose
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
...
...
@@ -567,7 +609,8 @@ static cublasLtEpilogue_t GetEpilogueGradType(
}
template
<
typename
T
,
typename
DXT
,
typename
DYT
,
bool
TransX
,
bool
TransY
>
void
ComputeFusedGemmEpilogueBackwardImpl
(
const
phi
::
GPUContext
&
dev_ctx
,
void
ComputeFusedGemmEpilogueBackwardImplDev
(
const
phi
::
GPUContext
&
dev_ctx
,
const
phi
::
DenseTensor
*
dout
,
const
phi
::
DenseTensor
*
x
,
const
phi
::
DenseTensor
*
y
,
...
...
paddle/phi/kernels/gpu/cross_entropy_kernel.cu
浏览文件 @
f21b6f08
...
...
@@ -559,7 +559,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
// max index to read
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
#pragma unroll
// read data
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
...
...
@@ -659,7 +659,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
// loss
phi
::
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sumloss
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
i
++
)
{
if
(
i
>=
local_batches
)
break
;
loss
[
first_batch
+
i
]
=
sumloss
[
i
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录