未验证 提交 0a4d1999 编写于 作者: F freeliuzc 提交者: GitHub

[inference] Add FusedBiasActKernel (#55301)

* add init value for CudaSwishFunctor

* add new phi kernel fusedBiasActKernel
上级 d12837d3
......@@ -63,6 +63,17 @@
data_type : x
optional : bias, x_max
- op : fused_bias_act
args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", int rows = -1, int cols = -1, float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0)
output : Tensor(out)
infer_meta :
func: FusedBiasActInferMeta
kernel :
func : fused_bias_act
data_type : x
optional : bias, dequant_scales, shift, smooth
support_dygraph_mode : true
- op : fused_dropout_add
args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false)
optional : seed_tensor
......
......@@ -1335,6 +1335,136 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
sequencenum->set_dtype(DataType::FLOAT32);
}
void FusedBiasActInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& dequant_scales,
const MetaTensor& shift,
const MetaTensor& smooth,
const std::string& act_method,
const std::string& compute_dtype,
int rows,
int cols,
float quant_scale,
int quant_round_type,
float quant_max_bound,
float quant_min_bound,
MetaTensor* out) {
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(x_dims.size(),
2,
phi::errors::InvalidArgument(
"The size of Input(x) must be 2: %s", x_dims));
auto token_num = x_dims[0];
auto dim = x_dims[1];
PADDLE_ENFORCE_GT(
rows, 0, phi::errors::InvalidArgument("The size of Attr(rows) must > 0"));
PADDLE_ENFORCE_GT(
cols, 0, phi::errors::InvalidArgument("The size of Attr(cols) must > 0"));
if (act_method == "geglu" || act_method == "swiglu") {
PADDLE_ENFORCE_EQ(
dim % 2,
0,
phi::errors::InvalidArgument(
"The seconde dimension of x must be even, but receive %d", dim));
dim /= 2;
out->set_dims(phi::make_ddim({token_num, dim}));
} else if (act_method == "gelu") {
out->set_dims(phi::make_ddim({token_num, dim}));
} else {
PADDLE_THROW(
errors::InvalidArgument("act_method must be geglu, swiglu or gelu, "
"but get act_method (%s)",
act_method));
}
auto FBADtypeCheck = [](const MetaTensor& check_tensor,
const std::string& tensor_name,
const std::string& compute_dtype) {
if (compute_dtype == "bf16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::BFLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp32") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT32,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
}
};
// In the case of quantization enabled, the dtype for computation is
// determined based on compute_dtype.
if (x.dtype() == phi::DataType::INT32) {
PADDLE_ENFORCE_NE(
compute_dtype,
"default",
phi::errors::InvalidArgument(
"If Input(x) dtype is INT32, Attr(compute_dtype) must be set."));
if (bias) {
FBADtypeCheck(bias, "bias", compute_dtype);
}
if (quant_scale > 0) {
out->set_dtype(phi::DataType::INT8);
} else {
if (compute_dtype == "bf16") {
out->set_dtype(phi::DataType::BFLOAT16);
} else if (compute_dtype == "fp16") {
out->set_dtype(phi::DataType::FLOAT16);
} else if (compute_dtype == "fp32") {
out->set_dtype(phi::DataType::FLOAT32);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"In the case of quantization enabled with Input(x) INT32, "
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
"but get compute_dtype (%s)",
compute_dtype));
}
}
} else {
// x.dtype() != phi::DataType::INT32
if (bias) {
if (compute_dtype != "default") {
FBADtypeCheck(bias, "bias", compute_dtype);
FBADtypeCheck(x, "x", compute_dtype);
} else {
PADDLE_ENFORCE_EQ(
x.dtype(),
bias.dtype(),
phi::errors::InvalidArgument("Input(x) and Input(bias) must be the "
"same dtype in this situation"));
}
} else {
// bias not exist
if (compute_dtype != "default") {
FBADtypeCheck(x, "x", compute_dtype);
}
}
if (quant_scale > 0) {
out->set_dtype(phi::DataType::INT8);
} else {
out->set_dtype(x.dtype());
}
}
out->set_layout(x.layout());
}
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
......
......@@ -279,6 +279,21 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor* sequencenum,
MetaTensor* out);
void FusedBiasActInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& dequant_scales,
const MetaTensor& shift,
const MetaTensor& smooth,
const std::string& act_method,
const std::string& compute_dtype,
int rows,
int cols,
float quant_scale,
int quant_round_type,
float quant_max_bound,
float quant_min_bound,
MetaTensor* out);
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
......
......@@ -3923,7 +3923,7 @@ template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float beta = 1.0;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
namespace funcs {
#ifndef PADDLE_WITH_HIP
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max) return max;
if (v < min) return min;
return v;
}
template <typename InType, typename OutType>
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * input;
if (round_type == 0) {
quant_value = static_cast<float>(rint(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}
template <typename T>
struct Load {
explicit Load(const T *src) : src_(src) {}
template <int VecSize>
__device__ void load(phi::AlignedVector<T, VecSize> *dst, int idx) {
phi::Load<T, VecSize>(src_ + idx, dst);
}
const T *src_;
};
template <typename T, bool Smooth = false>
struct Store {
explicit Store(T *dst) : dst_(dst) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, int idx) {
phi::Store<T, VecSize>(src, dst_ + idx);
}
T *dst_;
};
template <typename T>
struct Store<T, true> {
Store(T *dst, const T *shift, const T *smooth, const int cols)
: dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, int idx) {
using Vec = phi::AlignedVector<T, VecSize>;
Vec shift_vec;
Vec smooth_vec;
phi::Load<T, VecSize>(shift_ + idx % cols_, &shift_vec);
phi::Load<T, VecSize>(smooth_ + idx % cols_, &smooth_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src[i] = (src[i] + shift_vec[i]) * smooth_vec[i];
}
phi::Store<T, VecSize>(src, dst_ + idx);
}
T *dst_;
const T *shift_;
const T *smooth_;
const int cols_;
};
template <typename T>
struct DequantLoad {
DequantLoad(const int32_t *src, const float *dequant_scales, const int cols)
: src_(src), dequant_scales_(dequant_scales), cols_(cols) {}
template <int VecSize>
__device__ void load(phi::AlignedVector<T, VecSize> *dst, int idx) {
using SrcVec = phi::AlignedVector<int32_t, VecSize>;
using DstVec = phi::AlignedVector<T, VecSize>;
using ScaleVec = phi::AlignedVector<float, VecSize>;
SrcVec src_vec;
DstVec dst_vec;
ScaleVec scale_vec;
phi::Load<int32_t, VecSize>(src_ + idx, &src_vec);
phi::Load<float, VecSize>(dequant_scales_ + idx % cols_, &scale_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dst_vec[i] =
static_cast<T>(static_cast<float>(src_vec[i]) * scale_vec[i]);
}
*dst = dst_vec;
}
const int32_t *src_;
const float *dequant_scales_;
const int cols_;
};
template <typename T, bool Smooth = false>
struct QuantStore {
QuantStore(int8_t *dst,
const int quant_round_type,
const float quant_scale,
const float quant_max_bound,
const float quant_min_bound)
: dst_(dst),
quant_round_type_(quant_round_type),
quant_scale_(quant_scale),
quant_max_bound_(quant_max_bound),
quant_min_bound_(quant_min_bound) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, // NOLINT
int idx) { // NOLINT
using DstVec = phi::AlignedVector<int8_t, VecSize>;
DstVec dst_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dst_vec[i] = QuantHelperFunc<float, int8_t>(static_cast<float>(src[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
}
phi::Store<int8_t, VecSize>(dst_vec, dst_ + idx);
}
int8_t *dst_;
const int quant_round_type_;
const float quant_scale_;
const float quant_max_bound_;
const float quant_min_bound_;
};
template <typename T>
struct QuantStore<T, true> {
QuantStore(int8_t *dst,
const T *shift,
const T *smooth,
const int cols,
const int quant_round_type,
const float quant_scale,
const float quant_max_bound,
const float quant_min_bound)
: dst_(dst),
shift_(shift),
smooth_(smooth),
cols_(cols),
quant_round_type_(quant_round_type),
quant_scale_(quant_scale),
quant_max_bound_(quant_max_bound),
quant_min_bound_(quant_min_bound) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, // NOLINT
int idx) { // NOLINT
using DstVec = phi::AlignedVector<int8_t, VecSize>;
using Vec = phi::AlignedVector<T, VecSize>;
DstVec dst_vec;
Vec shift_vec;
Vec smooth_vec;
phi::Load<T, VecSize>(shift_ + idx % cols_, &shift_vec);
phi::Load<T, VecSize>(smooth_ + idx % cols_, &smooth_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src[i] = (src[i] + shift_vec[i]) * smooth_vec[i];
dst_vec[i] = QuantHelperFunc<float, int8_t>(static_cast<float>(src[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
}
phi::Store<int8_t, VecSize>(dst_vec, dst_ + idx);
}
int8_t *dst_;
const int quant_round_type_;
const float quant_scale_;
const float quant_max_bound_;
const float quant_min_bound_;
const T *shift_;
const T *smooth_;
const int cols_;
};
#endif
} // namespace funcs
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h"
PHI_DECLARE_bool(use_fast_math);
namespace phi {
namespace fusion {
#ifndef PADDLE_WITH_HIP
template <typename T,
typename Functor,
int VecSize,
typename LoadFunc,
typename StoreFunc>
__global__ void ActFFNGlu(const T *bias,
Functor act_functor,
const int token_num,
const int hid_dim,
const int elem_num,
LoadFunc load_func,
StoreFunc store_func) {
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec1;
LoadT src_vec2;
LoadT bias_vec1;
LoadT bias_vec2;
const int global_tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = global_tid * VecSize; i < elem_num;
i += gridDim.x * blockDim.x * VecSize) {
int bi = i / hid_dim;
int idx = i % hid_dim;
load_func.template load<VecSize>(&src_vec1, bi * hid_dim * 2 + idx);
load_func.template load<VecSize>(&src_vec2,
bi * hid_dim * 2 + idx + hid_dim);
if (bias) {
phi::Load<T, VecSize>(&bias[idx], &bias_vec1);
phi::Load<T, VecSize>(&bias[idx + hid_dim], &bias_vec2);
}
#pragma unroll
for (int j = 0; j < VecSize; j++) {
if (bias) {
src_vec1[j] += bias_vec1[j];
src_vec2[j] += bias_vec2[j];
}
src_vec1[j] = act_functor(src_vec1[j]);
src_vec1[j] *= src_vec2[j];
}
store_func.template store<VecSize>(src_vec1, bi * hid_dim + idx);
}
}
template <typename T,
typename Context,
typename Functor,
typename LoadFunc,
typename StoreFunc,
typename LoadT = T>
void LaunchActFFNGlu(const Context &dev_ctx,
const T *bias,
const int token_num,
const int hid_dim,
LoadFunc load_func,
StoreFunc store_func) {
constexpr int VecSize = 16;
constexpr int PackSize = VecSize / sizeof(LoadT);
const int elem_cnt = token_num * hid_dim;
const int blocksize = 128;
int grid_size = 1;
Functor functor;
switch (hid_dim % PackSize) {
case 0:
GetNumBlocks(elem_cnt / PackSize, &grid_size);
ActFFNGlu<T, Functor, PackSize>
<<<grid_size, blocksize, 0, dev_ctx.stream()>>>(bias,
functor,
token_num,
hid_dim,
elem_cnt,
load_func,
store_func);
break;
default:
GetNumBlocks(elem_cnt, &grid_size);
ActFFNGlu<T, Functor, 1><<<grid_size, blocksize, 0, dev_ctx.stream()>>>(
bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func);
break;
}
}
template <typename T,
typename Functor,
int VecSize,
typename LoadFunc,
typename StoreFunc>
__global__ void BiasAct(const T *bias,
Functor act_functor,
const int rows,
const int cols,
const int elem_num,
LoadFunc load_func,
StoreFunc store_func) {
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
LoadT bias_vec;
// Zero Initialize BiasVec.
#pragma unroll
for (int unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
bias_vec[unroll_idx] = 0;
}
const int global_tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = global_tid * VecSize; i < elem_num;
i += gridDim.x * blockDim.x * VecSize) {
int row_idx = i / cols;
int col_idx = i % cols;
int linear_idx = row_idx * cols + col_idx;
load_func.template load<VecSize>(&src_vec, linear_idx);
if (bias) {
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
}
#pragma unroll
for (int j = 0; j < VecSize; j++) {
if (bias) {
src_vec[j] += bias_vec[j];
}
src_vec[j] = act_functor(src_vec[j]);
}
store_func.template store<VecSize>(src_vec, linear_idx);
}
}
template <typename T,
typename Context,
typename Functor,
typename LoadFunc,
typename StoreFunc,
typename LoadT = T>
void LaunchBiasAct(const Context &dev_ctx,
const T *bias,
const int token_num,
const int hid_dim,
LoadFunc load_func,
StoreFunc store_func) {
constexpr int VecSize = 16;
constexpr int PackSize = VecSize / sizeof(LoadT);
const int elem_cnt = token_num * hid_dim;
const int blocksize = 128;
int grid_size = 1;
Functor functor;
switch (hid_dim % PackSize) {
case 0:
GetNumBlocks(elem_cnt / PackSize, &grid_size);
BiasAct<T, Functor, PackSize>
<<<grid_size, blocksize, 0, dev_ctx.stream()>>>(bias,
functor,
token_num,
hid_dim,
elem_cnt,
load_func,
store_func);
break;
default:
GetNumBlocks(elem_cnt, &grid_size);
BiasAct<T, Functor, 1><<<grid_size, blocksize, 0, dev_ctx.stream()>>>(
bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func);
break;
}
}
template <typename T,
typename Context,
typename LoadFunc,
typename StoreFunc,
typename LoadT = T>
void ComputeImpl(const Context &dev_ctx,
const T *bias_data,
const std::string &act_method,
int rows,
int cols,
LoadFunc load_func,
StoreFunc store_func) {
if (act_method == "geglu") {
// Note(Zhengzekang): For GLU structure, we need divide the cols by 2.
VLOG(8) << "Doing geglu";
LaunchActFFNGlu<T, Context, GeluFunctor<T>, LoadFunc, StoreFunc, LoadT>(
dev_ctx, bias_data, rows, cols / 2, load_func, store_func);
} else if (act_method == "swiglu") {
VLOG(8) << "Doing swiglu";
LaunchActFFNGlu<T,
Context,
CudaSwishFunctor<T>,
LoadFunc,
StoreFunc,
LoadT>(
dev_ctx, bias_data, rows, cols / 2, load_func, store_func);
} else if (act_method == "gelu") {
if (FLAGS_use_fast_math) {
VLOG(8) << "Doing Fast GELU";
LaunchBiasAct<T, Context, FastGeluFunctor<T>, LoadFunc, StoreFunc, LoadT>(
dev_ctx, bias_data, rows, cols, load_func, store_func);
} else {
VLOG(8) << "Doing GELU";
LaunchBiasAct<T, Context, GeluFunctor<T>, LoadFunc, StoreFunc, LoadT>(
dev_ctx, bias_data, rows, cols, load_func, store_func);
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Currently Only Support GeGLU, SwiGLU, GeLU"));
}
}
template <typename T, typename Context>
void DispatchComputeImpl(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor *bias,
const DenseTensor *dequant_scales,
const std::string &act_method,
int rows,
int cols,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor *out) {
const T *bias_data = bias == nullptr ? nullptr : bias->data<T>();
if (dequant_scales != nullptr && quant_scale > 0) {
DequantLoad<T> load_func(
x.data<int32_t>(), dequant_scales->data<float>(), cols);
QuantStore<T> store_func(dev_ctx.template Alloc<int8_t>(out),
quant_round_type,
quant_scale,
quant_max_bound,
quant_min_bound);
ComputeImpl<T, Context, DequantLoad<T>, QuantStore<T>, int32_t>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
} else if (dequant_scales == nullptr && quant_scale > 0) {
Load<T> load_func(x.data<T>());
QuantStore<T> store_func(dev_ctx.template Alloc<int8_t>(out),
quant_round_type,
quant_scale,
quant_max_bound,
quant_min_bound);
ComputeImpl<T>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
} else if (dequant_scales != nullptr && quant_scale <= 0) {
DequantLoad<T> load_func(
x.data<int32_t>(), dequant_scales->data<float>(), cols);
Store<T> store_func(dev_ctx.template Alloc<T>(out));
ComputeImpl<T, Context, DequantLoad<T>, Store<T>, int32_t>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
} else {
Load<T> load_func(x.data<T>());
Store<T> store_func(dev_ctx.template Alloc<T>(out));
ComputeImpl<T>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
}
}
template <typename T, typename Context>
void DispatchComputeImpl(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor *bias,
const DenseTensor *dequant_scales,
const DenseTensor *shift,
const DenseTensor *smooth,
const std::string &act_method,
int rows,
int cols,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor *out) {
bool use_glu = (act_method == "geglu" || act_method == "swiglu");
const T *bias_data = bias == nullptr ? nullptr : bias->data<T>();
if (dequant_scales != nullptr && quant_scale > 0) {
int8_t *out_data = dev_ctx.template Alloc<int8_t>(out);
DequantLoad<T> load_func(
x.data<int32_t>(), dequant_scales->data<float>(), cols);
QuantStore<T, true> store_func(dev_ctx.template Alloc<int8_t>(out),
shift->data<T>(),
smooth->data<T>(),
use_glu ? cols / 2 : cols,
quant_round_type,
quant_scale,
quant_max_bound,
quant_min_bound);
ComputeImpl<T, Context, DequantLoad<T>, QuantStore<T, true>, int32_t>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
} else if (dequant_scales == nullptr && quant_scale > 0) {
Load<T> load_func(x.data<T>());
QuantStore<T, true> store_func(dev_ctx.template Alloc<int8_t>(out),
shift->data<T>(),
smooth->data<T>(),
use_glu ? cols / 2 : cols,
quant_round_type,
quant_scale,
quant_max_bound,
quant_min_bound);
ComputeImpl<T>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
} else if (dequant_scales != nullptr && quant_scale <= 0) {
DequantLoad<T> load_func(
x.data<int32_t>(), dequant_scales->data<float>(), cols);
Store<T, true> store_func(dev_ctx.template Alloc<T>(out),
shift->data<T>(),
smooth->data<T>(),
use_glu ? cols / 2 : cols);
ComputeImpl<T, Context, DequantLoad<T>, Store<T, true>, int32_t>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
} else {
Load<T> load_func(x.data<T>());
Store<T, true> store_func(dev_ctx.template Alloc<T>(out),
shift->data<T>(),
smooth->data<T>(),
use_glu ? cols / 2 : cols);
ComputeImpl<T>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
}
}
struct NormalVersion {};
struct UnusedVersion {};
template <typename T>
struct DispatchDtypeTrait {
using FuncVersion = NormalVersion;
};
template <>
struct DispatchDtypeTrait<int32_t> {
using FuncVersion = UnusedVersion;
};
template <typename T, typename Context>
void DispatchWithDtype(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &bias,
const paddle::optional<DenseTensor> &dequant_scales,
const paddle::optional<DenseTensor> &shift,
const paddle::optional<DenseTensor> &smooth,
const std::string &act_method,
int rows,
int cols,
float quant_scale,
int quant_round_type,
float quant_max_bound,
float quant_min_bound,
DenseTensor *out,
NormalVersion) {
auto *bias_p = bias.get_ptr();
auto *dequant_scales_p = dequant_scales.get_ptr();
auto *shift_p = shift.get_ptr();
auto *smooth_p = smooth.get_ptr();
if (dequant_scales_p != nullptr) {
if (shift_p != nullptr) {
DispatchComputeImpl<T>(dev_ctx,
x,
bias_p,
dequant_scales_p,
shift_p,
smooth_p,
act_method,
rows,
cols,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
out);
} else {
DispatchComputeImpl<T>(dev_ctx,
x,
bias_p,
dequant_scales_p,
act_method,
rows,
cols,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
out);
}
} else {
const T *bias_data = bias_p == nullptr ? nullptr : bias_p->data<T>();
Load<T> load_func(x.data<T>());
Store<T> store_func(dev_ctx.template Alloc<T>(out));
ComputeImpl<T>(
dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
}
}
// (not use) only for registering int32_t
template <typename T, typename Context>
void DispatchWithDtype(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &bias,
const paddle::optional<DenseTensor> &dequant_scales,
const paddle::optional<DenseTensor> &shift,
const paddle::optional<DenseTensor> &smooth,
const std::string &act_method,
int rows,
int cols,
float quant_scale,
int quant_round_type,
float quant_max_bound,
float quant_min_bound,
DenseTensor *out,
UnusedVersion) {}
#endif
template <typename T, typename Context>
void FusedBiasActKernel(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &bias,
const paddle::optional<DenseTensor> &dequant_scales,
const paddle::optional<DenseTensor> &shift,
const paddle::optional<DenseTensor> &smooth,
const std::string &act_method,
const std::string &compute_dtype,
int rows,
int cols,
float quant_scale,
int quant_round_type,
float quant_max_bound,
float quant_min_bound,
DenseTensor *out) {
#ifndef PADDLE_WITH_HIP
if (x.dtype() == phi::DataType::INT32) {
if (compute_dtype == "bf16") {
DispatchWithDtype<phi::dtype::bfloat16, Context>(
dev_ctx,
x,
bias,
dequant_scales,
shift,
smooth,
act_method,
rows,
cols,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
out,
typename DispatchDtypeTrait<phi::dtype::bfloat16>::FuncVersion{});
} else if (compute_dtype == "fp16") {
DispatchWithDtype<phi::dtype::float16, Context>(
dev_ctx,
x,
bias,
dequant_scales,
shift,
smooth,
act_method,
rows,
cols,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
out,
typename DispatchDtypeTrait<phi::dtype::float16>::FuncVersion{});
} else if (compute_dtype == "fp32") {
DispatchWithDtype<float, Context>(
dev_ctx,
x,
bias,
dequant_scales,
shift,
smooth,
act_method,
rows,
cols,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
out,
typename DispatchDtypeTrait<float>::FuncVersion{});
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"In the case of quantization enabled with Input(x) INT32, "
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
"but get compute_dtype (%s)",
compute_dtype));
}
} else {
DispatchWithDtype<T, Context>(
dev_ctx,
x,
bias,
dequant_scales,
shift,
smooth,
act_method,
rows,
cols,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
out,
typename DispatchDtypeTrait<T>::FuncVersion{});
}
#endif
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_bias_act,
GPU,
ALL_LAYOUT,
phi::fusion::FusedBiasActKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16,
int32_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
#endif
// for windows build
#define M_SQRT1_2 0.70710678118654752440
namespace phi {
namespace fusion {
#ifndef PADDLE_WITH_HIP
template <typename T>
struct GeluComputeType;
template <>
struct GeluComputeType<phi::dtype::bfloat16> {
using Type = float;
};
template <>
struct GeluComputeType<phi::dtype::float16> {
using Type = float;
};
template <>
struct GeluComputeType<float> {
using Type = float;
};
template <typename T>
using GeluType = typename GeluComputeType<T>::Type;
using phi::funcs::DequantLoad;
using phi::funcs::Load;
using phi::funcs::QuantStore;
using phi::funcs::Store;
template <typename T>
struct BaseActivationFunctor {
using ELEMENT_TYPE = T;
using AttrPair = std::vector<std::pair<const char *, float *>>;
AttrPair GetAttrs() { return AttrPair(); }
};
// For windows build
template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta = 1.0;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// swish(x) = x / (1 + exp(-beta * x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
return static_cast<T>(x / (one + exp(-b * x)));
}
};
// TODO(lzc): transfer to phi::funcs
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
using U = GeluType<T>;
const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
return static_cast<T>(out);
}
};
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
return phi::GeluFwd<T, true>(x);
}
};
inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
constexpr int kBlockSize = 128;
constexpr int kNumWaves = 16;
const int device_id = phi::backends::gpu::GetCurrentDeviceId();
const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id);
const int max_thread_per_multiprocessor =
phi::backends::gpu::GetGPUMultiProcessors(device_id);
*num_blocks =
std::max<int>(1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * max_thread_per_multiprocessor /
kBlockSize * kNumWaves));
return cudaSuccess;
}
#endif
} // namespace fusion
} // namespace phi
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import convert_float_to_uint16
from scipy.special import erf, expit
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def round_type_1_process(val):
dtype = type(val)
if val >= 0:
return dtype(np.floor(val + 0.5))
return dtype(np.ceil(val - 0.5))
# rounding to nearest ties away from zero
round_type_1 = np.vectorize(round_type_1_process)
M_SQRT1_2 = 0.70710678118654752440
def gelu(x):
out = (
0.5 * x.astype('float32') * (1.0 + erf(x.astype('float32') * M_SQRT1_2))
)
return out.astype(x.dtype)
def swish(x):
out = x.astype('float32') * expit(x.astype('float32'))
return out.astype(x.dtype)
def fake_dequant(values, dequant_scales):
out = values * dequant_scales.astype('float32')
return out
def fake_quant(
values, shift, smooth, quant_sacle, max_bound, min_bound, round_type
):
values_tmp = (values + shift) * smooth
values_tmp = max_bound * quant_sacle * values_tmp
if round_type == 0:
values_tmp = np.rint(values_tmp)
elif round_type == 1:
values_tmp = round_type_1(values_tmp)
return np.clip(values_tmp, min_bound, max_bound).astype(np.int8)
def fused_act_bias_wrapper(
x,
bias=None,
dequant_scales=None,
shift=None,
smooth=None,
act_method='gelu',
compute_dtype='default',
rows=0,
cols=0,
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0,
):
return paddle._C_ops.fused_bias_act(
x,
bias,
dequant_scales,
shift,
smooth,
act_method,
compute_dtype,
rows,
cols,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFusedBiasActOp(unittest.TestCase):
def setUp(self):
paddle.seed(2017)
np.random.seed(2017)
self.op_type = "fused_bias_act"
self.rtol = 1e-5
self.atol = 1e-3
self.rows = 20
self.cols = 512
self.dtype = 'float32'
self.act_method = 'gelu'
self.compute_dtype = 'default'
self.use_glu = False
self.init_test_case()
self.generate_inputs()
def init_test_case(self):
pass
def generate_inputs(self):
self.x = (np.random.rand(self.rows, self.cols) * 16).astype(self.dtype)
self.bias = np.random.rand(self.cols).astype(self.dtype)
def compute_baseline_output(self):
out = gelu(self.x + self.bias).astype(self.dtype)
return out
def compute_paddle_output(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = paddle.to_tensor(self.x)
bias = paddle.to_tensor(self.bias)
return fused_act_bias_wrapper(
x=x,
bias=bias,
rows=self.rows,
cols=self.cols,
act_method=self.act_method,
compute_dtype=self.compute_dtype,
)
def test_check_output(self):
final_out_ref = self.compute_baseline_output()
final_out = self.compute_paddle_output()
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
class TestBaseFP16(TestFusedBiasActOp):
def init_test_case(self):
self.dtype = np.float16
self.act_method = 'gelu'
class TestWithComTypeFP32(TestFusedBiasActOp):
def init_test_case(self):
self.dtype = 'float32'
self.act_method = 'gelu'
self.compute_dtype = 'fp32'
class TestWithComTypeFP16(TestFusedBiasActOp):
def init_test_case(self):
self.dtype = 'float16'
self.act_method = 'gelu'
self.compute_dtype = 'fp16'
class TestFastGeluFP16(TestFusedBiasActOp):
def use_fast_math(self, enabled):
paddle.set_flags({'FLAGS_use_fast_math': enabled})
def init_test_case(self):
self.dtype = np.float16
self.act_method = 'gelu'
def compute_baseline_output(self):
out = F.gelu(
paddle.to_tensor(self.x) + paddle.to_tensor(self.bias),
approximate=True,
)
return out
def compute_paddle_output(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = paddle.to_tensor(self.x)
bias = paddle.to_tensor(self.bias)
self.use_fast_math(True)
out = fused_act_bias_wrapper(
x=x,
bias=bias,
rows=self.rows,
cols=self.cols,
act_method=self.act_method,
)
self.use_fast_math(False)
return out
class TestGegluFP16(TestFusedBiasActOp):
def init_test_case(self):
self.dtype = np.float16
self.act_method = 'geglu'
def compute_baseline_output(self):
res_tmp = (self.x + self.bias).astype(self.dtype)
res_tmp_head = res_tmp[:, : self.cols // 2]
res_tmp_tail = res_tmp[:, self.cols // 2 :]
res_tmp_head_act = gelu(res_tmp_head)
out = res_tmp_head_act * res_tmp_tail
return out
class TestSwigluFP16(TestFusedBiasActOp):
def init_test_case(self):
self.dtype = np.float16
self.act_method = 'swiglu'
def compute_baseline_output(self):
res_tmp = (self.x + self.bias).astype(self.dtype)
res_tmp_head = res_tmp[:, : self.cols // 2]
res_tmp_tail = res_tmp[:, self.cols // 2 :]
res_tmp_head_act = swish(res_tmp_head)
out = res_tmp_head_act * res_tmp_tail
return out
class TestQuantFP32(TestFusedBiasActOp):
def init_test_case(self):
self.atol = 1
self.dtype = 'float32'
self.compute_dtype = 'fp32'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
def generate_inputs(self):
self.x = np.random.randint(
low=-16, high=16, size=(self.rows, self.cols)
).astype('int32')
self.bias = np.random.rand(self.cols).astype(self.dtype)
self.dequant_scales = np.random.rand(self.cols).astype('float32')
quant_params_cols = self.cols // 2 if self.use_glu else self.cols
self.shift = np.zeros(quant_params_cols).astype(self.dtype)
self.smooth = np.ones(quant_params_cols).astype(self.dtype)
def compute_baseline_output(self):
input_dequanted = fake_dequant(self.x, self.dequant_scales)
output_tmp = gelu(input_dequanted + self.bias).astype(self.dtype)
out = fake_quant(
output_tmp,
self.shift,
self.smooth,
self.quant_scale,
self.quant_max_bound,
self.quant_min_bound,
self.quant_round_type,
)
return out
def compute_paddle_output(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = paddle.to_tensor(self.x)
bias = paddle.to_tensor(self.bias)
dequant_scales = paddle.to_tensor(self.dequant_scales)
shift = paddle.to_tensor(self.shift)
smooth = paddle.to_tensor(self.smooth)
out = fused_act_bias_wrapper(
x=x,
bias=bias,
dequant_scales=dequant_scales,
shift=shift,
smooth=smooth,
act_method=self.act_method,
compute_dtype=self.compute_dtype,
rows=self.rows,
cols=self.cols,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
return out
class TestDequantFP32(TestQuantFP32):
def init_test_case(self):
self.rows = 10
self.cols = 10
self.atol = 1
self.dtype = 'float32'
self.compute_dtype = 'fp32'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
def generate_inputs(self):
self.x = np.random.randint(
low=-16, high=16, size=(self.rows, self.cols)
).astype('int32')
self.bias = np.random.rand(self.cols).astype(self.dtype)
self.dequant_scales = np.ones(self.cols).astype('float32')
def compute_baseline_output(self):
input_dequanted = fake_dequant(self.x, self.dequant_scales)
out = gelu(input_dequanted + self.bias).astype(self.dtype)
return out
def compute_paddle_output(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = paddle.to_tensor(self.x)
bias = paddle.to_tensor(self.bias)
dequant_scales = paddle.to_tensor(self.dequant_scales)
out = fused_act_bias_wrapper(
x=x,
bias=bias,
dequant_scales=dequant_scales,
act_method=self.act_method,
compute_dtype=self.compute_dtype,
rows=self.rows,
cols=self.cols,
)
return out
class TestQuantFP16(TestQuantFP32):
def init_test_case(self):
self.atol = 1
self.dtype = 'float16'
self.compute_dtype = 'fp16'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
class TestDequantFP16(TestDequantFP32):
def init_test_case(self):
self.rows = 10
self.cols = 10
self.atol = 1
self.dtype = 'float16'
self.compute_dtype = 'fp16'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
class TestQuantGegluFP16(TestQuantFP32):
def init_test_case(self):
self.atol = 1
self.dtype = 'float16'
self.compute_dtype = 'fp16'
self.act_method = 'geglu'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
self.use_glu = True
def compute_baseline_output(self):
input_dequanted = fake_dequant(self.x, self.dequant_scales)
tmp = (input_dequanted + self.bias).astype('float32')
tmp_head = tmp[:, : self.cols // 2]
tmp_tail = tmp[:, self.cols // 2 :]
out_tmp = gelu(tmp_head).astype('float32') * tmp_tail
out = fake_quant(
out_tmp,
self.shift,
self.smooth,
self.quant_scale,
self.quant_max_bound,
self.quant_min_bound,
self.quant_round_type,
)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFusedBiasActOpBF16(unittest.TestCase):
def setUp(self):
paddle.seed(2019)
np.random.seed(2019)
self.op_type = "fused_bias_act"
self.rtol = 1e-3
self.atol = 1e-3
self.rows = 20
self.cols = 512
self.act_method = 'gelu'
self.compute_dtype = 'default'
self.init_test_case()
self.generate_inputs()
def init_test_case(self):
pass
def generate_inputs(self):
self.x = np.random.rand(self.rows, self.cols).astype('float32') * 16
self.bias = np.random.rand(self.cols).astype('float32')
def compute_baseline_output(self):
out = gelu(self.x.astype('float32') + self.bias)
return convert_float_to_uint16(out)
def compute_paddle_output(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = paddle.to_tensor(convert_float_to_uint16(self.x))
bias = paddle.to_tensor(convert_float_to_uint16(self.bias))
out = fused_act_bias_wrapper(
x=x,
bias=bias,
act_method=self.act_method,
compute_dtype=self.compute_dtype,
rows=self.rows,
cols=self.cols,
)
return out
def test_check_output(self):
final_out_ref = self.compute_baseline_output()
final_out = self.compute_paddle_output()
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestWithComTypeBF16(unittest.TestCase):
def init_test_case(self):
self.act_method = 'geglu'
self.compute_dtype = 'bf16'
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGegluBF16(TestFusedBiasActOpBF16):
def init_test_case(self):
self.act_method = 'geglu'
self.compute_dtype = 'default'
def compute_baseline_output(self):
res_tmp = self.x + self.bias
res_tmp_head = res_tmp[:, : self.cols // 2]
res_tmp_tail = res_tmp[:, self.cols // 2 :]
res_tmp_head_act = gelu(res_tmp_head)
out = res_tmp_head_act * res_tmp_tail
return convert_float_to_uint16(out)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16 ",
)
class TestSwigluBF16(TestFusedBiasActOpBF16):
def init_test_case(self):
self.act_method = 'swiglu'
self.compute_dtype = 'default'
def compute_baseline_output(self):
res_tmp = self.x + self.bias
res_tmp_head = res_tmp[:, : self.cols // 2]
res_tmp_tail = res_tmp[:, self.cols // 2 :]
res_tmp_head_act = swish(res_tmp_head)
out = res_tmp_head_act * res_tmp_tail
return convert_float_to_uint16(out)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestQuantBF16(TestFusedBiasActOpBF16):
def init_test_case(self):
self.atol = 1
self.compute_dtype = 'bf16'
self.act_method = 'gelu'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
self.use_glu = False
def generate_inputs(self):
self.x = np.random.randint(
low=-1000, high=1000, size=(self.rows, self.cols)
).astype('int32')
self.bias = np.zeros(self.cols).astype('float32')
self.dequant_scales = np.ones(self.cols).astype('float32')
quant_params_cols = self.cols // 2 if self.use_glu else self.cols
self.shift = np.zeros(quant_params_cols).astype('float32')
self.smooth = np.ones(quant_params_cols).astype('float32')
def compute_baseline_output(self):
input_dequanted = fake_dequant(
self.x.astype('float32'), self.dequant_scales
)
output_tmp = gelu(input_dequanted + self.bias)
out = fake_quant(
output_tmp,
self.shift,
self.smooth,
self.quant_scale,
self.quant_max_bound,
self.quant_min_bound,
self.quant_round_type,
)
return out
def compute_paddle_output(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = paddle.to_tensor(self.x)
bias = paddle.to_tensor(convert_float_to_uint16(self.bias))
dequant_scales = paddle.to_tensor(self.dequant_scales)
shift = paddle.to_tensor(convert_float_to_uint16(self.shift))
smooth = paddle.to_tensor(convert_float_to_uint16(self.smooth))
out = fused_act_bias_wrapper(
x=x,
bias=bias,
dequant_scales=dequant_scales,
shift=shift,
smooth=smooth,
act_method=self.act_method,
compute_dtype=self.compute_dtype,
rows=self.rows,
cols=self.cols,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestQuantGegluBF16(TestQuantBF16):
def init_test_case(self):
self.atol = 1
self.compute_dtype = 'bf16'
self.act_method = 'geglu'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
self.use_glu = True
def compute_baseline_output(self):
input_dequanted = fake_dequant(
self.x.astype('float32'), self.dequant_scales
)
tmp = (input_dequanted + self.bias).astype('float32')
tmp_head = tmp[:, : self.cols // 2]
tmp_tail = tmp[:, self.cols // 2 :]
out_tmp = gelu(tmp_head).astype('float32') * tmp_tail
out = fake_quant(
out_tmp,
self.shift,
self.smooth,
self.quant_scale,
self.quant_max_bound,
self.quant_min_bound,
self.quant_round_type,
)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestQuantSwigluBF16(TestQuantBF16):
def init_test_case(self):
self.atol = 1
self.compute_dtype = 'bf16'
self.act_method = 'swiglu'
self.quant_scale = 0.5
self.quant_round_type = 1
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
self.use_glu = True
def compute_baseline_output(self):
input_dequanted = fake_dequant(
self.x.astype('float32'), self.dequant_scales
)
tmp = (input_dequanted + self.bias).astype('float32')
tmp_head = tmp[:, : self.cols // 2]
tmp_tail = tmp[:, self.cols // 2 :]
out_tmp = swish(tmp_head).astype('float32') * tmp_tail
out = fake_quant(
out_tmp,
self.shift,
self.smooth,
self.quant_scale,
self.quant_max_bound,
self.quant_min_bound,
self.quant_round_type,
)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestAssert(unittest.TestCase):
def setUp(self):
self.rows = 20
self.cols = 512
self.dtype = 'float32'
self.act_method = 'gelu'
def test_assert_case1(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = np.random.randint(
low=-16, high=16, size=(self.rows, self.cols)
).astype('int32')
bias = np.random.rand(self.cols).astype(self.dtype)
try:
out = fused_act_bias_wrapper(
x=paddle.to_tensor(x),
bias=paddle.to_tensor(bias),
rows=self.rows,
cols=self.cols,
)
except ValueError as e:
pass
def test_assert_case2(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = np.random.randint(
low=-16, high=16, size=(self.rows, self.cols)
).astype('int32')
bias = np.random.rand(self.cols).astype(self.dtype)
try:
out = fused_act_bias_wrapper(
x=paddle.to_tensor(x),
bias=paddle.to_tensor(bias),
rows=self.rows,
cols=self.cols,
compute_dtype='fp16',
)
except ValueError as e:
pass
def test_assert_case3(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = np.random.randint(
low=-16, high=16, size=(self.rows, self.cols)
).astype('int32')
bias = np.random.rand(self.cols).astype(self.dtype)
act_method = "error_type"
try:
out = fused_act_bias_wrapper(
x=paddle.to_tensor(x),
bias=paddle.to_tensor(bias),
rows=self.rows,
cols=self.cols,
compute_dtype='fp16',
act_method=act_method,
)
except ValueError as e:
pass
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestWithoutBias(unittest.TestCase):
def setUp(self):
paddle.seed(2017)
np.random.seed(2017)
self.op_type = "fused_bias_act"
self.rtol = 1e-5
self.atol = 1e-3
self.rows = 20
self.cols = 512
self.dtype = 'float32'
self.act_method = 'gelu'
self.use_glu = False
self.init_test_case()
self.generate_inputs()
def init_test_case(self):
pass
def generate_inputs(self):
self.x = (np.random.rand(self.rows, self.cols) * 16).astype(self.dtype)
# self.bias = np.random.rand(self.cols).astype(self.dtype)
def compute_baseline_output(self):
out = gelu(self.x).astype(self.dtype)
return out
def compute_paddle_output(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x = paddle.to_tensor(self.x)
return fused_act_bias_wrapper(
x=x,
bias=None,
rows=self.rows,
cols=self.cols,
act_method=self.act_method,
)
def test_check_output(self):
final_out_ref = self.compute_baseline_output()
final_out = self.compute_paddle_output()
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册