未验证 提交 0a4d1999 编写于 作者: F freeliuzc 提交者: GitHub

[inference] Add FusedBiasActKernel (#55301)

* add init value for CudaSwishFunctor

* add new phi kernel fusedBiasActKernel
上级 d12837d3
......@@ -63,6 +63,17 @@
data_type : x
optional : bias, x_max
- op : fused_bias_act
args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", int rows = -1, int cols = -1, float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0)
output : Tensor(out)
infer_meta :
func: FusedBiasActInferMeta
kernel :
func : fused_bias_act
data_type : x
optional : bias, dequant_scales, shift, smooth
support_dygraph_mode : true
- op : fused_dropout_add
args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false)
optional : seed_tensor
......
......@@ -1335,6 +1335,136 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
sequencenum->set_dtype(DataType::FLOAT32);
}
void FusedBiasActInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& dequant_scales,
const MetaTensor& shift,
const MetaTensor& smooth,
const std::string& act_method,
const std::string& compute_dtype,
int rows,
int cols,
float quant_scale,
int quant_round_type,
float quant_max_bound,
float quant_min_bound,
MetaTensor* out) {
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(x_dims.size(),
2,
phi::errors::InvalidArgument(
"The size of Input(x) must be 2: %s", x_dims));
auto token_num = x_dims[0];
auto dim = x_dims[1];
PADDLE_ENFORCE_GT(
rows, 0, phi::errors::InvalidArgument("The size of Attr(rows) must > 0"));
PADDLE_ENFORCE_GT(
cols, 0, phi::errors::InvalidArgument("The size of Attr(cols) must > 0"));
if (act_method == "geglu" || act_method == "swiglu") {
PADDLE_ENFORCE_EQ(
dim % 2,
0,
phi::errors::InvalidArgument(
"The seconde dimension of x must be even, but receive %d", dim));
dim /= 2;
out->set_dims(phi::make_ddim({token_num, dim}));
} else if (act_method == "gelu") {
out->set_dims(phi::make_ddim({token_num, dim}));
} else {
PADDLE_THROW(
errors::InvalidArgument("act_method must be geglu, swiglu or gelu, "
"but get act_method (%s)",
act_method));
}
auto FBADtypeCheck = [](const MetaTensor& check_tensor,
const std::string& tensor_name,
const std::string& compute_dtype) {
if (compute_dtype == "bf16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::BFLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp32") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT32,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
}
};
// In the case of quantization enabled, the dtype for computation is
// determined based on compute_dtype.
if (x.dtype() == phi::DataType::INT32) {
PADDLE_ENFORCE_NE(
compute_dtype,
"default",
phi::errors::InvalidArgument(
"If Input(x) dtype is INT32, Attr(compute_dtype) must be set."));
if (bias) {
FBADtypeCheck(bias, "bias", compute_dtype);
}
if (quant_scale > 0) {
out->set_dtype(phi::DataType::INT8);
} else {
if (compute_dtype == "bf16") {
out->set_dtype(phi::DataType::BFLOAT16);
} else if (compute_dtype == "fp16") {
out->set_dtype(phi::DataType::FLOAT16);
} else if (compute_dtype == "fp32") {
out->set_dtype(phi::DataType::FLOAT32);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"In the case of quantization enabled with Input(x) INT32, "
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
"but get compute_dtype (%s)",
compute_dtype));
}
}
} else {
// x.dtype() != phi::DataType::INT32
if (bias) {
if (compute_dtype != "default") {
FBADtypeCheck(bias, "bias", compute_dtype);
FBADtypeCheck(x, "x", compute_dtype);
} else {
PADDLE_ENFORCE_EQ(
x.dtype(),
bias.dtype(),
phi::errors::InvalidArgument("Input(x) and Input(bias) must be the "
"same dtype in this situation"));
}
} else {
// bias not exist
if (compute_dtype != "default") {
FBADtypeCheck(x, "x", compute_dtype);
}
}
if (quant_scale > 0) {
out->set_dtype(phi::DataType::INT8);
} else {
out->set_dtype(x.dtype());
}
}
out->set_layout(x.layout());
}
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
......
......@@ -279,6 +279,21 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor* sequencenum,
MetaTensor* out);
void FusedBiasActInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& dequant_scales,
const MetaTensor& shift,
const MetaTensor& smooth,
const std::string& act_method,
const std::string& compute_dtype,
int rows,
int cols,
float quant_scale,
int quant_round_type,
float quant_max_bound,
float quant_min_bound,
MetaTensor* out);
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
......
......@@ -3923,7 +3923,7 @@ template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float beta = 1.0;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
namespace funcs {
#ifndef PADDLE_WITH_HIP
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max) return max;
if (v < min) return min;
return v;
}
template <typename InType, typename OutType>
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * input;
if (round_type == 0) {
quant_value = static_cast<float>(rint(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}
template <typename T>
struct Load {
explicit Load(const T *src) : src_(src) {}
template <int VecSize>
__device__ void load(phi::AlignedVector<T, VecSize> *dst, int idx) {
phi::Load<T, VecSize>(src_ + idx, dst);
}
const T *src_;
};
template <typename T, bool Smooth = false>
struct Store {
explicit Store(T *dst) : dst_(dst) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, int idx) {
phi::Store<T, VecSize>(src, dst_ + idx);
}
T *dst_;
};
template <typename T>
struct Store<T, true> {
Store(T *dst, const T *shift, const T *smooth, const int cols)
: dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, int idx) {
using Vec = phi::AlignedVector<T, VecSize>;
Vec shift_vec;
Vec smooth_vec;
phi::Load<T, VecSize>(shift_ + idx % cols_, &shift_vec);
phi::Load<T, VecSize>(smooth_ + idx % cols_, &smooth_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src[i] = (src[i] + shift_vec[i]) * smooth_vec[i];
}
phi::Store<T, VecSize>(src, dst_ + idx);
}
T *dst_;
const T *shift_;
const T *smooth_;
const int cols_;
};
template <typename T>
struct DequantLoad {
DequantLoad(const int32_t *src, const float *dequant_scales, const int cols)
: src_(src), dequant_scales_(dequant_scales), cols_(cols) {}
template <int VecSize>
__device__ void load(phi::AlignedVector<T, VecSize> *dst, int idx) {
using SrcVec = phi::AlignedVector<int32_t, VecSize>;
using DstVec = phi::AlignedVector<T, VecSize>;
using ScaleVec = phi::AlignedVector<float, VecSize>;
SrcVec src_vec;
DstVec dst_vec;
ScaleVec scale_vec;
phi::Load<int32_t, VecSize>(src_ + idx, &src_vec);
phi::Load<float, VecSize>(dequant_scales_ + idx % cols_, &scale_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dst_vec[i] =
static_cast<T>(static_cast<float>(src_vec[i]) * scale_vec[i]);
}
*dst = dst_vec;
}
const int32_t *src_;
const float *dequant_scales_;
const int cols_;
};
template <typename T, bool Smooth = false>
struct QuantStore {
QuantStore(int8_t *dst,
const int quant_round_type,
const float quant_scale,
const float quant_max_bound,
const float quant_min_bound)
: dst_(dst),
quant_round_type_(quant_round_type),
quant_scale_(quant_scale),
quant_max_bound_(quant_max_bound),
quant_min_bound_(quant_min_bound) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, // NOLINT
int idx) { // NOLINT
using DstVec = phi::AlignedVector<int8_t, VecSize>;
DstVec dst_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dst_vec[i] = QuantHelperFunc<float, int8_t>(static_cast<float>(src[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
}
phi::Store<int8_t, VecSize>(dst_vec, dst_ + idx);
}
int8_t *dst_;
const int quant_round_type_;
const float quant_scale_;
const float quant_max_bound_;
const float quant_min_bound_;
};
template <typename T>
struct QuantStore<T, true> {
QuantStore(int8_t *dst,
const T *shift,
const T *smooth,
const int cols,
const int quant_round_type,
const float quant_scale,
const float quant_max_bound,
const float quant_min_bound)
: dst_(dst),
shift_(shift),
smooth_(smooth),
cols_(cols),
quant_round_type_(quant_round_type),
quant_scale_(quant_scale),
quant_max_bound_(quant_max_bound),
quant_min_bound_(quant_min_bound) {}
template <int VecSize>
__device__ void store(phi::AlignedVector<T, VecSize> &src, // NOLINT
int idx) { // NOLINT
using DstVec = phi::AlignedVector<int8_t, VecSize>;
using Vec = phi::AlignedVector<T, VecSize>;
DstVec dst_vec;
Vec shift_vec;
Vec smooth_vec;
phi::Load<T, VecSize>(shift_ + idx % cols_, &shift_vec);
phi::Load<T, VecSize>(smooth_ + idx % cols_, &smooth_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src[i] = (src[i] + shift_vec[i]) * smooth_vec[i];
dst_vec[i] = QuantHelperFunc<float, int8_t>(static_cast<float>(src[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
}
phi::Store<int8_t, VecSize>(dst_vec, dst_ + idx);
}
int8_t *dst_;
const int quant_round_type_;
const float quant_scale_;
const float quant_max_bound_;
const float quant_min_bound_;
const T *shift_;
const T *smooth_;
const int cols_;
};
#endif
} // namespace funcs
} // namespace phi
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
#endif
// for windows build
#define M_SQRT1_2 0.70710678118654752440
namespace phi {
namespace fusion {
#ifndef PADDLE_WITH_HIP
template <typename T>
struct GeluComputeType;
template <>
struct GeluComputeType<phi::dtype::bfloat16> {
using Type = float;
};
template <>
struct GeluComputeType<phi::dtype::float16> {
using Type = float;
};
template <>
struct GeluComputeType<float> {
using Type = float;
};
template <typename T>
using GeluType = typename GeluComputeType<T>::Type;
using phi::funcs::DequantLoad;
using phi::funcs::Load;
using phi::funcs::QuantStore;
using phi::funcs::Store;
template <typename T>
struct BaseActivationFunctor {
using ELEMENT_TYPE = T;
using AttrPair = std::vector<std::pair<const char *, float *>>;
AttrPair GetAttrs() { return AttrPair(); }
};
// For windows build
template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta = 1.0;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// swish(x) = x / (1 + exp(-beta * x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
return static_cast<T>(x / (one + exp(-b * x)));
}
};
// TODO(lzc): transfer to phi::funcs
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
using U = GeluType<T>;
const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
return static_cast<T>(out);
}
};
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
return phi::GeluFwd<T, true>(x);
}
};
inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
constexpr int kBlockSize = 128;
constexpr int kNumWaves = 16;
const int device_id = phi::backends::gpu::GetCurrentDeviceId();
const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id);
const int max_thread_per_multiprocessor =
phi::backends::gpu::GetGPUMultiProcessors(device_id);
*num_blocks =
std::max<int>(1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * max_thread_per_multiprocessor /
kBlockSize * kNumWaves));
return cudaSuccess;
}
#endif
} // namespace fusion
} // namespace phi
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册