Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0a4d1999
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a4d1999
编写于
7月 13, 2023
作者:
F
freeliuzc
提交者:
GitHub
7月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[inference] Add FusedBiasActKernel (#55301)
* add init value for CudaSwishFunctor * add new phi kernel fusedBiasActKernel
上级
d12837d3
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
1829 addition
and
1 deletion
+1829
-1
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+11
-0
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+130
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+15
-0
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+1
-1
paddle/phi/kernels/funcs/load_store_util.h
paddle/phi/kernels/funcs/load_store_util.h
+221
-0
paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu
paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu
+540
-0
paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h
paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h
+129
-0
test/legacy_test/test_fused_bias_act_op.py
test/legacy_test/test_fused_bias_act_op.py
+782
-0
未找到文件。
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
0a4d1999
...
@@ -63,6 +63,17 @@
...
@@ -63,6 +63,17 @@
data_type
:
x
data_type
:
x
optional
:
bias, x_max
optional
:
bias, x_max
-
op
:
fused_bias_act
args
:
(Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", int rows = -1, int cols = -1, float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0)
output
:
Tensor(out)
infer_meta
:
func
:
FusedBiasActInferMeta
kernel
:
func
:
fused_bias_act
data_type
:
x
optional
:
bias, dequant_scales, shift, smooth
support_dygraph_mode
:
true
-
op
:
fused_dropout_add
-
op
:
fused_dropout_add
args
:
(Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed =
false
)
args
:
(Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed =
false
)
optional
:
seed_tensor
optional
:
seed_tensor
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
0a4d1999
...
@@ -1335,6 +1335,136 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
...
@@ -1335,6 +1335,136 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
sequencenum
->
set_dtype
(
DataType
::
FLOAT32
);
sequencenum
->
set_dtype
(
DataType
::
FLOAT32
);
}
}
void
FusedBiasActInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
bias
,
const
MetaTensor
&
dequant_scales
,
const
MetaTensor
&
shift
,
const
MetaTensor
&
smooth
,
const
std
::
string
&
act_method
,
const
std
::
string
&
compute_dtype
,
int
rows
,
int
cols
,
float
quant_scale
,
int
quant_round_type
,
float
quant_max_bound
,
float
quant_min_bound
,
MetaTensor
*
out
)
{
auto
x_dims
=
x
.
dims
();
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"The size of Input(x) must be 2: %s"
,
x_dims
));
auto
token_num
=
x_dims
[
0
];
auto
dim
=
x_dims
[
1
];
PADDLE_ENFORCE_GT
(
rows
,
0
,
phi
::
errors
::
InvalidArgument
(
"The size of Attr(rows) must > 0"
));
PADDLE_ENFORCE_GT
(
cols
,
0
,
phi
::
errors
::
InvalidArgument
(
"The size of Attr(cols) must > 0"
));
if
(
act_method
==
"geglu"
||
act_method
==
"swiglu"
)
{
PADDLE_ENFORCE_EQ
(
dim
%
2
,
0
,
phi
::
errors
::
InvalidArgument
(
"The seconde dimension of x must be even, but receive %d"
,
dim
));
dim
/=
2
;
out
->
set_dims
(
phi
::
make_ddim
({
token_num
,
dim
}));
}
else
if
(
act_method
==
"gelu"
)
{
out
->
set_dims
(
phi
::
make_ddim
({
token_num
,
dim
}));
}
else
{
PADDLE_THROW
(
errors
::
InvalidArgument
(
"act_method must be geglu, swiglu or gelu, "
"but get act_method (%s)"
,
act_method
));
}
auto
FBADtypeCheck
=
[](
const
MetaTensor
&
check_tensor
,
const
std
::
string
&
tensor_name
,
const
std
::
string
&
compute_dtype
)
{
if
(
compute_dtype
==
"bf16"
)
{
PADDLE_ENFORCE_EQ
(
check_tensor
.
dtype
(),
phi
::
DataType
::
BFLOAT16
,
phi
::
errors
::
InvalidArgument
(
"Input(%s) dtype must be the same with Attr(compute_dtype)"
,
tensor_name
));
}
else
if
(
compute_dtype
==
"fp16"
)
{
PADDLE_ENFORCE_EQ
(
check_tensor
.
dtype
(),
phi
::
DataType
::
FLOAT16
,
phi
::
errors
::
InvalidArgument
(
"Input(%s) dtype must be the same with Attr(compute_dtype)"
,
tensor_name
));
}
else
if
(
compute_dtype
==
"fp32"
)
{
PADDLE_ENFORCE_EQ
(
check_tensor
.
dtype
(),
phi
::
DataType
::
FLOAT32
,
phi
::
errors
::
InvalidArgument
(
"Input(%s) dtype must be the same with Attr(compute_dtype)"
,
tensor_name
));
}
};
// In the case of quantization enabled, the dtype for computation is
// determined based on compute_dtype.
if
(
x
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
PADDLE_ENFORCE_NE
(
compute_dtype
,
"default"
,
phi
::
errors
::
InvalidArgument
(
"If Input(x) dtype is INT32, Attr(compute_dtype) must be set."
));
if
(
bias
)
{
FBADtypeCheck
(
bias
,
"bias"
,
compute_dtype
);
}
if
(
quant_scale
>
0
)
{
out
->
set_dtype
(
phi
::
DataType
::
INT8
);
}
else
{
if
(
compute_dtype
==
"bf16"
)
{
out
->
set_dtype
(
phi
::
DataType
::
BFLOAT16
);
}
else
if
(
compute_dtype
==
"fp16"
)
{
out
->
set_dtype
(
phi
::
DataType
::
FLOAT16
);
}
else
if
(
compute_dtype
==
"fp32"
)
{
out
->
set_dtype
(
phi
::
DataType
::
FLOAT32
);
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"In the case of quantization enabled with Input(x) INT32, "
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
"but get compute_dtype (%s)"
,
compute_dtype
));
}
}
}
else
{
// x.dtype() != phi::DataType::INT32
if
(
bias
)
{
if
(
compute_dtype
!=
"default"
)
{
FBADtypeCheck
(
bias
,
"bias"
,
compute_dtype
);
FBADtypeCheck
(
x
,
"x"
,
compute_dtype
);
}
else
{
PADDLE_ENFORCE_EQ
(
x
.
dtype
(),
bias
.
dtype
(),
phi
::
errors
::
InvalidArgument
(
"Input(x) and Input(bias) must be the "
"same dtype in this situation"
));
}
}
else
{
// bias not exist
if
(
compute_dtype
!=
"default"
)
{
FBADtypeCheck
(
x
,
"x"
,
compute_dtype
);
}
}
if
(
quant_scale
>
0
)
{
out
->
set_dtype
(
phi
::
DataType
::
INT8
);
}
else
{
out
->
set_dtype
(
x
.
dtype
());
}
}
out
->
set_layout
(
x
.
layout
());
}
void
FusedLinearParamGradAddInferMeta
(
const
MetaTensor
&
x
,
void
FusedLinearParamGradAddInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
dout
,
const
MetaTensor
&
dout
,
const
MetaTensor
&
dweight
,
const
MetaTensor
&
dweight
,
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
0a4d1999
...
@@ -279,6 +279,21 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
...
@@ -279,6 +279,21 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor
*
sequencenum
,
MetaTensor
*
sequencenum
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
FusedBiasActInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
bias
,
const
MetaTensor
&
dequant_scales
,
const
MetaTensor
&
shift
,
const
MetaTensor
&
smooth
,
const
std
::
string
&
act_method
,
const
std
::
string
&
compute_dtype
,
int
rows
,
int
cols
,
float
quant_scale
,
int
quant_round_type
,
float
quant_max_bound
,
float
quant_min_bound
,
MetaTensor
*
out
);
void
FusedLinearParamGradAddInferMeta
(
const
MetaTensor
&
x
,
void
FusedLinearParamGradAddInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
dout
,
const
MetaTensor
&
dout
,
const
MetaTensor
&
dweight
,
const
MetaTensor
&
dweight
,
...
...
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
0a4d1999
...
@@ -3923,7 +3923,7 @@ template <typename T>
...
@@ -3923,7 +3923,7 @@ template <typename T>
struct
CudaSwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaSwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
float
beta
;
float
beta
=
1.0
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
return
{{
"beta"
,
&
beta
}};
...
...
paddle/phi/kernels/funcs/load_store_util.h
0 → 100644
浏览文件 @
0a4d1999
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace
phi
{
namespace
funcs
{
#ifndef PADDLE_WITH_HIP
template
<
typename
T
>
__device__
__inline__
T
ClipFunc
(
const
T
v
,
const
T
min
,
const
T
max
)
{
if
(
v
>
max
)
return
max
;
if
(
v
<
min
)
return
min
;
return
v
;
}
template
<
typename
InType
,
typename
OutType
>
__forceinline__
__device__
OutType
QuantHelperFunc
(
const
InType
input
,
const
float
scale
,
const
int
round_type
,
const
float
max_bound
,
const
float
min_bound
)
{
float
quant_value
=
max_bound
*
scale
*
input
;
if
(
round_type
==
0
)
{
quant_value
=
static_cast
<
float
>
(
rint
(
quant_value
));
}
else
{
quant_value
=
static_cast
<
float
>
(
round
(
quant_value
));
}
return
static_cast
<
OutType
>
(
ClipFunc
<
float
>
(
quant_value
,
min_bound
,
max_bound
));
}
template
<
typename
T
>
struct
Load
{
explicit
Load
(
const
T
*
src
)
:
src_
(
src
)
{}
template
<
int
VecSize
>
__device__
void
load
(
phi
::
AlignedVector
<
T
,
VecSize
>
*
dst
,
int
idx
)
{
phi
::
Load
<
T
,
VecSize
>
(
src_
+
idx
,
dst
);
}
const
T
*
src_
;
};
template
<
typename
T
,
bool
Smooth
=
false
>
struct
Store
{
explicit
Store
(
T
*
dst
)
:
dst_
(
dst
)
{}
template
<
int
VecSize
>
__device__
void
store
(
phi
::
AlignedVector
<
T
,
VecSize
>
&
src
,
int
idx
)
{
phi
::
Store
<
T
,
VecSize
>
(
src
,
dst_
+
idx
);
}
T
*
dst_
;
};
template
<
typename
T
>
struct
Store
<
T
,
true
>
{
Store
(
T
*
dst
,
const
T
*
shift
,
const
T
*
smooth
,
const
int
cols
)
:
dst_
(
dst
),
shift_
(
shift
),
smooth_
(
smooth
),
cols_
(
cols
)
{}
template
<
int
VecSize
>
__device__
void
store
(
phi
::
AlignedVector
<
T
,
VecSize
>
&
src
,
int
idx
)
{
using
Vec
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
Vec
shift_vec
;
Vec
smooth_vec
;
phi
::
Load
<
T
,
VecSize
>
(
shift_
+
idx
%
cols_
,
&
shift_vec
);
phi
::
Load
<
T
,
VecSize
>
(
smooth_
+
idx
%
cols_
,
&
smooth_vec
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
src
[
i
]
=
(
src
[
i
]
+
shift_vec
[
i
])
*
smooth_vec
[
i
];
}
phi
::
Store
<
T
,
VecSize
>
(
src
,
dst_
+
idx
);
}
T
*
dst_
;
const
T
*
shift_
;
const
T
*
smooth_
;
const
int
cols_
;
};
template
<
typename
T
>
struct
DequantLoad
{
DequantLoad
(
const
int32_t
*
src
,
const
float
*
dequant_scales
,
const
int
cols
)
:
src_
(
src
),
dequant_scales_
(
dequant_scales
),
cols_
(
cols
)
{}
template
<
int
VecSize
>
__device__
void
load
(
phi
::
AlignedVector
<
T
,
VecSize
>
*
dst
,
int
idx
)
{
using
SrcVec
=
phi
::
AlignedVector
<
int32_t
,
VecSize
>
;
using
DstVec
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
ScaleVec
=
phi
::
AlignedVector
<
float
,
VecSize
>
;
SrcVec
src_vec
;
DstVec
dst_vec
;
ScaleVec
scale_vec
;
phi
::
Load
<
int32_t
,
VecSize
>
(
src_
+
idx
,
&
src_vec
);
phi
::
Load
<
float
,
VecSize
>
(
dequant_scales_
+
idx
%
cols_
,
&
scale_vec
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
dst_vec
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
src_vec
[
i
])
*
scale_vec
[
i
]);
}
*
dst
=
dst_vec
;
}
const
int32_t
*
src_
;
const
float
*
dequant_scales_
;
const
int
cols_
;
};
template
<
typename
T
,
bool
Smooth
=
false
>
struct
QuantStore
{
QuantStore
(
int8_t
*
dst
,
const
int
quant_round_type
,
const
float
quant_scale
,
const
float
quant_max_bound
,
const
float
quant_min_bound
)
:
dst_
(
dst
),
quant_round_type_
(
quant_round_type
),
quant_scale_
(
quant_scale
),
quant_max_bound_
(
quant_max_bound
),
quant_min_bound_
(
quant_min_bound
)
{}
template
<
int
VecSize
>
__device__
void
store
(
phi
::
AlignedVector
<
T
,
VecSize
>
&
src
,
// NOLINT
int
idx
)
{
// NOLINT
using
DstVec
=
phi
::
AlignedVector
<
int8_t
,
VecSize
>
;
DstVec
dst_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
dst_vec
[
i
]
=
QuantHelperFunc
<
float
,
int8_t
>
(
static_cast
<
float
>
(
src
[
i
]),
quant_scale_
,
quant_round_type_
,
quant_max_bound_
,
quant_min_bound_
);
}
phi
::
Store
<
int8_t
,
VecSize
>
(
dst_vec
,
dst_
+
idx
);
}
int8_t
*
dst_
;
const
int
quant_round_type_
;
const
float
quant_scale_
;
const
float
quant_max_bound_
;
const
float
quant_min_bound_
;
};
template
<
typename
T
>
struct
QuantStore
<
T
,
true
>
{
QuantStore
(
int8_t
*
dst
,
const
T
*
shift
,
const
T
*
smooth
,
const
int
cols
,
const
int
quant_round_type
,
const
float
quant_scale
,
const
float
quant_max_bound
,
const
float
quant_min_bound
)
:
dst_
(
dst
),
shift_
(
shift
),
smooth_
(
smooth
),
cols_
(
cols
),
quant_round_type_
(
quant_round_type
),
quant_scale_
(
quant_scale
),
quant_max_bound_
(
quant_max_bound
),
quant_min_bound_
(
quant_min_bound
)
{}
template
<
int
VecSize
>
__device__
void
store
(
phi
::
AlignedVector
<
T
,
VecSize
>
&
src
,
// NOLINT
int
idx
)
{
// NOLINT
using
DstVec
=
phi
::
AlignedVector
<
int8_t
,
VecSize
>
;
using
Vec
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
DstVec
dst_vec
;
Vec
shift_vec
;
Vec
smooth_vec
;
phi
::
Load
<
T
,
VecSize
>
(
shift_
+
idx
%
cols_
,
&
shift_vec
);
phi
::
Load
<
T
,
VecSize
>
(
smooth_
+
idx
%
cols_
,
&
smooth_vec
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
src
[
i
]
=
(
src
[
i
]
+
shift_vec
[
i
])
*
smooth_vec
[
i
];
dst_vec
[
i
]
=
QuantHelperFunc
<
float
,
int8_t
>
(
static_cast
<
float
>
(
src
[
i
]),
quant_scale_
,
quant_round_type_
,
quant_max_bound_
,
quant_min_bound_
);
}
phi
::
Store
<
int8_t
,
VecSize
>
(
dst_vec
,
dst_
+
idx
);
}
int8_t
*
dst_
;
const
int
quant_round_type_
;
const
float
quant_scale_
;
const
float
quant_max_bound_
;
const
float
quant_min_bound_
;
const
T
*
shift_
;
const
T
*
smooth_
;
const
int
cols_
;
};
#endif
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu
0 → 100644
浏览文件 @
0a4d1999
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h"
PHI_DECLARE_bool
(
use_fast_math
);
namespace
phi
{
namespace
fusion
{
#ifndef PADDLE_WITH_HIP
template
<
typename
T
,
typename
Functor
,
int
VecSize
,
typename
LoadFunc
,
typename
StoreFunc
>
__global__
void
ActFFNGlu
(
const
T
*
bias
,
Functor
act_functor
,
const
int
token_num
,
const
int
hid_dim
,
const
int
elem_num
,
LoadFunc
load_func
,
StoreFunc
store_func
)
{
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
LoadT
src_vec1
;
LoadT
src_vec2
;
LoadT
bias_vec1
;
LoadT
bias_vec2
;
const
int
global_tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
global_tid
*
VecSize
;
i
<
elem_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
*
VecSize
)
{
int
bi
=
i
/
hid_dim
;
int
idx
=
i
%
hid_dim
;
load_func
.
template
load
<
VecSize
>(
&
src_vec1
,
bi
*
hid_dim
*
2
+
idx
);
load_func
.
template
load
<
VecSize
>(
&
src_vec2
,
bi
*
hid_dim
*
2
+
idx
+
hid_dim
);
if
(
bias
)
{
phi
::
Load
<
T
,
VecSize
>
(
&
bias
[
idx
],
&
bias_vec1
);
phi
::
Load
<
T
,
VecSize
>
(
&
bias
[
idx
+
hid_dim
],
&
bias_vec2
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
VecSize
;
j
++
)
{
if
(
bias
)
{
src_vec1
[
j
]
+=
bias_vec1
[
j
];
src_vec2
[
j
]
+=
bias_vec2
[
j
];
}
src_vec1
[
j
]
=
act_functor
(
src_vec1
[
j
]);
src_vec1
[
j
]
*=
src_vec2
[
j
];
}
store_func
.
template
store
<
VecSize
>(
src_vec1
,
bi
*
hid_dim
+
idx
);
}
}
template
<
typename
T
,
typename
Context
,
typename
Functor
,
typename
LoadFunc
,
typename
StoreFunc
,
typename
LoadT
=
T
>
void
LaunchActFFNGlu
(
const
Context
&
dev_ctx
,
const
T
*
bias
,
const
int
token_num
,
const
int
hid_dim
,
LoadFunc
load_func
,
StoreFunc
store_func
)
{
constexpr
int
VecSize
=
16
;
constexpr
int
PackSize
=
VecSize
/
sizeof
(
LoadT
);
const
int
elem_cnt
=
token_num
*
hid_dim
;
const
int
blocksize
=
128
;
int
grid_size
=
1
;
Functor
functor
;
switch
(
hid_dim
%
PackSize
)
{
case
0
:
GetNumBlocks
(
elem_cnt
/
PackSize
,
&
grid_size
);
ActFFNGlu
<
T
,
Functor
,
PackSize
>
<<<
grid_size
,
blocksize
,
0
,
dev_ctx
.
stream
()
>>>
(
bias
,
functor
,
token_num
,
hid_dim
,
elem_cnt
,
load_func
,
store_func
);
break
;
default:
GetNumBlocks
(
elem_cnt
,
&
grid_size
);
ActFFNGlu
<
T
,
Functor
,
1
><<<
grid_size
,
blocksize
,
0
,
dev_ctx
.
stream
()
>>>
(
bias
,
functor
,
token_num
,
hid_dim
,
elem_cnt
,
load_func
,
store_func
);
break
;
}
}
template
<
typename
T
,
typename
Functor
,
int
VecSize
,
typename
LoadFunc
,
typename
StoreFunc
>
__global__
void
BiasAct
(
const
T
*
bias
,
Functor
act_functor
,
const
int
rows
,
const
int
cols
,
const
int
elem_num
,
LoadFunc
load_func
,
StoreFunc
store_func
)
{
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
LoadT
src_vec
;
LoadT
bias_vec
;
// Zero Initialize BiasVec.
#pragma unroll
for
(
int
unroll_idx
=
0
;
unroll_idx
<
VecSize
;
unroll_idx
++
)
{
bias_vec
[
unroll_idx
]
=
0
;
}
const
int
global_tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
global_tid
*
VecSize
;
i
<
elem_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
*
VecSize
)
{
int
row_idx
=
i
/
cols
;
int
col_idx
=
i
%
cols
;
int
linear_idx
=
row_idx
*
cols
+
col_idx
;
load_func
.
template
load
<
VecSize
>(
&
src_vec
,
linear_idx
);
if
(
bias
)
{
phi
::
Load
<
T
,
VecSize
>
(
&
bias
[
col_idx
],
&
bias_vec
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
VecSize
;
j
++
)
{
if
(
bias
)
{
src_vec
[
j
]
+=
bias_vec
[
j
];
}
src_vec
[
j
]
=
act_functor
(
src_vec
[
j
]);
}
store_func
.
template
store
<
VecSize
>(
src_vec
,
linear_idx
);
}
}
template
<
typename
T
,
typename
Context
,
typename
Functor
,
typename
LoadFunc
,
typename
StoreFunc
,
typename
LoadT
=
T
>
void
LaunchBiasAct
(
const
Context
&
dev_ctx
,
const
T
*
bias
,
const
int
token_num
,
const
int
hid_dim
,
LoadFunc
load_func
,
StoreFunc
store_func
)
{
constexpr
int
VecSize
=
16
;
constexpr
int
PackSize
=
VecSize
/
sizeof
(
LoadT
);
const
int
elem_cnt
=
token_num
*
hid_dim
;
const
int
blocksize
=
128
;
int
grid_size
=
1
;
Functor
functor
;
switch
(
hid_dim
%
PackSize
)
{
case
0
:
GetNumBlocks
(
elem_cnt
/
PackSize
,
&
grid_size
);
BiasAct
<
T
,
Functor
,
PackSize
>
<<<
grid_size
,
blocksize
,
0
,
dev_ctx
.
stream
()
>>>
(
bias
,
functor
,
token_num
,
hid_dim
,
elem_cnt
,
load_func
,
store_func
);
break
;
default:
GetNumBlocks
(
elem_cnt
,
&
grid_size
);
BiasAct
<
T
,
Functor
,
1
><<<
grid_size
,
blocksize
,
0
,
dev_ctx
.
stream
()
>>>
(
bias
,
functor
,
token_num
,
hid_dim
,
elem_cnt
,
load_func
,
store_func
);
break
;
}
}
template
<
typename
T
,
typename
Context
,
typename
LoadFunc
,
typename
StoreFunc
,
typename
LoadT
=
T
>
void
ComputeImpl
(
const
Context
&
dev_ctx
,
const
T
*
bias_data
,
const
std
::
string
&
act_method
,
int
rows
,
int
cols
,
LoadFunc
load_func
,
StoreFunc
store_func
)
{
if
(
act_method
==
"geglu"
)
{
// Note(Zhengzekang): For GLU structure, we need divide the cols by 2.
VLOG
(
8
)
<<
"Doing geglu"
;
LaunchActFFNGlu
<
T
,
Context
,
GeluFunctor
<
T
>
,
LoadFunc
,
StoreFunc
,
LoadT
>
(
dev_ctx
,
bias_data
,
rows
,
cols
/
2
,
load_func
,
store_func
);
}
else
if
(
act_method
==
"swiglu"
)
{
VLOG
(
8
)
<<
"Doing swiglu"
;
LaunchActFFNGlu
<
T
,
Context
,
CudaSwishFunctor
<
T
>
,
LoadFunc
,
StoreFunc
,
LoadT
>
(
dev_ctx
,
bias_data
,
rows
,
cols
/
2
,
load_func
,
store_func
);
}
else
if
(
act_method
==
"gelu"
)
{
if
(
FLAGS_use_fast_math
)
{
VLOG
(
8
)
<<
"Doing Fast GELU"
;
LaunchBiasAct
<
T
,
Context
,
FastGeluFunctor
<
T
>
,
LoadFunc
,
StoreFunc
,
LoadT
>
(
dev_ctx
,
bias_data
,
rows
,
cols
,
load_func
,
store_func
);
}
else
{
VLOG
(
8
)
<<
"Doing GELU"
;
LaunchBiasAct
<
T
,
Context
,
GeluFunctor
<
T
>
,
LoadFunc
,
StoreFunc
,
LoadT
>
(
dev_ctx
,
bias_data
,
rows
,
cols
,
load_func
,
store_func
);
}
}
else
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Currently Only Support GeGLU, SwiGLU, GeLU"
));
}
}
template
<
typename
T
,
typename
Context
>
void
DispatchComputeImpl
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
*
bias
,
const
DenseTensor
*
dequant_scales
,
const
std
::
string
&
act_method
,
int
rows
,
int
cols
,
const
float
quant_scale
,
const
int
quant_round_type
,
const
float
quant_max_bound
,
const
float
quant_min_bound
,
DenseTensor
*
out
)
{
const
T
*
bias_data
=
bias
==
nullptr
?
nullptr
:
bias
->
data
<
T
>
();
if
(
dequant_scales
!=
nullptr
&&
quant_scale
>
0
)
{
DequantLoad
<
T
>
load_func
(
x
.
data
<
int32_t
>
(),
dequant_scales
->
data
<
float
>
(),
cols
);
QuantStore
<
T
>
store_func
(
dev_ctx
.
template
Alloc
<
int8_t
>(
out
),
quant_round_type
,
quant_scale
,
quant_max_bound
,
quant_min_bound
);
ComputeImpl
<
T
,
Context
,
DequantLoad
<
T
>
,
QuantStore
<
T
>
,
int32_t
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
else
if
(
dequant_scales
==
nullptr
&&
quant_scale
>
0
)
{
Load
<
T
>
load_func
(
x
.
data
<
T
>
());
QuantStore
<
T
>
store_func
(
dev_ctx
.
template
Alloc
<
int8_t
>(
out
),
quant_round_type
,
quant_scale
,
quant_max_bound
,
quant_min_bound
);
ComputeImpl
<
T
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
else
if
(
dequant_scales
!=
nullptr
&&
quant_scale
<=
0
)
{
DequantLoad
<
T
>
load_func
(
x
.
data
<
int32_t
>
(),
dequant_scales
->
data
<
float
>
(),
cols
);
Store
<
T
>
store_func
(
dev_ctx
.
template
Alloc
<
T
>(
out
));
ComputeImpl
<
T
,
Context
,
DequantLoad
<
T
>
,
Store
<
T
>
,
int32_t
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
else
{
Load
<
T
>
load_func
(
x
.
data
<
T
>
());
Store
<
T
>
store_func
(
dev_ctx
.
template
Alloc
<
T
>(
out
));
ComputeImpl
<
T
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
}
template
<
typename
T
,
typename
Context
>
void
DispatchComputeImpl
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
*
bias
,
const
DenseTensor
*
dequant_scales
,
const
DenseTensor
*
shift
,
const
DenseTensor
*
smooth
,
const
std
::
string
&
act_method
,
int
rows
,
int
cols
,
const
float
quant_scale
,
const
int
quant_round_type
,
const
float
quant_max_bound
,
const
float
quant_min_bound
,
DenseTensor
*
out
)
{
bool
use_glu
=
(
act_method
==
"geglu"
||
act_method
==
"swiglu"
);
const
T
*
bias_data
=
bias
==
nullptr
?
nullptr
:
bias
->
data
<
T
>
();
if
(
dequant_scales
!=
nullptr
&&
quant_scale
>
0
)
{
int8_t
*
out_data
=
dev_ctx
.
template
Alloc
<
int8_t
>(
out
);
DequantLoad
<
T
>
load_func
(
x
.
data
<
int32_t
>
(),
dequant_scales
->
data
<
float
>
(),
cols
);
QuantStore
<
T
,
true
>
store_func
(
dev_ctx
.
template
Alloc
<
int8_t
>(
out
),
shift
->
data
<
T
>
(),
smooth
->
data
<
T
>
(),
use_glu
?
cols
/
2
:
cols
,
quant_round_type
,
quant_scale
,
quant_max_bound
,
quant_min_bound
);
ComputeImpl
<
T
,
Context
,
DequantLoad
<
T
>
,
QuantStore
<
T
,
true
>
,
int32_t
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
else
if
(
dequant_scales
==
nullptr
&&
quant_scale
>
0
)
{
Load
<
T
>
load_func
(
x
.
data
<
T
>
());
QuantStore
<
T
,
true
>
store_func
(
dev_ctx
.
template
Alloc
<
int8_t
>(
out
),
shift
->
data
<
T
>
(),
smooth
->
data
<
T
>
(),
use_glu
?
cols
/
2
:
cols
,
quant_round_type
,
quant_scale
,
quant_max_bound
,
quant_min_bound
);
ComputeImpl
<
T
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
else
if
(
dequant_scales
!=
nullptr
&&
quant_scale
<=
0
)
{
DequantLoad
<
T
>
load_func
(
x
.
data
<
int32_t
>
(),
dequant_scales
->
data
<
float
>
(),
cols
);
Store
<
T
,
true
>
store_func
(
dev_ctx
.
template
Alloc
<
T
>(
out
),
shift
->
data
<
T
>
(),
smooth
->
data
<
T
>
(),
use_glu
?
cols
/
2
:
cols
);
ComputeImpl
<
T
,
Context
,
DequantLoad
<
T
>
,
Store
<
T
,
true
>
,
int32_t
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
else
{
Load
<
T
>
load_func
(
x
.
data
<
T
>
());
Store
<
T
,
true
>
store_func
(
dev_ctx
.
template
Alloc
<
T
>(
out
),
shift
->
data
<
T
>
(),
smooth
->
data
<
T
>
(),
use_glu
?
cols
/
2
:
cols
);
ComputeImpl
<
T
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
}
struct
NormalVersion
{};
struct
UnusedVersion
{};
template
<
typename
T
>
struct
DispatchDtypeTrait
{
using
FuncVersion
=
NormalVersion
;
};
template
<
>
struct
DispatchDtypeTrait
<
int32_t
>
{
using
FuncVersion
=
UnusedVersion
;
};
template
<
typename
T
,
typename
Context
>
void
DispatchWithDtype
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
paddle
::
optional
<
DenseTensor
>
&
bias
,
const
paddle
::
optional
<
DenseTensor
>
&
dequant_scales
,
const
paddle
::
optional
<
DenseTensor
>
&
shift
,
const
paddle
::
optional
<
DenseTensor
>
&
smooth
,
const
std
::
string
&
act_method
,
int
rows
,
int
cols
,
float
quant_scale
,
int
quant_round_type
,
float
quant_max_bound
,
float
quant_min_bound
,
DenseTensor
*
out
,
NormalVersion
)
{
auto
*
bias_p
=
bias
.
get_ptr
();
auto
*
dequant_scales_p
=
dequant_scales
.
get_ptr
();
auto
*
shift_p
=
shift
.
get_ptr
();
auto
*
smooth_p
=
smooth
.
get_ptr
();
if
(
dequant_scales_p
!=
nullptr
)
{
if
(
shift_p
!=
nullptr
)
{
DispatchComputeImpl
<
T
>
(
dev_ctx
,
x
,
bias_p
,
dequant_scales_p
,
shift_p
,
smooth_p
,
act_method
,
rows
,
cols
,
quant_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
out
);
}
else
{
DispatchComputeImpl
<
T
>
(
dev_ctx
,
x
,
bias_p
,
dequant_scales_p
,
act_method
,
rows
,
cols
,
quant_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
out
);
}
}
else
{
const
T
*
bias_data
=
bias_p
==
nullptr
?
nullptr
:
bias_p
->
data
<
T
>
();
Load
<
T
>
load_func
(
x
.
data
<
T
>
());
Store
<
T
>
store_func
(
dev_ctx
.
template
Alloc
<
T
>(
out
));
ComputeImpl
<
T
>
(
dev_ctx
,
bias_data
,
act_method
,
rows
,
cols
,
load_func
,
store_func
);
}
}
// (not use) only for registering int32_t
template
<
typename
T
,
typename
Context
>
void
DispatchWithDtype
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
paddle
::
optional
<
DenseTensor
>
&
bias
,
const
paddle
::
optional
<
DenseTensor
>
&
dequant_scales
,
const
paddle
::
optional
<
DenseTensor
>
&
shift
,
const
paddle
::
optional
<
DenseTensor
>
&
smooth
,
const
std
::
string
&
act_method
,
int
rows
,
int
cols
,
float
quant_scale
,
int
quant_round_type
,
float
quant_max_bound
,
float
quant_min_bound
,
DenseTensor
*
out
,
UnusedVersion
)
{}
#endif
template
<
typename
T
,
typename
Context
>
void
FusedBiasActKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
paddle
::
optional
<
DenseTensor
>
&
bias
,
const
paddle
::
optional
<
DenseTensor
>
&
dequant_scales
,
const
paddle
::
optional
<
DenseTensor
>
&
shift
,
const
paddle
::
optional
<
DenseTensor
>
&
smooth
,
const
std
::
string
&
act_method
,
const
std
::
string
&
compute_dtype
,
int
rows
,
int
cols
,
float
quant_scale
,
int
quant_round_type
,
float
quant_max_bound
,
float
quant_min_bound
,
DenseTensor
*
out
)
{
#ifndef PADDLE_WITH_HIP
if
(
x
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
if
(
compute_dtype
==
"bf16"
)
{
DispatchWithDtype
<
phi
::
dtype
::
bfloat16
,
Context
>
(
dev_ctx
,
x
,
bias
,
dequant_scales
,
shift
,
smooth
,
act_method
,
rows
,
cols
,
quant_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
out
,
typename
DispatchDtypeTrait
<
phi
::
dtype
::
bfloat16
>::
FuncVersion
{});
}
else
if
(
compute_dtype
==
"fp16"
)
{
DispatchWithDtype
<
phi
::
dtype
::
float16
,
Context
>
(
dev_ctx
,
x
,
bias
,
dequant_scales
,
shift
,
smooth
,
act_method
,
rows
,
cols
,
quant_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
out
,
typename
DispatchDtypeTrait
<
phi
::
dtype
::
float16
>::
FuncVersion
{});
}
else
if
(
compute_dtype
==
"fp32"
)
{
DispatchWithDtype
<
float
,
Context
>
(
dev_ctx
,
x
,
bias
,
dequant_scales
,
shift
,
smooth
,
act_method
,
rows
,
cols
,
quant_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
out
,
typename
DispatchDtypeTrait
<
float
>::
FuncVersion
{});
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"In the case of quantization enabled with Input(x) INT32, "
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
"but get compute_dtype (%s)"
,
compute_dtype
));
}
}
else
{
DispatchWithDtype
<
T
,
Context
>
(
dev_ctx
,
x
,
bias
,
dequant_scales
,
shift
,
smooth
,
act_method
,
rows
,
cols
,
quant_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
out
,
typename
DispatchDtypeTrait
<
T
>::
FuncVersion
{});
}
#endif
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fused_bias_act
,
GPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FusedBiasActKernel
,
float
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
,
int32_t
)
{}
paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h
0 → 100644
浏览文件 @
0a4d1999
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
#endif
// for windows build
#define M_SQRT1_2 0.70710678118654752440
namespace
phi
{
namespace
fusion
{
#ifndef PADDLE_WITH_HIP
template
<
typename
T
>
struct
GeluComputeType
;
template
<
>
struct
GeluComputeType
<
phi
::
dtype
::
bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
GeluComputeType
<
phi
::
dtype
::
float16
>
{
using
Type
=
float
;
};
template
<
>
struct
GeluComputeType
<
float
>
{
using
Type
=
float
;
};
template
<
typename
T
>
using
GeluType
=
typename
GeluComputeType
<
T
>::
Type
;
using
phi
::
funcs
::
DequantLoad
;
using
phi
::
funcs
::
Load
;
using
phi
::
funcs
::
QuantStore
;
using
phi
::
funcs
::
Store
;
template
<
typename
T
>
struct
BaseActivationFunctor
{
using
ELEMENT_TYPE
=
T
;
using
AttrPair
=
std
::
vector
<
std
::
pair
<
const
char
*
,
float
*>>
;
AttrPair
GetAttrs
()
{
return
AttrPair
();
}
};
// For windows build
template
<
typename
T
>
struct
CudaSwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
float
beta
=
1.0
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
}
// swish(x) = x / (1 + exp(-beta * x))
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
b
=
static_cast
<
MPType
>
(
beta
);
return
static_cast
<
T
>
(
x
/
(
one
+
exp
(
-
b
*
x
)));
}
};
// TODO(lzc): transfer to phi::funcs
template
<
typename
T
>
struct
GeluFunctor
{
inline
__host__
__device__
T
operator
()(
const
T
x
)
const
{
using
U
=
GeluType
<
T
>
;
const
U
casted_x
=
static_cast
<
U
>
(
x
);
const
U
temp
=
erf
(
casted_x
*
static_cast
<
U
>
(
M_SQRT1_2
));
const
U
out
=
(
casted_x
*
static_cast
<
U
>
(
0.5
)
*
(
static_cast
<
U
>
(
1
)
+
temp
));
return
static_cast
<
T
>
(
out
);
}
};
template
<
typename
T
>
struct
FastGeluFunctor
{
inline
__device__
T
operator
()(
const
T
x
)
const
{
return
phi
::
GeluFwd
<
T
,
true
>
(
x
);
}
};
inline
cudaError_t
GetNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
constexpr
int
kBlockSize
=
128
;
constexpr
int
kNumWaves
=
16
;
const
int
device_id
=
phi
::
backends
::
gpu
::
GetCurrentDeviceId
();
const
int
sm_count
=
phi
::
backends
::
gpu
::
GetGPUMultiProcessors
(
device_id
);
const
int
max_thread_per_multiprocessor
=
phi
::
backends
::
gpu
::
GetGPUMultiProcessors
(
device_id
);
*
num_blocks
=
std
::
max
<
int
>
(
1
,
std
::
min
<
int64_t
>
((
n
+
kBlockSize
-
1
)
/
kBlockSize
,
sm_count
*
max_thread_per_multiprocessor
/
kBlockSize
*
kNumWaves
));
return
cudaSuccess
;
}
#endif
}
// namespace fusion
}
// namespace phi
test/legacy_test/test_fused_bias_act_op.py
0 → 100644
浏览文件 @
0a4d1999
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
eager_op_test
import
convert_float_to_uint16
from
scipy.special
import
erf
,
expit
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
def
round_type_1_process
(
val
):
dtype
=
type
(
val
)
if
val
>=
0
:
return
dtype
(
np
.
floor
(
val
+
0.5
))
return
dtype
(
np
.
ceil
(
val
-
0.5
))
# rounding to nearest ties away from zero
round_type_1
=
np
.
vectorize
(
round_type_1_process
)
M_SQRT1_2
=
0.70710678118654752440
def
gelu
(
x
):
out
=
(
0.5
*
x
.
astype
(
'float32'
)
*
(
1.0
+
erf
(
x
.
astype
(
'float32'
)
*
M_SQRT1_2
))
)
return
out
.
astype
(
x
.
dtype
)
def
swish
(
x
):
out
=
x
.
astype
(
'float32'
)
*
expit
(
x
.
astype
(
'float32'
))
return
out
.
astype
(
x
.
dtype
)
def
fake_dequant
(
values
,
dequant_scales
):
out
=
values
*
dequant_scales
.
astype
(
'float32'
)
return
out
def
fake_quant
(
values
,
shift
,
smooth
,
quant_sacle
,
max_bound
,
min_bound
,
round_type
):
values_tmp
=
(
values
+
shift
)
*
smooth
values_tmp
=
max_bound
*
quant_sacle
*
values_tmp
if
round_type
==
0
:
values_tmp
=
np
.
rint
(
values_tmp
)
elif
round_type
==
1
:
values_tmp
=
round_type_1
(
values_tmp
)
return
np
.
clip
(
values_tmp
,
min_bound
,
max_bound
).
astype
(
np
.
int8
)
def
fused_act_bias_wrapper
(
x
,
bias
=
None
,
dequant_scales
=
None
,
shift
=
None
,
smooth
=
None
,
act_method
=
'gelu'
,
compute_dtype
=
'default'
,
rows
=
0
,
cols
=
0
,
quant_scale
=-
1
,
quant_round_type
=
0
,
quant_max_bound
=
0
,
quant_min_bound
=
0
,
):
return
paddle
.
_C_ops
.
fused_bias_act
(
x
,
bias
,
dequant_scales
,
shift
,
smooth
,
act_method
,
compute_dtype
,
rows
,
cols
,
quant_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestFusedBiasActOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
seed
(
2017
)
np
.
random
.
seed
(
2017
)
self
.
op_type
=
"fused_bias_act"
self
.
rtol
=
1e-5
self
.
atol
=
1e-3
self
.
rows
=
20
self
.
cols
=
512
self
.
dtype
=
'float32'
self
.
act_method
=
'gelu'
self
.
compute_dtype
=
'default'
self
.
use_glu
=
False
self
.
init_test_case
()
self
.
generate_inputs
()
def
init_test_case
(
self
):
pass
def
generate_inputs
(
self
):
self
.
x
=
(
np
.
random
.
rand
(
self
.
rows
,
self
.
cols
)
*
16
).
astype
(
self
.
dtype
)
self
.
bias
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
self
.
dtype
)
def
compute_baseline_output
(
self
):
out
=
gelu
(
self
.
x
+
self
.
bias
).
astype
(
self
.
dtype
)
return
out
def
compute_paddle_output
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
paddle
.
to_tensor
(
self
.
x
)
bias
=
paddle
.
to_tensor
(
self
.
bias
)
return
fused_act_bias_wrapper
(
x
=
x
,
bias
=
bias
,
rows
=
self
.
rows
,
cols
=
self
.
cols
,
act_method
=
self
.
act_method
,
compute_dtype
=
self
.
compute_dtype
,
)
def
test_check_output
(
self
):
final_out_ref
=
self
.
compute_baseline_output
()
final_out
=
self
.
compute_paddle_output
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
class
TestBaseFP16
(
TestFusedBiasActOp
):
def
init_test_case
(
self
):
self
.
dtype
=
np
.
float16
self
.
act_method
=
'gelu'
class
TestWithComTypeFP32
(
TestFusedBiasActOp
):
def
init_test_case
(
self
):
self
.
dtype
=
'float32'
self
.
act_method
=
'gelu'
self
.
compute_dtype
=
'fp32'
class
TestWithComTypeFP16
(
TestFusedBiasActOp
):
def
init_test_case
(
self
):
self
.
dtype
=
'float16'
self
.
act_method
=
'gelu'
self
.
compute_dtype
=
'fp16'
class
TestFastGeluFP16
(
TestFusedBiasActOp
):
def
use_fast_math
(
self
,
enabled
):
paddle
.
set_flags
({
'FLAGS_use_fast_math'
:
enabled
})
def
init_test_case
(
self
):
self
.
dtype
=
np
.
float16
self
.
act_method
=
'gelu'
def
compute_baseline_output
(
self
):
out
=
F
.
gelu
(
paddle
.
to_tensor
(
self
.
x
)
+
paddle
.
to_tensor
(
self
.
bias
),
approximate
=
True
,
)
return
out
def
compute_paddle_output
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
paddle
.
to_tensor
(
self
.
x
)
bias
=
paddle
.
to_tensor
(
self
.
bias
)
self
.
use_fast_math
(
True
)
out
=
fused_act_bias_wrapper
(
x
=
x
,
bias
=
bias
,
rows
=
self
.
rows
,
cols
=
self
.
cols
,
act_method
=
self
.
act_method
,
)
self
.
use_fast_math
(
False
)
return
out
class
TestGegluFP16
(
TestFusedBiasActOp
):
def
init_test_case
(
self
):
self
.
dtype
=
np
.
float16
self
.
act_method
=
'geglu'
def
compute_baseline_output
(
self
):
res_tmp
=
(
self
.
x
+
self
.
bias
).
astype
(
self
.
dtype
)
res_tmp_head
=
res_tmp
[:,
:
self
.
cols
//
2
]
res_tmp_tail
=
res_tmp
[:,
self
.
cols
//
2
:]
res_tmp_head_act
=
gelu
(
res_tmp_head
)
out
=
res_tmp_head_act
*
res_tmp_tail
return
out
class
TestSwigluFP16
(
TestFusedBiasActOp
):
def
init_test_case
(
self
):
self
.
dtype
=
np
.
float16
self
.
act_method
=
'swiglu'
def
compute_baseline_output
(
self
):
res_tmp
=
(
self
.
x
+
self
.
bias
).
astype
(
self
.
dtype
)
res_tmp_head
=
res_tmp
[:,
:
self
.
cols
//
2
]
res_tmp_tail
=
res_tmp
[:,
self
.
cols
//
2
:]
res_tmp_head_act
=
swish
(
res_tmp_head
)
out
=
res_tmp_head_act
*
res_tmp_tail
return
out
class
TestQuantFP32
(
TestFusedBiasActOp
):
def
init_test_case
(
self
):
self
.
atol
=
1
self
.
dtype
=
'float32'
self
.
compute_dtype
=
'fp32'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
def
generate_inputs
(
self
):
self
.
x
=
np
.
random
.
randint
(
low
=-
16
,
high
=
16
,
size
=
(
self
.
rows
,
self
.
cols
)
).
astype
(
'int32'
)
self
.
bias
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
self
.
dtype
)
self
.
dequant_scales
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
'float32'
)
quant_params_cols
=
self
.
cols
//
2
if
self
.
use_glu
else
self
.
cols
self
.
shift
=
np
.
zeros
(
quant_params_cols
).
astype
(
self
.
dtype
)
self
.
smooth
=
np
.
ones
(
quant_params_cols
).
astype
(
self
.
dtype
)
def
compute_baseline_output
(
self
):
input_dequanted
=
fake_dequant
(
self
.
x
,
self
.
dequant_scales
)
output_tmp
=
gelu
(
input_dequanted
+
self
.
bias
).
astype
(
self
.
dtype
)
out
=
fake_quant
(
output_tmp
,
self
.
shift
,
self
.
smooth
,
self
.
quant_scale
,
self
.
quant_max_bound
,
self
.
quant_min_bound
,
self
.
quant_round_type
,
)
return
out
def
compute_paddle_output
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
paddle
.
to_tensor
(
self
.
x
)
bias
=
paddle
.
to_tensor
(
self
.
bias
)
dequant_scales
=
paddle
.
to_tensor
(
self
.
dequant_scales
)
shift
=
paddle
.
to_tensor
(
self
.
shift
)
smooth
=
paddle
.
to_tensor
(
self
.
smooth
)
out
=
fused_act_bias_wrapper
(
x
=
x
,
bias
=
bias
,
dequant_scales
=
dequant_scales
,
shift
=
shift
,
smooth
=
smooth
,
act_method
=
self
.
act_method
,
compute_dtype
=
self
.
compute_dtype
,
rows
=
self
.
rows
,
cols
=
self
.
cols
,
quant_scale
=
self
.
quant_scale
,
quant_round_type
=
self
.
quant_round_type
,
quant_max_bound
=
self
.
quant_max_bound
,
quant_min_bound
=
self
.
quant_min_bound
,
)
return
out
class
TestDequantFP32
(
TestQuantFP32
):
def
init_test_case
(
self
):
self
.
rows
=
10
self
.
cols
=
10
self
.
atol
=
1
self
.
dtype
=
'float32'
self
.
compute_dtype
=
'fp32'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
def
generate_inputs
(
self
):
self
.
x
=
np
.
random
.
randint
(
low
=-
16
,
high
=
16
,
size
=
(
self
.
rows
,
self
.
cols
)
).
astype
(
'int32'
)
self
.
bias
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
self
.
dtype
)
self
.
dequant_scales
=
np
.
ones
(
self
.
cols
).
astype
(
'float32'
)
def
compute_baseline_output
(
self
):
input_dequanted
=
fake_dequant
(
self
.
x
,
self
.
dequant_scales
)
out
=
gelu
(
input_dequanted
+
self
.
bias
).
astype
(
self
.
dtype
)
return
out
def
compute_paddle_output
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
paddle
.
to_tensor
(
self
.
x
)
bias
=
paddle
.
to_tensor
(
self
.
bias
)
dequant_scales
=
paddle
.
to_tensor
(
self
.
dequant_scales
)
out
=
fused_act_bias_wrapper
(
x
=
x
,
bias
=
bias
,
dequant_scales
=
dequant_scales
,
act_method
=
self
.
act_method
,
compute_dtype
=
self
.
compute_dtype
,
rows
=
self
.
rows
,
cols
=
self
.
cols
,
)
return
out
class
TestQuantFP16
(
TestQuantFP32
):
def
init_test_case
(
self
):
self
.
atol
=
1
self
.
dtype
=
'float16'
self
.
compute_dtype
=
'fp16'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
class
TestDequantFP16
(
TestDequantFP32
):
def
init_test_case
(
self
):
self
.
rows
=
10
self
.
cols
=
10
self
.
atol
=
1
self
.
dtype
=
'float16'
self
.
compute_dtype
=
'fp16'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
class
TestQuantGegluFP16
(
TestQuantFP32
):
def
init_test_case
(
self
):
self
.
atol
=
1
self
.
dtype
=
'float16'
self
.
compute_dtype
=
'fp16'
self
.
act_method
=
'geglu'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
self
.
use_glu
=
True
def
compute_baseline_output
(
self
):
input_dequanted
=
fake_dequant
(
self
.
x
,
self
.
dequant_scales
)
tmp
=
(
input_dequanted
+
self
.
bias
).
astype
(
'float32'
)
tmp_head
=
tmp
[:,
:
self
.
cols
//
2
]
tmp_tail
=
tmp
[:,
self
.
cols
//
2
:]
out_tmp
=
gelu
(
tmp_head
).
astype
(
'float32'
)
*
tmp_tail
out
=
fake_quant
(
out_tmp
,
self
.
shift
,
self
.
smooth
,
self
.
quant_scale
,
self
.
quant_max_bound
,
self
.
quant_min_bound
,
self
.
quant_round_type
,
)
return
out
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestFusedBiasActOpBF16
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
seed
(
2019
)
np
.
random
.
seed
(
2019
)
self
.
op_type
=
"fused_bias_act"
self
.
rtol
=
1e-3
self
.
atol
=
1e-3
self
.
rows
=
20
self
.
cols
=
512
self
.
act_method
=
'gelu'
self
.
compute_dtype
=
'default'
self
.
init_test_case
()
self
.
generate_inputs
()
def
init_test_case
(
self
):
pass
def
generate_inputs
(
self
):
self
.
x
=
np
.
random
.
rand
(
self
.
rows
,
self
.
cols
).
astype
(
'float32'
)
*
16
self
.
bias
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
'float32'
)
def
compute_baseline_output
(
self
):
out
=
gelu
(
self
.
x
.
astype
(
'float32'
)
+
self
.
bias
)
return
convert_float_to_uint16
(
out
)
def
compute_paddle_output
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
paddle
.
to_tensor
(
convert_float_to_uint16
(
self
.
x
))
bias
=
paddle
.
to_tensor
(
convert_float_to_uint16
(
self
.
bias
))
out
=
fused_act_bias_wrapper
(
x
=
x
,
bias
=
bias
,
act_method
=
self
.
act_method
,
compute_dtype
=
self
.
compute_dtype
,
rows
=
self
.
rows
,
cols
=
self
.
cols
,
)
return
out
def
test_check_output
(
self
):
final_out_ref
=
self
.
compute_baseline_output
()
final_out
=
self
.
compute_paddle_output
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestWithComTypeBF16
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
self
.
act_method
=
'geglu'
self
.
compute_dtype
=
'bf16'
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestGegluBF16
(
TestFusedBiasActOpBF16
):
def
init_test_case
(
self
):
self
.
act_method
=
'geglu'
self
.
compute_dtype
=
'default'
def
compute_baseline_output
(
self
):
res_tmp
=
self
.
x
+
self
.
bias
res_tmp_head
=
res_tmp
[:,
:
self
.
cols
//
2
]
res_tmp_tail
=
res_tmp
[:,
self
.
cols
//
2
:]
res_tmp_head_act
=
gelu
(
res_tmp_head
)
out
=
res_tmp_head_act
*
res_tmp_tail
return
convert_float_to_uint16
(
out
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16 "
,
)
class
TestSwigluBF16
(
TestFusedBiasActOpBF16
):
def
init_test_case
(
self
):
self
.
act_method
=
'swiglu'
self
.
compute_dtype
=
'default'
def
compute_baseline_output
(
self
):
res_tmp
=
self
.
x
+
self
.
bias
res_tmp_head
=
res_tmp
[:,
:
self
.
cols
//
2
]
res_tmp_tail
=
res_tmp
[:,
self
.
cols
//
2
:]
res_tmp_head_act
=
swish
(
res_tmp_head
)
out
=
res_tmp_head_act
*
res_tmp_tail
return
convert_float_to_uint16
(
out
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestQuantBF16
(
TestFusedBiasActOpBF16
):
def
init_test_case
(
self
):
self
.
atol
=
1
self
.
compute_dtype
=
'bf16'
self
.
act_method
=
'gelu'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
self
.
use_glu
=
False
def
generate_inputs
(
self
):
self
.
x
=
np
.
random
.
randint
(
low
=-
1000
,
high
=
1000
,
size
=
(
self
.
rows
,
self
.
cols
)
).
astype
(
'int32'
)
self
.
bias
=
np
.
zeros
(
self
.
cols
).
astype
(
'float32'
)
self
.
dequant_scales
=
np
.
ones
(
self
.
cols
).
astype
(
'float32'
)
quant_params_cols
=
self
.
cols
//
2
if
self
.
use_glu
else
self
.
cols
self
.
shift
=
np
.
zeros
(
quant_params_cols
).
astype
(
'float32'
)
self
.
smooth
=
np
.
ones
(
quant_params_cols
).
astype
(
'float32'
)
def
compute_baseline_output
(
self
):
input_dequanted
=
fake_dequant
(
self
.
x
.
astype
(
'float32'
),
self
.
dequant_scales
)
output_tmp
=
gelu
(
input_dequanted
+
self
.
bias
)
out
=
fake_quant
(
output_tmp
,
self
.
shift
,
self
.
smooth
,
self
.
quant_scale
,
self
.
quant_max_bound
,
self
.
quant_min_bound
,
self
.
quant_round_type
,
)
return
out
def
compute_paddle_output
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
paddle
.
to_tensor
(
self
.
x
)
bias
=
paddle
.
to_tensor
(
convert_float_to_uint16
(
self
.
bias
))
dequant_scales
=
paddle
.
to_tensor
(
self
.
dequant_scales
)
shift
=
paddle
.
to_tensor
(
convert_float_to_uint16
(
self
.
shift
))
smooth
=
paddle
.
to_tensor
(
convert_float_to_uint16
(
self
.
smooth
))
out
=
fused_act_bias_wrapper
(
x
=
x
,
bias
=
bias
,
dequant_scales
=
dequant_scales
,
shift
=
shift
,
smooth
=
smooth
,
act_method
=
self
.
act_method
,
compute_dtype
=
self
.
compute_dtype
,
rows
=
self
.
rows
,
cols
=
self
.
cols
,
quant_scale
=
self
.
quant_scale
,
quant_round_type
=
self
.
quant_round_type
,
quant_max_bound
=
self
.
quant_max_bound
,
quant_min_bound
=
self
.
quant_min_bound
,
)
return
out
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestQuantGegluBF16
(
TestQuantBF16
):
def
init_test_case
(
self
):
self
.
atol
=
1
self
.
compute_dtype
=
'bf16'
self
.
act_method
=
'geglu'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
self
.
use_glu
=
True
def
compute_baseline_output
(
self
):
input_dequanted
=
fake_dequant
(
self
.
x
.
astype
(
'float32'
),
self
.
dequant_scales
)
tmp
=
(
input_dequanted
+
self
.
bias
).
astype
(
'float32'
)
tmp_head
=
tmp
[:,
:
self
.
cols
//
2
]
tmp_tail
=
tmp
[:,
self
.
cols
//
2
:]
out_tmp
=
gelu
(
tmp_head
).
astype
(
'float32'
)
*
tmp_tail
out
=
fake_quant
(
out_tmp
,
self
.
shift
,
self
.
smooth
,
self
.
quant_scale
,
self
.
quant_max_bound
,
self
.
quant_min_bound
,
self
.
quant_round_type
,
)
return
out
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestQuantSwigluBF16
(
TestQuantBF16
):
def
init_test_case
(
self
):
self
.
atol
=
1
self
.
compute_dtype
=
'bf16'
self
.
act_method
=
'swiglu'
self
.
quant_scale
=
0.5
self
.
quant_round_type
=
1
self
.
quant_max_bound
=
127.0
self
.
quant_min_bound
=
-
127.0
self
.
use_glu
=
True
def
compute_baseline_output
(
self
):
input_dequanted
=
fake_dequant
(
self
.
x
.
astype
(
'float32'
),
self
.
dequant_scales
)
tmp
=
(
input_dequanted
+
self
.
bias
).
astype
(
'float32'
)
tmp_head
=
tmp
[:,
:
self
.
cols
//
2
]
tmp_tail
=
tmp
[:,
self
.
cols
//
2
:]
out_tmp
=
swish
(
tmp_head
).
astype
(
'float32'
)
*
tmp_tail
out
=
fake_quant
(
out_tmp
,
self
.
shift
,
self
.
smooth
,
self
.
quant_scale
,
self
.
quant_max_bound
,
self
.
quant_min_bound
,
self
.
quant_round_type
,
)
return
out
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestAssert
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rows
=
20
self
.
cols
=
512
self
.
dtype
=
'float32'
self
.
act_method
=
'gelu'
def
test_assert_case1
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
np
.
random
.
randint
(
low
=-
16
,
high
=
16
,
size
=
(
self
.
rows
,
self
.
cols
)
).
astype
(
'int32'
)
bias
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
self
.
dtype
)
try
:
out
=
fused_act_bias_wrapper
(
x
=
paddle
.
to_tensor
(
x
),
bias
=
paddle
.
to_tensor
(
bias
),
rows
=
self
.
rows
,
cols
=
self
.
cols
,
)
except
ValueError
as
e
:
pass
def
test_assert_case2
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
np
.
random
.
randint
(
low
=-
16
,
high
=
16
,
size
=
(
self
.
rows
,
self
.
cols
)
).
astype
(
'int32'
)
bias
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
self
.
dtype
)
try
:
out
=
fused_act_bias_wrapper
(
x
=
paddle
.
to_tensor
(
x
),
bias
=
paddle
.
to_tensor
(
bias
),
rows
=
self
.
rows
,
cols
=
self
.
cols
,
compute_dtype
=
'fp16'
,
)
except
ValueError
as
e
:
pass
def
test_assert_case3
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
np
.
random
.
randint
(
low
=-
16
,
high
=
16
,
size
=
(
self
.
rows
,
self
.
cols
)
).
astype
(
'int32'
)
bias
=
np
.
random
.
rand
(
self
.
cols
).
astype
(
self
.
dtype
)
act_method
=
"error_type"
try
:
out
=
fused_act_bias_wrapper
(
x
=
paddle
.
to_tensor
(
x
),
bias
=
paddle
.
to_tensor
(
bias
),
rows
=
self
.
rows
,
cols
=
self
.
cols
,
compute_dtype
=
'fp16'
,
act_method
=
act_method
,
)
except
ValueError
as
e
:
pass
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestWithoutBias
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
seed
(
2017
)
np
.
random
.
seed
(
2017
)
self
.
op_type
=
"fused_bias_act"
self
.
rtol
=
1e-5
self
.
atol
=
1e-3
self
.
rows
=
20
self
.
cols
=
512
self
.
dtype
=
'float32'
self
.
act_method
=
'gelu'
self
.
use_glu
=
False
self
.
init_test_case
()
self
.
generate_inputs
()
def
init_test_case
(
self
):
pass
def
generate_inputs
(
self
):
self
.
x
=
(
np
.
random
.
rand
(
self
.
rows
,
self
.
cols
)
*
16
).
astype
(
self
.
dtype
)
# self.bias = np.random.rand(self.cols).astype(self.dtype)
def
compute_baseline_output
(
self
):
out
=
gelu
(
self
.
x
).
astype
(
self
.
dtype
)
return
out
def
compute_paddle_output
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
x
=
paddle
.
to_tensor
(
self
.
x
)
return
fused_act_bias_wrapper
(
x
=
x
,
bias
=
None
,
rows
=
self
.
rows
,
cols
=
self
.
cols
,
act_method
=
self
.
act_method
,
)
def
test_check_output
(
self
):
final_out_ref
=
self
.
compute_baseline_output
()
final_out
=
self
.
compute_paddle_output
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录