未验证 提交 27cc0df5 编写于 作者: R RichardWooSJTU 提交者: GitHub

Add matmul_int8 op (#55228)

* add matmul int8
上级 2194e4c1
...@@ -57,7 +57,7 @@ class AttnMatmulINT8 { ...@@ -57,7 +57,7 @@ class AttnMatmulINT8 {
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) { const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(), LaunchQuantKernel<T>(input->data<T>(),
input_tmp->data<int8_t>(), input_tmp->data<int8_t>(),
quant_in_scale, quant_in_scale,
m_, m_,
...@@ -72,7 +72,7 @@ class AttnMatmulINT8 { ...@@ -72,7 +72,7 @@ class AttnMatmulINT8 {
output_tmp->data<int32_t>(), output_tmp->data<int32_t>(),
dev_ctx_.stream()); dev_ctx_.stream());
dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(), LaunchDequantKernel<T>(output_tmp->data<int32_t>(),
output->data<T>(), output->data<T>(),
m_, m_,
n_, n_,
...@@ -126,7 +126,7 @@ class AttnMatmulINT8 { ...@@ -126,7 +126,7 @@ class AttnMatmulINT8 {
output_tmp->data<int32_t>(), output_tmp->data<int32_t>(),
dev_ctx_.stream()); dev_ctx_.stream());
dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(), LaunchDequantKernel<T>(output_tmp->data<int32_t>(),
output->data<T>(), output->data<T>(),
m_, m_,
n_, n_,
...@@ -162,7 +162,7 @@ class AttnMatmulINT8 { ...@@ -162,7 +162,7 @@ class AttnMatmulINT8 {
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) { const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(), LaunchQuantKernel<T>(input->data<T>(),
input_tmp->data<int8_t>(), input_tmp->data<int8_t>(),
quant_in_scale, quant_in_scale,
m_, m_,
......
...@@ -47,7 +47,7 @@ __forceinline__ __device__ int8_t quant_helper(const T input, ...@@ -47,7 +47,7 @@ __forceinline__ __device__ int8_t quant_helper(const T input,
} }
template <typename T> template <typename T>
__global__ void quantize_kernel(const T* input, __global__ void QuantKernel(const T* input,
char4* output, char4* output,
const float scale, const float scale,
const int m, const int m,
...@@ -74,7 +74,7 @@ __global__ void quantize_kernel(const T* input, ...@@ -74,7 +74,7 @@ __global__ void quantize_kernel(const T* input,
} }
template <typename T> template <typename T>
void quantize_kernel_launcher(const T* input, void LaunchQuantKernel(const T* input,
int8_t* output, int8_t* output,
const float scale, const float scale,
const int m, const int m,
...@@ -87,7 +87,7 @@ void quantize_kernel_launcher(const T* input, ...@@ -87,7 +87,7 @@ void quantize_kernel_launcher(const T* input,
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32); dim3 block(32, 32);
quantize_kernel<<<grid, block, 0, stream>>>(input, QuantKernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT (char4*)output, // NOLINT
scale, scale,
m, m,
...@@ -98,7 +98,7 @@ void quantize_kernel_launcher(const T* input, ...@@ -98,7 +98,7 @@ void quantize_kernel_launcher(const T* input,
} }
template <typename T, int VecSize> template <typename T, int VecSize>
__global__ void dequantize_kernel(T* output, __global__ void DequantKernel(T* output,
const int32_t* input, const int32_t* input,
const int m, // batch size const int m, // batch size
const int n, // hidden const int n, // hidden
...@@ -128,7 +128,7 @@ __global__ void dequantize_kernel(T* output, ...@@ -128,7 +128,7 @@ __global__ void dequantize_kernel(T* output,
} }
template <typename T> template <typename T>
void dequantize_kernel_launcher(const int32_t* input, void LaunchDequantKernel(const int32_t* input,
T* output, T* output,
const int m, // m const int m, // m
const int n, // n const int n, // n
...@@ -136,7 +136,7 @@ void dequantize_kernel_launcher(const int32_t* input, ...@@ -136,7 +136,7 @@ void dequantize_kernel_launcher(const int32_t* input,
GpuLaunchConfig* gpu_config, GpuLaunchConfig* gpu_config,
const float quant_in_scale, const float quant_in_scale,
const float* dequant_out_scale_data) { const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize> DequantKernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>( <<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data); output, input, m, n, quant_in_scale, dequant_out_scale_data);
} }
......
...@@ -523,6 +523,14 @@ ...@@ -523,6 +523,14 @@
func : matmul func : matmul
backward : matmul_grad backward : matmul_grad
- op : matmul_int8
args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)
output : Tensor
infer_meta :
func : MatmulInt8InferMeta
kernel :
func : matmul_int8
- op : matrix_rank - op : matrix_rank
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false) args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out) output : Tensor(out)
......
...@@ -2096,6 +2096,76 @@ void MatmulInferMeta(const MetaTensor& x, ...@@ -2096,6 +2096,76 @@ void MatmulInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void MatmulInt8InferMeta(const MetaTensor& x,
const MetaTensor& y,
bool trans_x,
bool trans_y,
MetaTensor* out) {
std::vector<int64_t> dims_x = phi::vectorize(x.dims());
std::vector<int64_t> dims_y = phi::vectorize(y.dims());
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0UL,
phi::errors::InvalidArgument(
"The Input(x) dims size must be greater than 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0UL,
phi::errors::InvalidArgument(
"The Input(y) dims size must be greater than 0,"
" but reviced dims size is 0. "));
bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
auto ddim_out = phi::make_ddim(new_dims);
out->set_dims(ddim_out);
out->set_dtype(phi::DataType::INT32);
out->set_layout(x.layout());
}
void MatmulWithFlattenInferMeta(const MetaTensor& x, void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int x_num_col_dims, int x_num_col_dims,
......
...@@ -333,6 +333,12 @@ void MatmulInferMeta(const MetaTensor& x, ...@@ -333,6 +333,12 @@ void MatmulInferMeta(const MetaTensor& x,
bool trans_y, bool trans_y,
MetaTensor* out); MetaTensor* out);
void MatmulInt8InferMeta(const MetaTensor& x,
const MetaTensor& y,
bool trans_x,
bool trans_y,
MetaTensor* out);
void MatmulWithFlattenInferMeta(const MetaTensor& x, void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int x_num_col_dims, int x_num_col_dims,
......
...@@ -39,25 +39,16 @@ const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{}; ...@@ -39,25 +39,16 @@ const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};
class CublasLtHelper { class CublasLtHelper {
public: public:
CublasLtHelper(int m, int k, int n) CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle)
: alpha_(1), beta_(0), m_(m), k_(k), n_(n) { : handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
cublasStatus_t status; cublasStatus_t status;
// handle and matmul desc
status = dyl::cublasLtCreate(&handle_);
#if CUBLAS_VER_MAJOR < 11 #if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I; cudaDataType_t cudaComputeType = CUDA_R_32I;
#else #else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif #endif
PADDLE_ENFORCE_EQ( // matmul desc
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
#if CUBLAS_VER_MAJOR < 11 #if CUBLAS_VER_MAJOR < 11
status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType);
#else #else
...@@ -179,7 +170,7 @@ class CublasLtHelper { ...@@ -179,7 +170,7 @@ class CublasLtHelper {
} }
~CublasLtHelper() {} ~CublasLtHelper() {}
void GEMM(int8_t* A_dev, void GEMM(const int8_t* A_dev,
const int8_t* B_dev, const int8_t* B_dev,
int32_t* C_dev, int32_t* C_dev,
cudaStream_t stream, cudaStream_t stream,
...@@ -226,14 +217,14 @@ class CublasLtHelper { ...@@ -226,14 +217,14 @@ class CublasLtHelper {
cublasLtMatmulAlgo_t algo_; cublasLtMatmulAlgo_t algo_;
int32_t alpha_; int32_t alpha_ = 1;
int32_t beta_; int32_t beta_ = 0;
int m_; int m_ = 0;
int k_; int k_ = 0;
int n_; int n_ = 0;
size_t workspace_size_; size_t workspace_size_ = 0;
}; };
} // namespace phi } // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Paddle/paddle/phi/kernels/funcs/cublaslt.h"
namespace phi {
template <typename T>
class Int8GEMMHelper {
public:
Int8GEMMHelper(const phi::GPUContext &dev_ctx,
int m,
int k,
int n,
phi::DenseTensor &workspace, // NOLINT
phi::DenseTensor &input_workspace, // NOLINT
phi::DenseTensor &out_workspace, // NOLINT
int quant_round_type,
float quant_max_bound,
float quant_min_bound)
: dev_ctx_(dev_ctx),
m_(m),
k_(k),
n_(n),
quant_round_type_(quant_round_type),
quant_min_bound_(quant_min_bound),
quant_max_bound_(quant_max_bound),
workspace_(workspace),
input_workspace_(input_workspace),
out_workspace_(out_workspace) {
cublaslt_helper = std::make_unique<CublasLtHelper<int32_t>>(
m, k, n, dev_ctx.cublaslt_handle());
}
void Compute(const phi::DenseTensor *input,
const phi::DenseTensor *weight, // int8, Need be transposed
const phi::DenseTensor *dequant_out_scales,
const float quant_in_scale,
phi::DenseTensor *output,
bool quant_in = false,
bool dequant_out = false) {
phi::DenseTensor input_tmp, out_tmp;
if (quant_in) {
input_tmp = input_workspace_;
LaunchQuantKernel<T>(input->data<T>(),
input_tmp.data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_,
dev_ctx_.stream());
} else {
input_tmp = *input;
}
if (dequant_out) {
out_tmp = out_workspace_;
} else {
out_tmp = *output;
}
cublaslt_helper->GEMM(input_tmp.data<int8_t>(),
weight->data<int8_t>(),
out_tmp.data<int32_t>(),
dev_ctx_.stream(),
(void *)workspace_.data<int8_t>(),
workspace_.numel());
if (dequant_out) {
auto gpu_config = std::make_unique<GpuLaunchConfig>(
phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx_, m_ * n_, DequantKernelVecSize));
LaunchDequantKernel<T>(out_tmp.data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config.get(),
quant_in_scale,
dequant_out_scales->data<float>());
}
}
private:
const phi::GPUContext &dev_ctx_;
int m_;
int k_;
int n_;
int quant_round_type_;
float quant_max_bound_;
float quant_min_bound_;
phi::DenseTensor &workspace_; // char
phi::DenseTensor &input_workspace_; // int8_t
phi::DenseTensor &out_workspace_; // int32_t
std::unique_ptr<CublasLtHelper<int32_t>> cublaslt_helper;
};
} // namespace phi
...@@ -61,7 +61,7 @@ __forceinline__ __device__ int8_t quant_helper(const T input, ...@@ -61,7 +61,7 @@ __forceinline__ __device__ int8_t quant_helper(const T input,
} }
template <typename T> template <typename T>
__global__ void quantize_kernel(const T* input, __global__ void QuantKernel(const T* input,
char4* output, char4* output,
const float scale, const float scale,
const int m, const int m,
...@@ -88,7 +88,7 @@ __global__ void quantize_kernel(const T* input, ...@@ -88,7 +88,7 @@ __global__ void quantize_kernel(const T* input,
} }
template <typename T> template <typename T>
void quantize_kernel_launcher(const T* input, void LaunchQuantKernel(const T* input,
int8_t* output, int8_t* output,
const float scale, const float scale,
const int m, const int m,
...@@ -101,7 +101,7 @@ void quantize_kernel_launcher(const T* input, ...@@ -101,7 +101,7 @@ void quantize_kernel_launcher(const T* input,
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32); dim3 block(32, 32);
quantize_kernel<<<grid, block, 0, stream>>>(input, QuantKernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT (char4*)output, // NOLINT
scale, scale,
m, m,
...@@ -112,7 +112,7 @@ void quantize_kernel_launcher(const T* input, ...@@ -112,7 +112,7 @@ void quantize_kernel_launcher(const T* input,
} }
template <typename T, int VecSize> template <typename T, int VecSize>
__global__ void dequantize_kernel(T* output, __global__ void DequantKernel(T* output,
const int32_t* input, const int32_t* input,
const int m, // batch size const int m, // batch size
const int n, // hidden const int n, // hidden
...@@ -142,7 +142,7 @@ __global__ void dequantize_kernel(T* output, ...@@ -142,7 +142,7 @@ __global__ void dequantize_kernel(T* output,
} }
template <typename T> template <typename T>
void dequantize_kernel_launcher(const int32_t* input, void LaunchDequantKernel(const int32_t* input,
T* output, T* output,
const int m, // m const int m, // m
const int n, // n const int n, // n
...@@ -150,7 +150,7 @@ void dequantize_kernel_launcher(const int32_t* input, ...@@ -150,7 +150,7 @@ void dequantize_kernel_launcher(const int32_t* input,
GpuLaunchConfig* gpu_config, GpuLaunchConfig* gpu_config,
const float quant_in_scale, const float quant_in_scale,
const float* dequant_out_scale_data) { const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize> DequantKernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>( <<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data); output, input, m, n, quant_in_scale, dequant_out_scale_data);
} }
......
...@@ -30,6 +30,9 @@ PD_REGISTER_KERNEL(matmul, ...@@ -30,6 +30,9 @@ PD_REGISTER_KERNEL(matmul,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(
matmul_int8, GPU, ALL_LAYOUT, phi::MatmulInt8Kernel, int8_t) {}
PD_REGISTER_KERNEL(matmul_with_flatten, PD_REGISTER_KERNEL(matmul_with_flatten,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -667,7 +667,8 @@ void LLMGemm(const phi::GPUContext& dev_ctx, ...@@ -667,7 +667,8 @@ void LLMGemm(const phi::GPUContext& dev_ctx,
dev_ctx.Alloc<int32_t>(&int_out); dev_ctx.Alloc<int32_t>(&int_out);
{ {
auto helper = std::make_unique<CublasLtHelper>(m, k, n); auto helper =
std::make_unique<CublasLtHelper>(m, k, n, dev_ctx.cublaslt_handle());
helper->GEMM(quant_input.data<int8_t>(), helper->GEMM(quant_input.data<int8_t>(),
weight->data<int8_t>(), weight->data<int8_t>(),
int_out.data<int32_t>(), int_out.data<int32_t>(),
......
...@@ -16,11 +16,15 @@ limitations under the License. */ ...@@ -16,11 +16,15 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/autotune/cache_base.h" #include "paddle/phi/kernels/autotune/cache_base.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/phi/kernels/funcs/cublaslt.h"
#endif
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/phi/kernels/autotune/auto_tune_base.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h"
#endif #endif
...@@ -948,6 +952,15 @@ struct MatMulDispatcher<phi::GPUContext, T> { ...@@ -948,6 +952,15 @@ struct MatMulDispatcher<phi::GPUContext, T> {
#endif #endif
} }
}; };
static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx,
size_t workspace_size) {
return phi::memory_utils::Alloc(
ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
}
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
template <typename Context, typename T> template <typename Context, typename T>
...@@ -964,6 +977,107 @@ void MatMulFunction(const Context& ctx, ...@@ -964,6 +977,107 @@ void MatMulFunction(const Context& ctx,
ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag);
} }
template <typename Context>
void MatMulInt8Function(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool trans_x,
bool trans_y) {
PADDLE_ENFORCE_EQ(
x.dtype(),
DataType::INT8,
phi::errors::InvalidArgument(
"The type of input(x) used in int8 matmul must be (%s) does not "
"match the "
"type of data (%s) currently contained in the container.",
phi::CppTypeToDataType<int8_t>::Type(),
x.dtype()));
PADDLE_ENFORCE_EQ(
y.dtype(),
DataType::INT8,
phi::errors::InvalidArgument(
"The type of input(y) used in int8 matmul must be (%s) does not "
"match the "
"type of data (%s) currently contained in the container.",
phi::CppTypeToDataType<int8_t>::Type(),
x.dtype()));
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
PADDLE_ENFORCE_EQ(
x_ndim,
2,
phi::errors::InvalidArgument("[INT8 GEMM] The number of dims of input(x) "
"must be equal to 2 but received %d",
x_ndim));
PADDLE_ENFORCE_EQ(
y_ndim,
2,
phi::errors::InvalidArgument("[INT8 GEMM] The number of dims of input(x) "
"must be equal to 2 but received %d",
y_ndim));
PADDLE_ENFORCE_EQ(
trans_x,
false,
phi::errors::InvalidArgument("[INT8 GEMM] Input(x) must be not "
"transposed to acheive better performance"));
PADDLE_ENFORCE_EQ(
trans_y,
true,
phi::errors::InvalidArgument("[INT8 GEMM] Input(y) must be transposed to "
"acheive better performance"));
const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2];
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (trans_y) {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 1],
K,
phi::errors::InvalidArgument("Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1,
K,
y_ndim - 1,
y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 2],
K,
phi::errors::InvalidArgument("Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2,
K,
y_ndim - 2,
y_dims[y_ndim - 2]));
}
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);
// TODO(wufeisheng): cublaslt_helper is a temp scheme for Int8 GEMM,
// and releted functions need to be integrated into
// phi::funcs::MatmulWithCublasLt
auto cublaslt_helper = CublasLtHelper(M, K, N, ctx.cublaslt_handle());
ctx.template Alloc<int32_t>(out);
cublaslt_helper.GEMM(x.data<int8_t>(),
y.data<int8_t>(),
out->data<int32_t>(),
ctx.stream(),
workspace->ptr());
#else
PADDLE_THROW(phi::errors::Unimplemented(
"MatmulInt8 op needs paddle with cuda and cuda version >= 11.2"));
#endif
}
template <typename T, typename Context> template <typename T, typename Context>
void MatmulKernel(const Context& ctx, void MatmulKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -987,6 +1101,29 @@ void MatmulKernel(const Context& ctx, ...@@ -987,6 +1101,29 @@ void MatmulKernel(const Context& ctx,
ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y);
} }
template <typename T, typename Context>
void MatmulInt8Kernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(
phi::product(x.dims()),
0,
phi::errors::InvalidArgument("The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(
phi::product(y.dims()),
0,
phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
const std::vector<std::int64_t> x_dims = vectorize(x.dims());
const std::vector<std::int64_t> y_dims = vectorize(y.dims());
MatMulInt8Function<Context>(
ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y);
}
template <typename T, typename Context> template <typename T, typename Context>
void MatmulWithFlattenKernel(const Context& dev_ctx, void MatmulWithFlattenKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -157,6 +157,7 @@ if(WIN32) ...@@ -157,6 +157,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress) list(REMOVE_ITEM TEST_OPS test_linear_compress)
list(REMOVE_ITEM TEST_OPS test_matmul_int8_op)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from test_sparse_attention_op import get_cuda_version
import paddle
from paddle.fluid import core
paddle.disable_static()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"MatmulInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class TestMatmulInt8(unittest.TestCase):
"""
Test matmul int8
Only NT (Non-Transposed-A and Transposed-B) is supported
"""
def config(self):
self.dtype = 'int8'
self.rtol = 1e-5
self.atol = 1e-2
self.bias = True
self.m = 8
self.k = 64
self.n = 64
def setUp(self):
self.config()
self.input_a_np = np.random.randint(-127, 127, [self.m, self.k]).astype(
'int32'
)
self.input_b_np = np.random.randint(-127, 127, [self.k, self.n]).astype(
'int32'
)
self.input_a = paddle.to_tensor(self.input_a_np, dtype=self.dtype)
self.input_b = paddle.to_tensor(
self.input_b_np.transpose((1, 0)), dtype=self.dtype
)
def get_reference_out(self):
out = np.dot(self.input_a_np, self.input_b_np)
return out
def get_op_out(self):
out = paddle._C_ops.matmul_int8(self.input_a, self.input_b, False, True)
return out.numpy()
def test_matmul_int8(self):
out_real = self.get_op_out()
out_expect = self.get_reference_out()
np.testing.assert_allclose(
out_real, out_expect, rtol=self.rtol, atol=self.atol
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册