未验证 提交 ffff3da0 编写于 作者: FormlessUnit's avatar FormlessUnit 提交者: GitHub

[Paddle Inference] refactor linear_compress (#55490)

* Modify kernels to support quantized_matmul

---------
Co-authored-by: Nsuperxf <1208713646@qq.com>
上级 2f69edc5
......@@ -1427,15 +1427,15 @@
data_transform :
skip_transform : out_size, size_tensor, scale_tensor
- op : llm_int8_matmul
args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0)
- op : llm_int8_linear
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, float threshold=6.0)
output : Tensor(out)
infer_meta :
func : LLMInt8MatmulInferMeta
param : [x, weight]
func : LLMInt8LinearInferMeta
kernel :
func : llm_int8_matmul
func : llm_int8_linear
data_type : x
optional: bias
- op : log
args : (Tensor x)
......@@ -2003,15 +2003,6 @@
func : qr
backward : qr_grad
- op : quant_for_compress
args : (Tensor x, int bits = 8, str layout = "weight_only")
output : Tensor(out), Tensor(scale)
infer_meta :
func : QuantForCompressInferMeta
kernel :
func : quant_for_compress
data_type: x
- op : real
args : (Tensor x)
output : Tensor (out)
......@@ -2758,14 +2749,24 @@
intermediate: warprnntgrad
backward : warprnnt_grad
- op : weight_only_matmul
args : (Tensor x, Tensor weight, Tensor weight_scale)
- op : weight_only_linear
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype)
output : Tensor(out)
infer_meta :
func : WeightOnlyMatmulInferMeta
func : WeightOnlyLinearInferMeta
kernel :
func : weight_only_matmul
func : weight_only_linear
data_type : x
optional: bias
- op : weight_quantize
args : (Tensor x, str algo = "weight_only_int8")
output : Tensor(out), Tensor(scale)
infer_meta :
func : WeightQuantizeInferMeta
kernel :
func : weight_quantize
data_type: x
- op : weighted_sample_neighbors
args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids)
......
......@@ -2534,6 +2534,53 @@ void LambInferMeta(const MetaTensor& param,
}
}
void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& weight_scale,
const float threshold,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[1],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[1]));
PADDLE_ENFORCE_EQ(
w_dims[0] % 16,
0,
phi::errors::InvalidArgument(
"The first dimension of input must be divisible by 16, but got[%d]",
w_dims[0]));
PADDLE_ENFORCE_EQ(
w_dims[1] % 16,
0,
phi::errors::InvalidArgument(
"The second dimension of input must be divisible by 16, but got[%d]",
w_dims[1]));
PADDLE_ENFORCE_EQ(
weight_scale.dims()[0],
w_dims[0],
errors::InvalidArgument(
"Input(weight_scale) dim[0] and Input(Weight) dim[0] should be euqal."
"But received Input(weight_scale) dim[0](%s) != Input(Weight) "
"dim[0](%s)",
weight_scale.dims()[0],
w_dims[0]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[0];
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
......@@ -3664,6 +3711,52 @@ void WarprnntInferMeta(const MetaTensor& input,
loss->set_dtype(input.dtype());
}
void WeightOnlyLinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& weight_scale,
const std::string& weight_dtype,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
auto n = weight_scale.dims()[0];
PADDLE_ENFORCE(
weight_dtype == "int8" || weight_dtype == "int4",
errors::InvalidArgument("quant_method must be 'int8' or 'int4'."));
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
weight_scale.dims().size(),
1UL,
errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor."));
PADDLE_ENFORCE_EQ(
w_dims[0] % 16,
0,
phi::errors::InvalidArgument(
"The first dimension of input must be divisible by 16, but got[%d]",
w_dims[0]));
PADDLE_ENFORCE_EQ(
w_dims[1] % 16,
0,
phi::errors::InvalidArgument(
"The second dimension of input must be divisible by 16, but got[%d]",
w_dims[1]));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[1],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[1]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = n;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
......@@ -3997,58 +4090,6 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32);
}
void LLMInt8MatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[1],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[1]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[0];
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
void WeightOnlyMatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& weight_scale,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
auto n = weight_scale.dims()[0];
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
weight_scale.dims().size(),
1UL,
errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[1],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[1]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = n;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
......
......@@ -446,6 +446,13 @@ void LambInferMeta(const MetaTensor& param,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);
void LLMInt8LinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& weight_scale,
const float threshold,
MetaTensor* out);
void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
......@@ -676,6 +683,13 @@ void WarprnntInferMeta(const MetaTensor& input,
MetaTensor* loss,
MetaTensor* warpctcgrad);
void WeightOnlyLinearInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& weight_scale,
const std::string& weight_dtype,
MetaTensor* out);
void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& edge_weight,
......@@ -775,15 +789,6 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query,
bool causal,
MetaTensor* out);
void LLMInt8MatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out);
void WeightOnlyMatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& weight_scale,
MetaTensor* out);
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
......
......@@ -5039,6 +5039,48 @@ void UnStackInferMeta(const MetaTensor& x,
}
}
void WeightQuantizeInferMeta(const MetaTensor& x,
const std::string& algo,
MetaTensor* out,
MetaTensor* scale) {
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
2UL,
phi::errors::InvalidArgument(
"The x tensor of quant op must be 2D, but got[%d]", x_dims.size()));
PADDLE_ENFORCE_EQ(
x_dims[0] % 64,
0,
phi::errors::InvalidArgument(
"The first dimension of input must be divisible by 64, but got[%d]",
x_dims[0]));
PADDLE_ENFORCE_EQ(
x_dims[1] % 16,
0,
phi::errors::InvalidArgument(
"The second dimension of input must be divisible by 16, but got[%d]",
x_dims[1]));
std::vector<int64_t> dim_scale({x_dims[1]});
std::vector<int64_t> dim_out;
if (algo == "weight_only_int8" || algo == "llm.int8") {
dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]});
} else if (algo == "weight_only_int4") {
dim_out = std::vector<int64_t>({x_dims[1] / 2, x_dims[0]});
} else {
phi::errors::InvalidArgument(
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
"'llm.int8'], but got[%s]",
algo);
}
out->set_dims(phi::make_ddim(dim_out));
out->set_dtype(DataType::INT8);
scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32);
}
void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
......@@ -5099,46 +5141,6 @@ void CheckNumericsInferMeta(const MetaTensor& tensor,
values->set_dims(phi::make_ddim({3}));
}
void QuantForCompressInferMeta(const MetaTensor& x,
int bits,
const std::string& layout,
MetaTensor* out,
MetaTensor* scale) {
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
2UL,
phi::errors::InvalidArgument(
"The x tensor of quant op must be 2D, but got[%d]", x_dims.size()));
PADDLE_ENFORCE_GE(
x_dims[0],
64,
phi::errors::OutOfRange("The first dimension of input is out of range "
"(expected at least 64, but got %ld).",
x_dims[0]));
PADDLE_ENFORCE_EQ(
x_dims[0] % 64,
0,
phi::errors::InvalidArgument(
"The first dimension of input must be divisible by 64, but got[%d]",
x_dims[0]));
std::vector<int64_t> dim_scale({x_dims[1]});
std::vector<int64_t> dim_out;
if (bits == 8) {
dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]});
} else if (bits == 4) {
dim_out = std::vector<int64_t>({x_dims[1] / 2, x_dims[0]});
} else {
phi::errors::InvalidArgument("The bit must be 8 or 4, but got %d", bits);
}
out->set_dims(phi::make_ddim(dim_out));
out->set_dtype(DataType::INT8);
scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32);
}
void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out) {
out->share_meta(x);
out->set_strides(x.strides());
......
......@@ -438,6 +438,11 @@ void QrInferMeta(const MetaTensor& x,
MetaTensor* q,
MetaTensor* r);
void WeightQuantizeInferMeta(const MetaTensor& x,
const std::string& algo,
MetaTensor* out,
MetaTensor* scale);
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x,
......@@ -730,12 +735,6 @@ void UnStackInferMeta(const MetaTensor& x,
int num,
std::vector<MetaTensor*> outs);
void QuantForCompressInferMeta(const MetaTensor& x,
int bits,
const std::string& layout,
MetaTensor* out,
MetaTensor* scale);
void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/quant_for_compress_kernel.h"
#include "paddle/phi/kernels/weight_quantize_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h"
#include "paddle/phi/kernels/impl/weight_quantize_kernel_impl.h"
namespace phi {
......@@ -27,7 +27,7 @@ void quant_compute(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
DenseTensor* scale,
const std::string& layout) {
const std::string& algo) {
const auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
......@@ -56,57 +56,49 @@ void quant_compute(const DeviceContext& dev_ctx,
int_processed_2.Resize(out->dims());
dev_ctx.template Alloc<D>(&int_processed_2);
D* int_processed_2_data = int_processed_2.data<D>();
per_channel_scale(scale_data, x_data, m, n);
per_channel_scale(scale_data, x_data, m, n, bits == 8 ? 127.0f : 7.0f);
per_channel_quant<T, bits>(x_int_data, x_data, scale_data, m, n);
if (layout == "weight_only") {
if (algo == "llm.int8") {
std::vector<int> axis = {1, 0};
funcs::Transpose<DeviceContext, int8_t, 2> trans;
trans(dev_ctx, x_int, out, axis);
} else {
permute_B_rows_for_mixed_gemm<bits>(
int_processed_data, x_int_data, std::vector<size_t>{m, n}, (int64_t)80);
int_processed_data, x_int_data, std::vector<size_t>{m, n});
subbyte_transpose_impl<bits>(
int_processed_2_data, int_processed_data, std::vector<size_t>{m, n});
interleave_column_major_tensor<bits>(
out_data, int_processed_2_data, std::vector<size_t>{m, n});
add_bias_and_interleave_inplace<bits>(out_data, num);
} else if (layout == "llm.int8") {
std::vector<int> axis = {1, 0};
funcs::Transpose<DeviceContext, int8_t, 2> trans;
trans(dev_ctx, x_int, out, axis);
} else {
phi::errors::InvalidArgument(
"The layout must be weight_only or llm.int8, but got %s", layout);
}
}
template <typename T, typename Context>
void QuantForCompressKernel(const Context& dev_ctx,
const DenseTensor& x,
int bits,
const std::string& layout,
DenseTensor* out,
DenseTensor* scale) {
if (bits == 8) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, layout);
} else if (bits == 4 && layout == "weight_only") {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t, 4>(dev_ctx, x, out, scale, layout);
void WeightQuantizeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& algo,
DenseTensor* out,
DenseTensor* scale) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
if (algo == "weight_only_int8" || algo == "llm.int8") {
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, algo);
} else if (algo == "weight_only_int4") {
quant_compute<Context, T, int8_t, 4>(dev_ctx, x, out, scale, algo);
} else {
phi::errors::Unimplemented(
"The bits only support 8 or weight_only 4, but got[%s] [%d]",
layout,
bits);
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
"'llm.int8'], but got[%s]",
algo);
}
// VLOG(0) << "x: " << x.dtype() << x;
// VLOG(0) << "out: " << out->dtype() << *out;
}
} // namespace phi
PD_REGISTER_KERNEL(quant_for_compress,
PD_REGISTER_KERNEL(weight_quantize,
CPU,
ALL_LAYOUT,
phi::QuantForCompressKernel,
phi::dtype::float16) {}
phi::WeightQuantizeKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/funcs/weight_only_gemv.h"
#include <assert.h>
#include <stdint.h>
#include <cmath>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace {
#ifdef PADDLE_WITH_CUDA
constexpr int kWarpSize = 32;
constexpr int kPerBlockWarpNum = 16;
/////////////////////////////////////////////////////////////////////
template <typename T>
__device__ inline void fast_cvt_4_packed_signed_i8s_to_2_half2s(
T halves[4], int8_t signed_chars[4]) {
assert(false);
}
// Specialization for fast cast from FP16 -> int8
template <>
__device__ inline void fast_cvt_4_packed_signed_i8s_to_2_half2s(
half halves[4], int8_t signed_chars[4]) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
uint32_t* h = reinterpret_cast<uint32_t*>(halves);
uint32_t i8s = *reinterpret_cast<uint32_t*>(signed_chars);
static constexpr uint32_t mask_for_elt_01 = 0x5150;
static constexpr uint32_t mask_for_elt_23 = 0x5352;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(h[0])
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(h[1])
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
#endif
}
// Specialization for fast cast from BF16 -> int8
#ifdef PADDLE_CUDA_BF16
template <>
__device__ inline void fast_cvt_4_packed_signed_i8s_to_2_half2s(
__nv_bfloat16 halves[4], int8_t signed_chars[4]) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(halves);
uint32_t i8s = *reinterpret_cast<uint32_t*>(signed_chars);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
// Subtract out fp32_base + 128 to make the unsigned integer signed.
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
fp32_intermediates[ii] -= 8388736.f;
}
// Truncate the fp32 representation and pack up as bfloat16s.
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0],
fp32_intermediates_casted[2 * ii + 1],
0x7632);
}
#else
// Disable this on architectures older than Ampere since they lack hardware
// for bf16 mma. If one wishes to use HMMA on older hardware, they should
// Convert directly to FP16 using FP16 converters.
assert(false);
#endif
}
#endif
/* Gelu Activation */
__forceinline__ __device__ float copysignf_pos(float a, float b) {
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__inline__ __device__ float tanh_opt(float x) {
#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000)
float r;
asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x));
return r;
#else
const float exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#endif
}
template <typename T, bool EnableFastGelu>
struct GeluActivation {
using return_type = T;
static __device__ __forceinline__ T apply(const T& val) {
if (!EnableFastGelu) return val;
const float cdf =
0.5f * (1.0f + tanh_opt((0.7978845608028654f *
(val + 0.044715f * val * val * val))));
return val * cdf;
}
};
template <typename T>
struct ConvertFloatFunc {
ConvertFloatFunc() {}
static __device__ __forceinline__ float apply(const T& val) {
assert(false);
return 0.0f;
}
};
template <>
struct ConvertFloatFunc<half> {
static __device__ __forceinline__ float apply(const half& val) {
return __half2float(val);
}
};
#ifdef PADDLE_CUDA_BF16
template <>
struct ConvertFloatFunc<__nv_bfloat16> {
static __device__ __forceinline__ float apply(const __nv_bfloat16& val) {
return __bfloat162float(val);
}
};
#endif
template <typename T>
struct ConvertDstFunc {
static __device__ __forceinline__ T apply(const float& val) { assert(false); }
};
template <>
struct ConvertDstFunc<half> {
static __device__ __forceinline__ half apply(const float& val) {
return __float2half_rn(val);
}
};
#ifdef PADDLE_CUDA_BF16
template <>
struct ConvertDstFunc<__nv_bfloat16> {
static __device__ __forceinline__ __nv_bfloat16 apply(const float& val) {
return __float2bfloat16_rn(val);
}
};
#endif
template <typename T>
struct HalfMul {
static __device__ __forceinline__ T apply(const T& x, const T& y) {
return __hmul(x, y);
}
};
#ifdef PADDLE_CUDA_BF16
template <>
struct HalfMul<__nv_bfloat16> {
static __device__ __forceinline__ __nv_bfloat16
apply(const __nv_bfloat16& x, const __nv_bfloat16& y) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmul(x, y);
#else
return __float2bfloat16_rn(__bfloat162float(x) * __bfloat162float(y));
#endif
}
};
#endif
/*
Int8 Weightonly GEMV.
X: 1 x k
Weight(ColMajor): n x k
Each Warp Process: 1 x k matmul 1 x k
*/
template <typename T, bool Bias, bool Gelu>
__global__ void int8_weight_only_gemv(const T* input,
const int8_t* weight,
const float* scale_list,
const T* bias,
T* output,
const int k,
const int n) {
constexpr int kWarpSize = 32;
constexpr int kVecSize = 16;
T vec_input[kVecSize];
int8_t vec_weight[kVecSize];
T vec_weight_f16[kVecSize];
const int warp_id = threadIdx.x / kWarpSize;
const int lane_id = threadIdx.x % kWarpSize;
const int tile_id = blockIdx.x * blockDim.x / kWarpSize + warp_id;
const int row_id = tile_id * 2 + ((lane_id % 8) > 3 ? 1 : 0);
weight += tile_id * k * 2;
float v = 0.f, scale = scale_list[row_id], v_bias;
if (Bias) {
v_bias = ConvertFloatFunc<T>::apply(bias[row_id]);
}
#pragma unroll
for (int i = lane_id * kVecSize; i < k * 2; i += kVecSize * kWarpSize) {
*(int4*)vec_weight = *(int4*)(weight + i); // NOLINT
*(float4*)vec_input = // NOLINT
*(float4*)(input + i / 128 * 64 + (i % 64)); // NOLINT
*(float4*)(vec_input + 8) = // NOLINT
*(float4*)(input + i / 128 * 64 + (i % 64) + 8); // NOLINT
#pragma unroll
for (int p = 0; p < kVecSize; p += 4) {
fast_cvt_4_packed_signed_i8s_to_2_half2s<T>(vec_weight_f16 + p,
vec_weight + p);
}
#pragma unroll
for (int p = 0; p < kVecSize; ++p) {
v += ConvertFloatFunc<T>::apply(
HalfMul<T>::apply(vec_input[p], vec_weight_f16[p / 8 + (p % 8) * 2]));
}
}
// Do WarpReduceSum.
v += __shfl_xor_sync(0xffffffff, v, 16);
v += __shfl_xor_sync(0xffffffff, v, 8);
v += __shfl_xor_sync(0xffffffff, v, 2);
v += __shfl_xor_sync(0xffffffff, v, 1);
if (lane_id == 0 || lane_id == 4) {
if (Bias) {
output[row_id] = ConvertDstFunc<T>::apply(
GeluActivation<float, Gelu>::apply(v * scale + v_bias));
} else {
output[row_id] = ConvertDstFunc<T>::apply(
GeluActivation<float, Gelu>::apply(v * scale));
}
}
}
#endif
template <typename T>
void int8_weight_only_gemv_launcher(const T* input,
const int8_t* weight,
const float* scale_list,
const T* bias,
T* output,
const int k,
const int n,
const bool gelu,
gpuStream_t stream) {
#ifdef PADDLE_WITH_CUDA
dim3 block(kWarpSize * kPerBlockWarpNum); // equal to 512;
dim3 grid(n / kPerBlockWarpNum /
2); // Note(zhengzekang): Since each warp process 2 rows of matrix.
if (bias) {
if (gelu) {
int8_weight_only_gemv<T, true, true><<<grid, block, 0, stream>>>(
input, weight, scale_list, bias, output, k, n);
} else {
int8_weight_only_gemv<T, true, false><<<grid, block, 0, stream>>>(
input, weight, scale_list, bias, output, k, n);
}
} else {
if (gelu) {
int8_weight_only_gemv<T, false, true><<<grid, block, 0, stream>>>(
input, weight, scale_list, bias, output, k, n);
} else {
int8_weight_only_gemv<T, false, false><<<grid, block, 0, stream>>>(
input, weight, scale_list, bias, output, k, n);
}
}
#endif
}
template <>
void int8_weight_only_gemv_launcher(const float* input,
const int8_t* weight,
const float* scale_list,
const float* bias,
float* output,
const int k,
const int n,
const bool gelu,
gpuStream_t stream) {
// Weightonly GEMV do not support float.
assert(false);
}
template <>
void int8_weight_only_gemv_launcher(const phi::dtype::bfloat16* input,
const int8_t* weight,
const float* scale_list,
const phi::dtype::bfloat16* bias,
phi::dtype::bfloat16* output,
const int k,
const int n,
const bool gelu,
gpuStream_t stream) {
// Environment do not support bf16.
assert(false);
}
} // namespace
template <typename T, typename Context>
void GemvWeightonlyInt8Wrapper(const Context& ctx,
const T* x,
const int8_t* weight,
const T* bias,
const float* weight_scale,
const int n,
const int k,
const std::string& act_method,
T* output) {
using DataType = typename PDDataTypeTraits<T>::DataType;
bool gelu = false;
if (act_method == "gelu") {
gelu = true;
} else if (act_method == "None") {
gelu = false;
} else {
PADDLE_THROW(
errors::InvalidArgument("Currently, Int8 weightonly GEMV act_method "
"only support `gelu`, `None`. "));
}
int8_weight_only_gemv_launcher<DataType>(
reinterpret_cast<const DataType*>(x),
weight,
weight_scale,
reinterpret_cast<const DataType*>(bias),
reinterpret_cast<DataType*>(output),
k,
n,
gelu,
ctx.stream());
}
template <typename T, typename Context>
void GemvWeightonlyInt8Kernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const std::string& act_method,
DenseTensor* out) {
const T* x_data = x.data<T>();
const int8_t* weight_data =
weight.data<int8_t>(); // Actually, we pass the weight datatype is
// uint8_t type.
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
const float* weight_scale_data = weight_scale.data<float>();
T* out_data = dev_ctx.template Alloc<T>(out);
int k = x.dims()[1];
int n = weight.dims()[0];
GemvWeightonlyInt8Wrapper<T, Context>(dev_ctx,
x_data,
weight_data,
bias_data,
weight_scale_data,
n,
k,
act_method,
out_data);
}
template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const int8_t* weight,
const phi::dtype::float16* bias,
const float* weight_scale,
const int n,
const int k,
const std::string& act_method,
phi::dtype::float16* output);
template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const int8_t* weight,
const phi::dtype::bfloat16* bias,
const float* weight_scale,
const int n,
const int k,
const std::string& act_method,
phi::dtype::bfloat16* output);
template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx,
const float* x,
const int8_t* weight,
const float* bias,
const float* weight_scale,
const int n,
const int k,
const std::string& act_method,
float* output);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GemvWeightonlyInt8Wrapper(const Context& ctx,
const T* x,
const int8_t* weight,
const T* bias,
const float* weight_scale,
const int n,
const int k,
const std::string& act_method,
T* output);
} // namespace phi
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
......
......@@ -102,8 +102,6 @@ struct GemmFpAIntB {
/// Parameters structure
struct Arguments : UniversalArgumentsBase {
GemmUniversalMode mode = GemmUniversalMode::kGemm;
cutlass::gemm::GemmCoord problem_size;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
......@@ -111,9 +109,6 @@ struct GemmFpAIntB {
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
// Control serial split-k
int batch_count;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
......@@ -121,9 +116,6 @@ struct GemmFpAIntB {
int const* gather_B_indices;
int const* scatter_D_indices;
// Included so we can use Gemm Universal
int batch_stride_D = 0;
//
// Methods
//
......@@ -144,10 +136,13 @@ struct GemmFpAIntB {
int const* gather_A_indices = nullptr,
int const* gather_B_indices = nullptr,
int const* scatter_D_indices = nullptr)
: UniversalArgumentsBase(mode,
problem_size,
/*serial_split_k_factor=*/1,
/*batch_stride_D=*/0),
: // TODO(wangbojun) hard code here for GemmUniversalMode::kGemm and
// batch_stride_D
UniversalArgumentsBase(
GemmUniversalMode::kGemm,
problem_size,
/*serial_split_k_factor=*/serial_split_k_factor,
/*batch_stride_D=*/0),
ref_A(ref_A),
ref_B(ref_B),
ref_scale(ref_scale),
......@@ -505,10 +500,9 @@ struct GemmFpAIntB {
void operator()(Params const& params,
SharedStorage& shared_storage) { // NOLINT
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
// static constexpr bool compile_needed = platform::is_same<KernelArch,
// arch::Sm70>::value; KernelRunner<compile_needed>::run_kernel(params,
// shared_storage);
CUTLASS_NOT_IMPLEMENTED();
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm70>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
// static constexpr bool compile_needed = platform::is_same<KernelArch,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......@@ -52,13 +68,14 @@ void generic_mixed_gemm_kernelLauncher(const T* A,
char* workspace,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
int* occupancy) {
static_assert(cutlass::platform::is_same<T, half>::value ||
#ifdef PADDLE_CUDA_BF16
cutlass::platform::is_same<T, __nv_bfloat16>::value ||
#endif
cutlass::platform::is_same<T, float>::value,
"Specialized for bfloat16, half, float");
static_assert(
cutlass::platform::is_same<T, WeightType>::value ||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
......@@ -71,13 +88,23 @@ void generic_mixed_gemm_kernelLauncher(const T* A,
cutlass::platform::is_same<T, half>::value,
cutlass::half_t,
T>::type;
using ElementType = ElementType_;
#ifdef PADDLE_CUDA_BF16
using ElementType = typename cutlass::platform::conditional<
cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
cutlass::bfloat16_t,
ElementType_>::type;
#endif
using CutlassWeightType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<WeightType, half>::value,
cutlass::half_t,
WeightType>::type;
using CutlassWeightType = CutlassWeightType_;
#ifdef PADDLE_CUDA_BF16
using CutlassWeightType = typename cutlass::platform::conditional<
cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
cutlass::bfloat16_t,
CutlassWeightType_>::type;
#endif
// We need separate config for each architecture since we will target
// different tensorcore instructions. For float, we do not target TCs.
......@@ -156,10 +183,17 @@ void generic_mixed_gemm_kernelLauncher(const T* A,
Gemm gemm;
if (gemm.get_workspace_size(args) > workspace_bytes) {
// TODO(wangbojun) here to reset the split-k in gemm args, but no work for
// now to run bf16 mixgemm, we have set the split-k factor to 1
VLOG(1) << "Requested split-k but workspace size insufficient. Falling "
"back to non-split-k implementation.";
VLOG(1) << "need workspace sizoe of: " << gemm.get_workspace_size(args)
<< ", but got " << workspace_bytes;
VLOG(1) << "args.batch_stride_D:" << args.batch_stride_D;
VLOG(1) << "args.batch_count:" << args.batch_count;
// If requested split-k factor will require more workspace bytes, revert to
// standard gemm.
//
args.batch_count = 1;
}
......@@ -209,13 +243,13 @@ struct dispatch_stages {
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
// VLOG(3)<<__PRETTY_FUNCTION__;
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " +
std::to_string(arch::kMinComputeCapability) +
" with stages set to " + std::to_string(Stages);
throw std::runtime_error("[dispatch_stages::dispatch] " + err_msg);
}
};
template <typename T,
typename WeightType,
typename arch,
......@@ -242,6 +276,8 @@ struct dispatch_stages<T,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
// VLOG(3)<<__PRETTY_FUNCTION__;
generic_mixed_gemm_kernelLauncher<T,
WeightType,
arch,
......@@ -331,7 +367,7 @@ void dispatch_gemm_config(const T* A,
char* workspace,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
int* occupancy) {
switch (gemm_config.stages) {
case 2:
using DispatcherStages2 = dispatch_stages<T,
......@@ -420,7 +456,8 @@ void dispatch_gemm_to_cutlass(const T* A,
size_t workspace_bytes,
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy = nullptr) {
int* occupancy) {
// VLOG(3)<<__PRETTY_FUNCTION__;
// Note that SIMT configs are omitted here since they are not supported for
// fpA_intB. We also only instantiate configs here where threadblockShapeM ==
// warpShapeM since those usually perform the best for mixed type gemms.
......@@ -533,24 +570,24 @@ void dispatch_gemm_to_cutlass(const T* A,
break;
case CutlassTileConfig::Undefined:
throw std::runtime_error(
"[fpA_intB][dispatch_gemm_to_cutlass] gemm config "
"undefined.");
"[fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[fpA_intB][dispatch_gemm_to_cutlass] gemm config should "
"have already been set by heuristic.");
"[fpA_intB][dispatch_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
break;
default:
throw std::runtime_error(
"[fpA_intB][dispatch_gemm_to_cutlass] Config is invalid "
"for mixed type GEMM.");
"[fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed "
"type GEMM.");
break;
}
}
template <typename T, typename WeightType>
CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner() {
// VLOG(3)<<__PRETTY_FUNCTION__;
int device{-1};
check_cuda_error(cudaGetDevice(&device));
sm_ = getSMVersion();
......@@ -559,7 +596,9 @@ CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner() {
}
template <typename T, typename WeightType>
CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner() {}
CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner() {
// VLOG(3)<<__PRETTY_FUNCTION__;
}
template <typename T, typename WeightType>
template <typename EpilogueTag>
......@@ -577,20 +616,38 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
const size_t workspace_bytes,
cudaStream_t stream,
int* occupancy) {
// if (sm_ >= 70 && sm_ < 75) {
// dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70,
// EpilogueTag>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr,
// workspace_bytes, gemm_config, stream, occupancy);
// }
// else if (sm_ >= 75 && sm_ < 80) {
// dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75,
// EpilogueTag>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr,
// workspace_bytes, gemm_config, stream, occupancy);
// }
// else
if (sm_ >= 80 && sm_ < 90) {
// VLOG(3)<<__PRETTY_FUNCTION__;
if (sm_ >= 70 && sm_ < 75) {
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70, EpilogueTag>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
workspace_ptr,
workspace_bytes,
gemm_config,
stream,
occupancy);
} else if (sm_ >= 75 && sm_ < 80) {
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
workspace_ptr,
workspace_bytes,
gemm_config,
stream,
occupancy);
} else if (sm_ >= 80 && sm_ < 90) {
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(
A,
B,
......@@ -607,8 +664,8 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
occupancy);
} else {
throw std::runtime_error(
"[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported "
"for CUTLASS mixed type GEMM");
"[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for "
"CUTLASS mixed type GEMM");
}
}
......@@ -626,6 +683,7 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
// VLOG(3)<<__PRETTY_FUNCTION__;
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
const bool is_weight_only_encoder = m >= 512 ? true : false;
std::vector<CutlassGemmConfig> candidate_configs =
......@@ -690,6 +748,7 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act(
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
// VLOG(3)<<__PRETTY_FUNCTION__;
if (activation_type == "gelu") {
run_gemm<EpilogueOpBiasFtGelu>(A,
B,
......@@ -742,6 +801,7 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::gemm(const T* A,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
// VLOG(3)<<__PRETTY_FUNCTION__;
run_gemm<EpilogueOpNoBias>(A,
B,
weight_scales,
......@@ -759,7 +819,8 @@ template <typename T, typename WeightType>
int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m,
const int n,
const int k) {
// sizes for each config, which would launch the maximum number of blocks
// VLOG(3)<<__PRETTY_FUNCTION__; // These are the min tile sizes for each
// config, which would launch the maximum number of blocks
const int max_grid_m = (m + 31) / 32;
const int max_grid_n = (n + 127) / 128;
// We need 4 bytes per block in the worst case. We launch split_k_limit in z
......@@ -811,4 +872,14 @@ int CutlassFpAIntBGemmRunner<float, WeightType>::getWorkspaceSize(const int m,
return 0;
}
template class CutlassFpAIntBGemmRunner<float, uint8_t>;
template class CutlassFpAIntBGemmRunner<half, uint8_t>;
#ifdef PADDLE_CUDA_BF16
template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t>;
#endif
template class CutlassFpAIntBGemmRunner<float, cutlass::uint4b_t>;
template class CutlassFpAIntBGemmRunner<half, cutlass::uint4b_t>;
#ifdef PADDLE_CUDA_BF16
template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
#endif
} // namespace phi
......@@ -14,6 +14,20 @@
* limitations under the License.
*/
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cublasLt.h>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/llm_int8_matmul_kernel.h"
#include "paddle/phi/kernels/llm_int8_linear_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020
#include "paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h"
#endif
......@@ -26,12 +28,11 @@ template <typename T, typename Context>
void llm_int8_compute(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const float threshold,
DenseTensor* out) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with cublaslt, ROCM platform isn't support it";
#else
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020
DenseTensor cublaslt_workspace;
cublaslt_workspace.Resize({{3000000}});
dev_ctx.template Alloc<int8_t>(&cublaslt_workspace);
......@@ -52,24 +53,35 @@ void llm_int8_compute(const Context& dev_ctx,
m,
k,
n);
if (bias) {
std::vector<const phi::DenseTensor*> ins = {out, &(bias.get())};
std::vector<phi::DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<T>(
dev_ctx, ins, &outs, phi::funcs::AddFunctor<T>());
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"llm_int8_linear op needs paddle with cuda and cuda version >= 11.2"));
#endif
}
template <typename T, typename Context>
void LLMInt8MatmulKernel(const Context& dev_ctx,
void LLMInt8LinearKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const float threshold,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
llm_int8_compute<T, Context>(
dev_ctx, x, weight, weight_scale, threshold, out);
dev_ctx, x, weight, bias, weight_scale, threshold, out);
}
} // namespace phi
PD_REGISTER_KERNEL(llm_int8_matmul,
PD_REGISTER_KERNEL(llm_int8_linear,
GPU,
ALL_LAYOUT,
phi::LLMInt8MatmulKernel,
phi::dtype::float16) {}
phi::LLMInt8LinearKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -11,11 +11,11 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/weight_only_matmul_kernel.h"
#include "paddle/phi/kernels/weight_only_linear_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/weight_only_gemv.h"
#if defined(PADDLE_WITH_CUTLASS)
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
#endif
......@@ -23,32 +23,33 @@
namespace phi {
template <typename T, typename Context>
void WeightOnlyMatmulKernel(const Context& dev_ctx,
void WeightOnlyLinearKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const std::string& weight_dtype,
DenseTensor* out) {
#if defined(PADDLE_WITH_CUTLASS)
dev_ctx.template Alloc<T>(out);
const T* x_data = x.data<T>();
const int8_t* weight_data = weight.data<int8_t>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
const float* weight_scale_data = weight_scale.data<float>();
T* out_data = out->data<T>();
const auto x_dims = x.dims();
const auto w_dims = weight.dims();
int n = weight_scale.dims()[0];
int quant_bit = 0;
if (n % w_dims[0] == 0) {
quant_bit = w_dims[0] * 8 / n;
} else {
errors::InvalidArgument(
"w_dims[0] must be divisible by weight_scale.dims()[0]");
}
int k = w_dims[1];
int m = x.numel() / k;
switch (quant_bit) {
case 8: {
// m > 1: run gemm
if (m > 1 || weight_dtype == "int4") {
#if defined(PADDLE_WITH_CUTLASS)
if (weight_dtype == "int8") {
auto mixed_gemm_runner =
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
uint8_t>();
int mixgemm_max_size = std::max(n, k);
int mixgemm_max_size = std::max(m, k);
DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
m, mixgemm_max_size, mixgemm_max_size);
......@@ -57,25 +58,41 @@ void WeightOnlyMatmulKernel(const Context& dev_ctx,
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
char* mixgemm_workspace_data =
reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
mixed_gemm_runner.gemm(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x.data<T>()),
reinterpret_cast<const uint8_t*>(weight.data<int8_t>()),
reinterpret_cast<const float*>(weight_scale.data<float>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(
out->data<T>()),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
} break;
case 4: {
if (bias_data) {
mixed_gemm_runner.gemm_bias_act(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x_data),
reinterpret_cast<const uint8_t*>(weight_data),
weight_scale_data,
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
bias_data),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(out_data),
m,
n,
k,
"none",
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
} else {
mixed_gemm_runner.gemm(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x_data),
reinterpret_cast<const uint8_t*>(weight_data),
weight_scale_data,
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(out_data),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
}
} else {
auto mixed_gemm_runner =
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
cutlass::uint4b_t>();
int mixgemm_max_size = std::max(n, k);
int mixgemm_max_size = std::max(m, k);
DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
m, mixgemm_max_size, mixgemm_max_size);
......@@ -84,34 +101,60 @@ void WeightOnlyMatmulKernel(const Context& dev_ctx,
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
char* mixgemm_workspace_data =
reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
mixed_gemm_runner.gemm(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x.data<T>()),
reinterpret_cast<const cutlass::uint4b_t*>(weight.data<int8_t>()),
reinterpret_cast<const float*>(weight_scale.data<float>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(
out->data<T>()),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
} break;
default:
PADDLE_THROW(errors::Unimplemented(
"Quant_bits (%d) is not supported when gemm ", quant_bit));
break;
}
if (bias_data) {
mixed_gemm_runner.gemm_bias_act(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x_data),
reinterpret_cast<const cutlass::uint4b_t*>(weight_data),
weight_scale_data,
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
bias_data),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(out_data),
m,
n,
k,
"none",
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
} else {
mixed_gemm_runner.gemm(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x_data),
reinterpret_cast<const cutlass::uint4b_t*>(weight_data),
weight_scale_data,
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(out_data),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
}
}
#else
LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()";
PADDLE_THROW(phi::errors::Unimplemented(
"Please compile with cutlass to make cutlass available"));
#endif
} else { // m == 1: gemv
if (weight_dtype == "int8") {
GemvWeightonlyInt8Wrapper<T, Context>(dev_ctx,
x_data,
weight_data,
bias_data,
weight_scale_data,
n,
k,
"None",
out->data<T>());
} // TODO(lizhenyun) support weight_only_gemv for int4.
}
}
} // namespace phi
PD_REGISTER_KERNEL(weight_only_matmul,
PD_REGISTER_KERNEL(weight_only_linear,
GPU,
ALL_LAYOUT,
phi::WeightOnlyMatmulKernel,
phi::dtype::float16) {}
phi::WeightOnlyLinearKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,9 +28,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
#define PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
#include <iostream>
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
......@@ -29,13 +44,14 @@ inline T xabs(const T x) {
}
template <typename T>
void per_channel_scale(float* scale, const T* input, size_t m, size_t n) {
void per_channel_scale(
float* scale, const T* input, size_t m, size_t n, float bound) {
for (size_t i = 0; i < n; ++i) {
T max = input[i];
for (size_t j = 0; j < m; ++j) {
max = xabs(input[j * n + i]) > max ? xabs(input[j * n + i]) : max;
}
scale[i] = static_cast<float>(max) / 127.0;
scale[i] = static_cast<float>(max) / bound;
}
}
......@@ -144,8 +160,7 @@ void add_bias_and_interleave_inplace(int8_t* tensor_ptr, size_t num_elts) {
template <int quant_bit>
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const int8_t* quantized_tensor,
const std::vector<size_t>& shape,
const int64_t arch_version) {
const std::vector<size_t>& shape) {
// We only want to run this step for weight only quant.
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
......@@ -321,7 +336,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
const size_t num_vec_rows = num_rows / elts_in_int32;
const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32;
const size_t interleave = 2;
const size_t interleave = 128 * 8 / quant_bit / rows_per_tile;
for (size_t read_col = 0; read_col < num_cols; ++read_col) {
const size_t write_col = read_col / interleave;
for (size_t base_vec_row = 0; base_vec_row < num_vec_rows;
......@@ -345,4 +360,3 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
}
}
} // namespace phi
#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -16,9 +16,10 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void LLMInt8MatmulKernel(const Context& dev_ctx,
void LLMInt8LinearKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const float threshold,
DenseTensor* out);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -16,9 +16,11 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void WeightOnlyMatmulKernel(const Context& dev_ctx,
void WeightOnlyLinearKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& weight_scale,
const std::string& weight_dtype,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -16,10 +16,9 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void QuantForCompressKernel(const Context& dev_ctx,
const DenseTensor& x,
int bits,
const std::string& layout,
DenseTensor* out,
DenseTensor* scale);
void WeightQuantizeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& algo,
DenseTensor* out,
DenseTensor* scale);
} // namespace phi
......@@ -71,7 +71,6 @@ from .layer.common import AlphaDropout # noqa: F401
from .layer.common import Unfold # noqa: F401
from .layer.common import Fold # noqa: F401
from .layer.common import Unflatten # noqa: F401
from .layer.common import LinearCompress # noqa: F401
from .layer.pooling import AvgPool1D # noqa: F401
from .layer.pooling import AvgPool2D # noqa: F401
from .layer.pooling import AvgPool3D # noqa: F401
......
......@@ -65,8 +65,6 @@ from .common import interpolate # noqa: F401
from .common import upsample # noqa: F401
from .common import bilinear # noqa: F401
from .common import class_center_sample # noqa: F401
from .common import quant_for_compress # noqa: F401
from .common import linear_compress # noqa: F401
from .conv import conv1d # noqa: F401
from .conv import conv1d_transpose # noqa: F401
from .common import linear # noqa: F401
......
......@@ -1877,86 +1877,6 @@ def linear(x, weight, bias=None, name=None):
return res
def quant_for_compress(x, bits=8, layout="weight_only"):
return _C_ops.quant_for_compress(x, bits, layout)
def linear_compress(
x,
weight,
weight_scale,
bias=None,
bits=8,
algo="llm.int8",
name=None,
config=None,
):
if in_dynamic_mode():
if algo == "llm.int8":
y = _C_ops.llm_int8_matmul(
x, weight, weight_scale, config['threshold']
)
elif algo == "weight_only":
y = _C_ops.weight_only_matmul(x, weight, weight_scale)
else:
raise ValueError(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format(
algo
)
)
if bias is not None:
y = paddle.add(y, bias)
return y
else:
helper = LayerHelper('linear_compress', **locals())
dtype = x.dtype
check_variable_and_dtype(x, 'x', ['float16'], 'linear_compress')
check_dtype(dtype, 'dtype', ['float16'], 'linear_compress')
if algo == "llm.int8":
type = "llm_int8_matmul"
inputs = {
'x': [x],
'weight': [weight],
'weight_scale': [weight_scale],
}
attrs = {'algo': algo, 'threshold': config['threshold']}
elif algo == "weight_only":
type = "weight_only_matmul"
inputs = {
'x': [x],
'weight': [weight],
'weight_scale': [weight_scale],
}
attrs = {}
else:
raise ValueError(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format(
algo
)
)
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=type,
inputs=inputs,
outputs={'Out': tmp},
attrs=attrs,
)
if bias is not None:
res = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_add',
inputs={'X': [tmp], 'Y': [bias]},
outputs={'Out': [res]},
attrs={'axis': -1},
)
else:
res = tmp
return res
def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
r"""
Label smoothing is a mechanism to regularize the classifier layer and is called
......
......@@ -194,173 +194,6 @@ class Linear(Layer):
)
class LinearCompress(Layer):
r"""
Fully-connected linear transformation layer. For each input :math:`X` ,
the equation is:
.. math::
Out = XW + b
where :math:`W` is the weight and :math:`b` is the bias.
Linear layer takes only one multi-dimensional tensor as input with the
shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
number of additional dimensions. It multiplies input tensor with the weight
(a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
an output tensor of shape :math:`[batch\_size, *, out\_features]` .
If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
shape :math:`[out\_features]` ) will be created and added to the output.
Parameters:
in_features (int): The number of input units.
out_features (int): The number of output units.
weight_attr (ParamAttr, optional): The attribute for the weight of this layer.
The default value is None. If the Initializer of the
param_attr is not set, the parameter is initialized with Xavier.
For detailed information, please refer to paddle.ParamAttr.
bias_attr (ParamAttr|bool, optional): The attribute for the bias of this layer.
If it is set to False, no bias will be added to the output.
If it is set to None or one kind of ParamAttr, a bias parameter will
be created according to ParamAttr. For detailed information, please refer
to paddle.ParamAttr. The default value is None and the bias will be
initialized to zero.
name (str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
bits (int, optional): The attribute to set num of bits in quant during weight_only,
it must be set as 8, default: 8.
algo (str, optional): The attribute to set algorithm of cpmoress, it must be set as 'weight_only'
or 'llm.int8', default: weight_only.
config (dict, optional): The parameter config for algorithm of cpmoress.
For llm.int8, it should be set as {'threshold': 6.0}, default: {'threshold': 6.0}.
Attribute:
**weight** (Parameter): the learnable weight of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` . Its data types are float16.
- output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` . The data type is the same as the input .
Examples:
.. code-block:: python
>>> import paddle
>>> paddle.seed(100)
>>> # Define the linear layer.
>>> paddle.set_default_dtype('float16')
>>> weight_attr = paddle.ParamAttr(
... name="weight",
... initializer=paddle.nn.initializer.Constant(value=0.5))
>>> bias_attr = paddle.ParamAttr(
... name="bias",
... initializer=paddle.nn.initializer.Constant(value=1.0))
>>> linear = paddle.nn.LinearCompress(128, 64, weight_attr=weight_attr, bias_attr=bias_attr, bits=8, algo='weight_only')
>>> x = paddle.randn((3, 128), dtype="float16")
>>> y = linear(x)
"""
def __init__(
self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None,
bits=8,
algo="weight_only",
config={'threshold': 6.0},
):
super().__init__()
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self.weight = self.create_parameter(
shape=[in_features, out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.bias = self.create_parameter(
shape=[out_features],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True,
)
self.weight_scale = self.create_parameter(
shape=[out_features],
attr=None,
dtype=self._dtype,
is_bias=False,
)
self.is_weight_quanted = False
self.name = (name,)
self.bits = bits
self.layout = algo
self.algo = algo
self.config = config
def forward(self, input):
if in_dynamic_mode():
if not self.is_weight_quanted:
weight_tensor, weight_scale_tensor = F.quant_for_compress(
self.weight, self.bits, self.layout
)
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(weight_tensor)
)
weight_shape = (
[self.weight.shape[1], self.weight.shape[0]]
if self.bits == 8
else [self.weight.shape[1] / 2, self.weight.shape[0]]
)
self.weight = self.create_parameter(
shape=weight_shape,
attr=weight_attr,
dtype="int8",
is_bias=False,
)
weight_scale_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(
weight_scale_tensor
)
)
self.weight_scale = self.create_parameter(
shape=self.weight_scale.shape,
attr=weight_scale_attr,
dtype="float32",
is_bias=False,
)
self.is_weight_quanted = True
out = F.linear_compress(
x=input,
weight=self.weight,
weight_scale=self.weight_scale,
bias=self.bias,
bits=self.bits,
algo=self.algo,
name=self.name,
config=self.config,
)
return out
def extra_repr(self):
name_str = f', name={self.name}' if self.name else ''
return 'in_features={}, out_features={}, dtype={}{}, algo={}'.format(
self.weight.shape[0],
self.weight.shape[1],
self._dtype,
name_str,
self.algo,
)
class Upsample(Layer):
"""
This op resizes a batch of images.
......
......@@ -22,8 +22,11 @@ from .functional_layers import transpose # noqa: F401
from .functional_layers import concat # noqa: F401
from .functional_layers import flatten # noqa: F401
from .functional_layers import matmul # noqa: F401
from .quantized_linear import weight_only_linear # noqa: F401
from .quantized_linear import llm_int8_linear # noqa: F401
from .quantized_linear import weight_quantize # noqa: F401
from .quant_layers import QuantStub # noqa: F401
from . import qat
from .stub import Stub
__all__ = ["Stub"]
__all__ = ["Stub", "weight_only_linear", "llm_int8_linear", "weight_quantize"]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
def weight_quantize(x, algo="weight_only_int8"):
"""
Quantization function for weight_only and llm.int8's weight.
Args:
x (Tensor): The input Tensor to be quantized, the data type is float16 or bfloat16.
algo (str|None): The algo that is x will be apply, must be one of 'weight_only_int8',
'weight_only_int4' and 'llm.int8', default: 'weight_only_int8'.
Returns:
out (Tensor): The Tensor which is the quantitative results, the data type is the same as that of x.
scale (Tensor): The scale Tensor which is the scale of pre-channel, the data type is float32.
Examples:
.. code-block:: python
import paddle
import numpy as np
from paddle.nn.quant import weight_quantize
paddle.device.set_device("cpu")
x = np.random.randn(64, 32).astype('float16')
x = paddle.to_tensor(x, dtype=paddle.float16, place=paddle.CPUPlace())
out, scale = weight_quantize(x, algo='weight_only_int8')
print(out.shape) # [32, 64]
print(scale.shape) # [32]
"""
if in_dynamic_mode():
return _C_ops.weight_quantize(x, algo)
else:
type = "weight_quantize"
helper = LayerHelper(type, **locals())
out = helper.create_variable_for_type_inference('int8')
scale = helper.create_variable_for_type_inference('float')
helper.append_op(
type=type,
inputs={"x": x},
outputs={'out': out, "scale": scale},
attrs={"algo": algo},
)
return (out, scale)
def weight_only_linear(
x,
weight,
bias=None,
weight_scale=None,
weight_dtype="int8",
):
"""
Applies matrix multiplication of two tensors and then bias addition if provided.
This method requires CUDA version >= 11.2.
Args:
x (Tensor): The first input Tensor to be multiplied, the data type is float16 or bfloat16.
weight (Tensor): The second input Tensor to be multiplied. Its rank must be 2.
bias (Tensor|None): The input bias Tensor. If it is None, no bias addition would
be performed. Otherwise, The bias is added to the matrix multiplication result.
weight_scale (Tensor|None): The input scale Tensor Provided to weight for dequantization. Its rank must be 1.
weight_dtype(str): The dtype of weight Tensor, must be one of 'int8', 'int4', Defaulted to 'int8'.
Returns:
Tensor: the output Tensor, the data type is the same as that of x.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.nn.quant import weight_only_linear
x = paddle.cast(paddle.randn([1, 2, 64]), dtype='float16')
weight = paddle.cast(paddle.randint(0, 127, [32, 64]), dtype='int8')
scale = paddle.randn([32], dtype='float32')
bias = paddle.cast(paddle.randn([32]), dtype='float16')
if paddle.device.cuda.get_device_capability()[0] >= 8:
out = weight_only_linear(x, weight, bias=bias, weight_scale=scale, weight_dtype='int8')
print(out.shape) # [1, 2, 32]
"""
if in_dynamic_mode():
out = _C_ops.weight_only_linear(
x, weight, bias, weight_scale, weight_dtype
)
return out
else:
type = "weight_only_linear"
helper = LayerHelper(type, **locals())
dtype = x.dtype
inputs = {
'x': [x],
'weight': [weight],
'bias': [bias],
'weight_scale': [weight_scale],
}
attrs = {'weight_dtype': weight_dtype}
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=type,
inputs=inputs,
outputs={'out': out},
attrs=attrs,
)
return out
def llm_int8_linear(
x,
weight,
bias=None,
weight_scale=None,
threshold=6.0,
):
"""
Applies matrix multiplication of two tensors and then bias addition if provided.
This method requires CUDA version >= 11.2.
Args:
x (Tensor): the first input Tensor to be multiplied, the data type is float16 or bfloat16.
weight (Tensor): the second input Tensor to be multiplied. Its rank must be 2.
bias (Tensor|None): the input bias Tensor. If it is None, no bias addition would
be performed. Otherwise, the bias is added to the matrix multiplication result.
weight_scale (Tensor|None): the input scale Tensor Provided to weight for dequantization. Its rank must be 1.
threshold(float): The min value of outlier in activation, outlier's channel will be apply multiply with x.dtype.
Returns:
Tensor: the output Tensor, the data type is the same as that of x.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.nn.quant import llm_int8_linear
x = paddle.cast(paddle.randn([1, 2, 64]), dtype='float16')
weight = paddle.cast(paddle.randint(0, 127, [32, 64]), dtype='int8')
scale = paddle.randn([32], dtype='float32')
bias = paddle.cast(paddle.randn([32]), dtype='float16')
if paddle.device.cuda.get_device_capability()[0] >= 8:
out = llm_int8_linear(x, weight, bias=bias, weight_scale=scale, threshold=6.0)
print(out.shape) # [1, 2, 32]
"""
if in_dynamic_mode():
out = _C_ops.llm_int8_linear(x, weight, bias, weight_scale, threshold)
return out
else:
type = "llm_int8_linear"
helper = LayerHelper(type, **locals())
dtype = x.dtype
inputs = {
'x': [x],
'weight': [weight],
'bias': [bias],
'weight_scale': [weight_scale],
}
attrs = {'threshold': threshold}
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=type,
inputs=inputs,
outputs={'out': out},
attrs=attrs,
)
return out
......@@ -88,7 +88,6 @@ endif()
list(REMOVE_ITEM TEST_OPS test_audio_logmel_feature test_audio_mel_feature)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
list(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
......@@ -159,7 +158,6 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
list(REMOVE_ITEM TEST_OPS test_matmul_int8_op)
list(REMOVE_ITEM TEST_OPS test_variable_length_memory_efficient_attention)
endif()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid.framework import default_main_program
from paddle.framework import set_default_dtype
np.random.seed(123)
paddle.seed(123)
default_main_program().random_seed = 42
paddle.disable_static()
class LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = 'float16'
self.rtol = 1e-5
self.atol = 1e-2
self.bias = True
self.in_features = 64
self.out_features = 64
self.algo = "weight_only"
self.bits = 8
def setUp(self):
self.config()
input = np.random.random((2, 4, self.in_features))
self.input = paddle.to_tensor(input, dtype=self.dtype)
if self.bias:
bias_attr = fluid.ParamAttr(
learning_rate=0.8,
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.Constant(value=1.0),
)
else:
bias_attr = None
set_default_dtype(self.dtype)
self.linear = paddle.nn.Linear(
self.in_features, self.out_features, bias_attr=bias_attr
)
if self.algo == "llm.int8":
self.config = {"threshold": 6.0}
else:
self.config = None
self.linear_compress = paddle.nn.LinearCompress(
self.in_features,
self.out_features,
bias_attr=bias_attr,
bits=8,
algo=self.algo,
config=self.config,
)
self.linear_compress(self.input)
def get_linear_out(self):
out = self.linear(self.input)
return out.numpy()
def get_linear_compress_out(self):
out = self.linear_compress(self.input)
return out.numpy()
def test_linear_compress(self):
out_real = self.get_linear_compress_out()
out_expect = self.get_linear_out()
np.testing.assert_allclose(
out_real, out_expect, rtol=self.rtol, atol=self.atol
)
class LinearTestCase1(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = True
self.in_features = 128
self.out_features = 64
class LinearTestCase2(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.in_features = 64
self.out_features = 64
class LinearTestCase3(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.in_features = 64
self.out_features = 64
self.algo = "llm.int8"
self.atol = 1e-1
class LinearTestCase4(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = True
self.in_features = 128
self.out_features = 64
self.bits = 4
if __name__ == '__main__':
unittest.main()
......@@ -227,6 +227,8 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_matmul)
list(REMOVE_ITEM TEST_OPS test_weight_only_linear)
list(REMOVE_ITEM TEST_OPS test_llm_int8_linear)
list(REMOVE_ITEM TEST_OPS test_quant_aware)
list(REMOVE_ITEM TEST_OPS test_quant_post_quant_aware)
list(REMOVE_ITEM TEST_OPS test_quant_aware_user_defined)
......@@ -235,6 +237,11 @@ if(WIN32)
endif()
if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_weight_only_linear)
list(REMOVE_ITEM TEST_OPS test_llm_int8_linear)
endif()
if(LINUX AND WITH_MKLDNN)
#### Image classification dataset: ImageNet (small)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from test_weight_only_linear import convert_uint16_to_float, get_cuda_version
import paddle
import paddle.nn.quant as Q
from paddle import fluid
from paddle.fluid import core
from paddle.fluid.framework import default_main_program
from paddle.framework import set_default_dtype
np.random.seed(123)
paddle.seed(123)
default_main_program().random_seed = 42
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = 'float16'
self.rtol = 1e-5
self.atol = 1e-1
self.bias = True
self.batch = 1
self.token = 32
self.in_features = 64
self.out_features = 256
self.threshold = 6.0
self.static = False
def setUp(self):
self.config()
x = np.random.random((self.batch, self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
if self.bias:
bias_attr = fluid.ParamAttr(
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.Constant(value=1.0),
)
else:
bias_attr = None
set_default_dtype(self.dtype)
self.linear = paddle.nn.Linear(
self.in_features, self.out_features, bias_attr=bias_attr
)
self.bias = self.linear.bias
self.weight = self.linear.weight
self.weight_scale = None
self.weight, self.weight_scale = Q.weight_quantize(
self.weight, algo="llm.int8"
)
def get_linear_out(self):
out = self.linear(self.x)
return out.numpy()
def get_llm_int8_linear_out(self):
out = Q.llm_int8_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
threshold=self.threshold,
)
return out.numpy()
def get_llm_int8_linear_out_static(self):
paddle.enable_static()
main = fluid.Program()
start = fluid.Program()
with fluid.program_guard(main, start):
x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype)
weight = paddle.static.data(
"weight", self.weight.shape, dtype=self.weight.dtype
)
bias = paddle.static.data(
"bias", self.bias.shape, dtype=self.bias.dtype
)
x_np = self.x.numpy()
weight_np = self.weight.numpy()
bias_np = self.bias.numpy()
if self.weight_scale is not None:
weight_scale = paddle.static.data(
"weight_scale",
self.weight_scale.shape,
dtype=self.weight_scale.dtype,
)
weight_scale_np = self.weight_scale.numpy()
else:
weight_scale = None
weight_scale_np = None
out = Q.llm_int8_linear(
x,
weight,
bias,
weight_scale,
self.threshold,
)
feed_dict = {
'x': x_np,
'weight': weight_np,
'bias': bias_np,
"weight_scale": weight_scale_np,
}
exe = fluid.Executor(paddle.CUDAPlace(0))
exe.run(start)
(out,) = exe.run(main, feed=feed_dict, fetch_list=[out])
paddle.disable_static()
return out
def test_llm_int8_linear(self):
out_expect = self.get_linear_out()
if self.static:
out_real = self.get_llm_int8_linear_out_static()
else:
out_real = self.get_llm_int8_linear_out()
if self.dtype == "bfloat16":
out_real = convert_uint16_to_float(out_real)
out_expect = convert_uint16_to_float(out_expect)
np.testing.assert_allclose(
out_real, out_expect, rtol=self.rtol, atol=self.atol
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase1(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase2(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.weight_dtype = "int8"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase3(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
)
class LLMInt8LinearTestCase4(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int4"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase5(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.weight_dtype = "int4"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
)
class LLMInt8LinearTestCase6(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int4"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase7(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase8(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
self.bias = False
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase9(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase10(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
self.bias = False
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCaseStatic(LLMInt8LinearTestCase):
def config(self):
super().config()
self.static = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import struct
import unittest
import numpy as np
import paddle
import paddle.nn.quant as Q
from paddle import fluid
from paddle.fluid import core
from paddle.fluid.framework import default_main_program
from paddle.framework import set_default_dtype
np.random.seed(123)
paddle.seed(123)
default_main_program().random_seed = 42
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def convert_uint16_to_float(in_list):
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack(
'<f', struct.pack('<I', np.uint32(x) << np.uint32(16))
)[0],
otypes=[np.float32],
)(in_list.flat)
return np.reshape(out, in_list.shape)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase(unittest.TestCase):
def config(self):
self.dtype = 'float16'
self.rtol = 1e-5
self.atol = 1e-2
self.bias = True
self.batch = 1
self.token = 32
self.in_features = 64
self.out_features = 256
self.weight_dtype = "int8"
self.static = False
def setUp(self):
self.config()
if self.dtype == "bfloat16" or self.weight_dtype == "int4":
self.atol = 1e-1
x = np.random.random((self.batch, self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
if self.bias:
bias_attr = fluid.ParamAttr(
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.Constant(value=1.0),
)
else:
bias_attr = None
set_default_dtype(self.dtype)
self.linear = paddle.nn.Linear(
self.in_features, self.out_features, bias_attr=bias_attr
)
self.bias = self.linear.bias
self.weight = self.linear.weight
self.weight_scale = None
self.weight, self.weight_scale = Q.weight_quantize(
self.weight,
algo="weight_only_int8"
if self.weight_dtype == "int8"
else "weight_only_int4",
)
def get_linear_out(self):
out = self.linear(self.x)
return out.numpy()
def get_weight_only_linear_out(self):
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
)
return out.numpy()
def get_weight_only_linear_out_static(self):
paddle.enable_static()
main = fluid.Program()
start = fluid.Program()
with fluid.program_guard(main, start):
x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype)
weight = paddle.static.data(
"weight", self.weight.shape, dtype=self.weight.dtype
)
bias = paddle.static.data(
"bias", self.bias.shape, dtype=self.bias.dtype
)
x_np = self.x.numpy()
weight_np = self.weight.numpy()
bias_np = self.bias.numpy()
if self.weight_scale is not None:
weight_scale = paddle.static.data(
"weight_scale",
self.weight_scale.shape,
dtype=self.weight_scale.dtype,
)
weight_scale_np = self.weight_scale.numpy()
else:
weight_scale = None
weight_scale_np = None
out = Q.weight_only_linear(
x,
weight,
bias,
weight_scale,
self.weight_dtype,
)
feed_dict = {
'x': x_np,
'weight': weight_np,
'bias': bias_np,
"weight_scale": weight_scale_np,
}
exe = fluid.Executor(paddle.CUDAPlace(0))
exe.run(start)
(out,) = exe.run(main, feed=feed_dict, fetch_list=[out])
paddle.disable_static()
return out
def test_weight_only_linear(self):
out_expect = self.get_linear_out()
if self.static:
out_real = self.get_weight_only_linear_out_static()
else:
out_real = self.get_weight_only_linear_out()
if self.dtype == "bfloat16":
out_real = convert_uint16_to_float(out_real)
out_expect = convert_uint16_to_float(out_expect)
np.testing.assert_allclose(
out_real, out_expect, rtol=self.rtol, atol=self.atol
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase1(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase2(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.weight_dtype = "int8"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase3(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
)
class WeightOnlyLinearTestCase4(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int4"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase5(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.weight_dtype = "int4"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
)
class WeightOnlyLinearTestCase6(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int4"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase7(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase8(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
self.bias = False
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase9(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase10(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
self.bias = False
self.batch = 1
self.token = 1
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCaseStatic(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.static = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册