Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
27cc0df5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
27cc0df5
编写于
7月 13, 2023
作者:
R
RichardWooSJTU
提交者:
GitHub
7月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add matmul_int8 op (#55228)
* add matmul int8
上级
2194e4c1
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
542 addition
and
134 deletion
+542
-134
paddle/fluid/operators/fused/attn_gemm_int8.h
paddle/fluid/operators/fused/attn_gemm_int8.h
+34
-34
paddle/fluid/operators/fused/quant_dequant_kernel.h
paddle/fluid/operators/fused/quant_dequant_kernel.h
+40
-40
paddle/phi/api/yaml/legacy_ops.yaml
paddle/phi/api/yaml/legacy_ops.yaml
+8
-0
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+70
-0
paddle/phi/infermeta/binary.h
paddle/phi/infermeta/binary.h
+6
-0
paddle/phi/kernels/funcs/cublaslt.h
paddle/phi/kernels/funcs/cublaslt.h
+10
-19
paddle/phi/kernels/funcs/gemm_int8_helper.h
paddle/phi/kernels/funcs/gemm_int8_helper.h
+114
-0
paddle/phi/kernels/funcs/quant_dequant.h
paddle/phi/kernels/funcs/quant_dequant.h
+40
-40
paddle/phi/kernels/gpu/matmul_kernel.cu
paddle/phi/kernels/gpu/matmul_kernel.cu
+3
-0
paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h
paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h
+2
-1
paddle/phi/kernels/impl/matmul_kernel_impl.h
paddle/phi/kernels/impl/matmul_kernel_impl.h
+137
-0
test/legacy_test/CMakeLists.txt
test/legacy_test/CMakeLists.txt
+1
-0
test/legacy_test/test_matmul_int8_op.py
test/legacy_test/test_matmul_int8_op.py
+77
-0
未找到文件。
paddle/fluid/operators/fused/attn_gemm_int8.h
浏览文件 @
27cc0df5
...
...
@@ -57,7 +57,7 @@ class AttnMatmulINT8 {
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
quantize_kernel_launcher
<
T
>
(
input
->
data
<
T
>
(),
LaunchQuantKernel
<
T
>
(
input
->
data
<
T
>
(),
input_tmp
->
data
<
int8_t
>
(),
quant_in_scale
,
m_
,
...
...
@@ -72,7 +72,7 @@ class AttnMatmulINT8 {
output_tmp
->
data
<
int32_t
>
(),
dev_ctx_
.
stream
());
dequantize_kernel_launcher
<
T
>
(
output_tmp
->
data
<
int32_t
>
(),
LaunchDequantKernel
<
T
>
(
output_tmp
->
data
<
int32_t
>
(),
output
->
data
<
T
>
(),
m_
,
n_
,
...
...
@@ -126,7 +126,7 @@ class AttnMatmulINT8 {
output_tmp
->
data
<
int32_t
>
(),
dev_ctx_
.
stream
());
dequantize_kernel_launcher
<
T
>
(
output_tmp
->
data
<
int32_t
>
(),
LaunchDequantKernel
<
T
>
(
output_tmp
->
data
<
int32_t
>
(),
output
->
data
<
T
>
(),
m_
,
n_
,
...
...
@@ -162,7 +162,7 @@ class AttnMatmulINT8 {
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
quantize_kernel_launcher
<
T
>
(
input
->
data
<
T
>
(),
LaunchQuantKernel
<
T
>
(
input
->
data
<
T
>
(),
input_tmp
->
data
<
int8_t
>
(),
quant_in_scale
,
m_
,
...
...
paddle/fluid/operators/fused/quant_dequant_kernel.h
浏览文件 @
27cc0df5
...
...
@@ -47,7 +47,7 @@ __forceinline__ __device__ int8_t quant_helper(const T input,
}
template
<
typename
T
>
__global__
void
quantize_k
ernel
(
const
T
*
input
,
__global__
void
QuantK
ernel
(
const
T
*
input
,
char4
*
output
,
const
float
scale
,
const
int
m
,
...
...
@@ -74,7 +74,7 @@ __global__ void quantize_kernel(const T* input,
}
template
<
typename
T
>
void
quantize_kernel_launcher
(
const
T
*
input
,
void
LaunchQuantKernel
(
const
T
*
input
,
int8_t
*
output
,
const
float
scale
,
const
int
m
,
...
...
@@ -87,7 +87,7 @@ void quantize_kernel_launcher(const T* input,
dim3
grid
((
n
>>
2
+
31
)
/
32
,
(
m
+
31
)
/
32
);
dim3
block
(
32
,
32
);
quantize_k
ernel
<<<
grid
,
block
,
0
,
stream
>>>
(
input
,
QuantK
ernel
<<<
grid
,
block
,
0
,
stream
>>>
(
input
,
(
char4
*
)
output
,
// NOLINT
scale
,
m
,
...
...
@@ -98,7 +98,7 @@ void quantize_kernel_launcher(const T* input,
}
template
<
typename
T
,
int
VecSize
>
__global__
void
dequantize_k
ernel
(
T
*
output
,
__global__
void
DequantK
ernel
(
T
*
output
,
const
int32_t
*
input
,
const
int
m
,
// batch size
const
int
n
,
// hidden
...
...
@@ -128,7 +128,7 @@ __global__ void dequantize_kernel(T* output,
}
template
<
typename
T
>
void
dequantize_kernel_launcher
(
const
int32_t
*
input
,
void
LaunchDequantKernel
(
const
int32_t
*
input
,
T
*
output
,
const
int
m
,
// m
const
int
n
,
// n
...
...
@@ -136,7 +136,7 @@ void dequantize_kernel_launcher(const int32_t* input,
GpuLaunchConfig
*
gpu_config
,
const
float
quant_in_scale
,
const
float
*
dequant_out_scale_data
)
{
dequantize_k
ernel
<
T
,
DequantKernelVecSize
>
DequantK
ernel
<
T
,
DequantKernelVecSize
>
<<<
gpu_config
->
block_per_grid
,
gpu_config
->
thread_per_block
,
0
,
stream
>>>
(
output
,
input
,
m
,
n
,
quant_in_scale
,
dequant_out_scale_data
);
}
...
...
paddle/phi/api/yaml/legacy_ops.yaml
浏览文件 @
27cc0df5
...
...
@@ -523,6 +523,14 @@
func
:
matmul
backward
:
matmul_grad
-
op
:
matmul_int8
args
:
(Tensor x, Tensor y, bool transpose_x =
false
, bool transpose_y =
false
)
output
:
Tensor
infer_meta
:
func
:
MatmulInt8InferMeta
kernel
:
func
:
matmul_int8
-
op
:
matrix_rank
args
:
(Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output
:
Tensor(out)
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
27cc0df5
...
...
@@ -2096,6 +2096,76 @@ void MatmulInferMeta(const MetaTensor& x,
out
->
set_layout
(
x
.
layout
());
}
void
MatmulInt8InferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
bool
trans_x
,
bool
trans_y
,
MetaTensor
*
out
)
{
std
::
vector
<
int64_t
>
dims_x
=
phi
::
vectorize
(
x
.
dims
());
std
::
vector
<
int64_t
>
dims_y
=
phi
::
vectorize
(
y
.
dims
());
auto
ndims_x
=
dims_x
.
size
();
auto
ndims_y
=
dims_y
.
size
();
PADDLE_ENFORCE_GT
(
ndims_x
,
0UL
,
phi
::
errors
::
InvalidArgument
(
"The Input(x) dims size must be greater than 0,"
" but reviced dims size is 0. "
));
PADDLE_ENFORCE_GT
(
ndims_y
,
0UL
,
phi
::
errors
::
InvalidArgument
(
"The Input(y) dims size must be greater than 0,"
" but reviced dims size is 0. "
));
bool
x_broadcasted
=
false
,
y_broadcasted
=
false
;
if
(
ndims_x
==
1
)
{
dims_x
.
insert
(
dims_x
.
begin
(),
1
);
ndims_x
=
2
;
x_broadcasted
=
true
;
}
if
(
ndims_y
==
1
)
{
dims_y
.
push_back
(
1
);
ndims_y
=
2
;
y_broadcasted
=
true
;
}
size_t
M
,
N
;
if
(
trans_x
)
{
M
=
dims_x
[
ndims_x
-
1
];
}
else
{
M
=
dims_x
[
ndims_x
-
2
];
}
if
(
trans_y
)
{
N
=
dims_y
[
ndims_y
-
2
];
}
else
{
N
=
dims_y
[
ndims_y
-
1
];
}
std
::
vector
<
int64_t
>
new_dims
;
if
(
ndims_x
>
ndims_y
)
{
new_dims
.
assign
(
dims_x
.
begin
(),
dims_x
.
end
()
-
2
);
}
else
if
(
ndims_x
<
ndims_y
)
{
new_dims
.
assign
(
dims_y
.
begin
(),
dims_y
.
end
()
-
2
);
}
else
{
new_dims
.
reserve
(
ndims_x
);
for
(
size_t
i
=
0
;
i
<
ndims_x
-
2
;
++
i
)
{
new_dims
.
push_back
(
std
::
max
(
dims_x
[
i
],
dims_y
[
i
]));
}
}
if
(
!
x_broadcasted
)
{
new_dims
.
push_back
(
M
);
}
if
(
!
y_broadcasted
)
{
new_dims
.
push_back
(
N
);
}
auto
ddim_out
=
phi
::
make_ddim
(
new_dims
);
out
->
set_dims
(
ddim_out
);
out
->
set_dtype
(
phi
::
DataType
::
INT32
);
out
->
set_layout
(
x
.
layout
());
}
void
MatmulWithFlattenInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
int
x_num_col_dims
,
...
...
paddle/phi/infermeta/binary.h
浏览文件 @
27cc0df5
...
...
@@ -333,6 +333,12 @@ void MatmulInferMeta(const MetaTensor& x,
bool
trans_y
,
MetaTensor
*
out
);
void
MatmulInt8InferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
bool
trans_x
,
bool
trans_y
,
MetaTensor
*
out
);
void
MatmulWithFlattenInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
int
x_num_col_dims
,
...
...
paddle/phi/kernels/funcs/cublaslt.h
浏览文件 @
27cc0df5
...
...
@@ -39,25 +39,16 @@ const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};
class
CublasLtHelper
{
public:
CublasLtHelper
(
int
m
,
int
k
,
int
n
)
:
alpha_
(
1
),
beta_
(
0
),
m_
(
m
),
k_
(
k
),
n_
(
n
)
{
CublasLtHelper
(
int
m
,
int
k
,
int
n
,
cublasLtHandle_t
handle
)
:
handle_
(
handle
),
alpha_
(
1
),
beta_
(
0
),
m_
(
m
),
k_
(
k
),
n_
(
n
)
{
cublasStatus_t
status
;
// handle and matmul desc
status
=
dyl
::
cublasLtCreate
(
&
handle_
);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t
cudaComputeType
=
CUDA_R_32I
;
#else
cublasComputeType_t
cudaComputeType
=
CUBLAS_COMPUTE_32I
;
#endif
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
phi
::
errors
::
External
(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
// matmul desc
#if CUBLAS_VER_MAJOR < 11
status
=
dyl
::
cublasLtMatmulDescCreate
(
&
matmul_desc_
,
cudaComputeType
);
#else
...
...
@@ -179,7 +170,7 @@ class CublasLtHelper {
}
~
CublasLtHelper
()
{}
void
GEMM
(
int8_t
*
A_dev
,
void
GEMM
(
const
int8_t
*
A_dev
,
const
int8_t
*
B_dev
,
int32_t
*
C_dev
,
cudaStream_t
stream
,
...
...
@@ -226,14 +217,14 @@ class CublasLtHelper {
cublasLtMatmulAlgo_t
algo_
;
int32_t
alpha_
;
int32_t
beta_
;
int32_t
alpha_
=
1
;
int32_t
beta_
=
0
;
int
m_
;
int
k_
;
int
n_
;
int
m_
=
0
;
int
k_
=
0
;
int
n_
=
0
;
size_t
workspace_size_
;
size_t
workspace_size_
=
0
;
};
}
// namespace phi
paddle/phi/kernels/funcs/gemm_int8_helper.h
0 → 100644
浏览文件 @
27cc0df5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Paddle/paddle/phi/kernels/funcs/cublaslt.h"
namespace
phi
{
template
<
typename
T
>
class
Int8GEMMHelper
{
public:
Int8GEMMHelper
(
const
phi
::
GPUContext
&
dev_ctx
,
int
m
,
int
k
,
int
n
,
phi
::
DenseTensor
&
workspace
,
// NOLINT
phi
::
DenseTensor
&
input_workspace
,
// NOLINT
phi
::
DenseTensor
&
out_workspace
,
// NOLINT
int
quant_round_type
,
float
quant_max_bound
,
float
quant_min_bound
)
:
dev_ctx_
(
dev_ctx
),
m_
(
m
),
k_
(
k
),
n_
(
n
),
quant_round_type_
(
quant_round_type
),
quant_min_bound_
(
quant_min_bound
),
quant_max_bound_
(
quant_max_bound
),
workspace_
(
workspace
),
input_workspace_
(
input_workspace
),
out_workspace_
(
out_workspace
)
{
cublaslt_helper
=
std
::
make_unique
<
CublasLtHelper
<
int32_t
>>
(
m
,
k
,
n
,
dev_ctx
.
cublaslt_handle
());
}
void
Compute
(
const
phi
::
DenseTensor
*
input
,
const
phi
::
DenseTensor
*
weight
,
// int8, Need be transposed
const
phi
::
DenseTensor
*
dequant_out_scales
,
const
float
quant_in_scale
,
phi
::
DenseTensor
*
output
,
bool
quant_in
=
false
,
bool
dequant_out
=
false
)
{
phi
::
DenseTensor
input_tmp
,
out_tmp
;
if
(
quant_in
)
{
input_tmp
=
input_workspace_
;
LaunchQuantKernel
<
T
>
(
input
->
data
<
T
>
(),
input_tmp
.
data
<
int8_t
>
(),
quant_in_scale
,
m_
,
k_
,
quant_round_type_
,
quant_max_bound_
,
quant_min_bound_
,
dev_ctx_
.
stream
());
}
else
{
input_tmp
=
*
input
;
}
if
(
dequant_out
)
{
out_tmp
=
out_workspace_
;
}
else
{
out_tmp
=
*
output
;
}
cublaslt_helper
->
GEMM
(
input_tmp
.
data
<
int8_t
>
(),
weight
->
data
<
int8_t
>
(),
out_tmp
.
data
<
int32_t
>
(),
dev_ctx_
.
stream
(),
(
void
*
)
workspace_
.
data
<
int8_t
>
(),
workspace_
.
numel
());
if
(
dequant_out
)
{
auto
gpu_config
=
std
::
make_unique
<
GpuLaunchConfig
>
(
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx_
,
m_
*
n_
,
DequantKernelVecSize
));
LaunchDequantKernel
<
T
>
(
out_tmp
.
data
<
int32_t
>
(),
output
->
data
<
T
>
(),
m_
,
n_
,
dev_ctx_
.
stream
(),
gpu_config
.
get
(),
quant_in_scale
,
dequant_out_scales
->
data
<
float
>
());
}
}
private:
const
phi
::
GPUContext
&
dev_ctx_
;
int
m_
;
int
k_
;
int
n_
;
int
quant_round_type_
;
float
quant_max_bound_
;
float
quant_min_bound_
;
phi
::
DenseTensor
&
workspace_
;
// char
phi
::
DenseTensor
&
input_workspace_
;
// int8_t
phi
::
DenseTensor
&
out_workspace_
;
// int32_t
std
::
unique_ptr
<
CublasLtHelper
<
int32_t
>>
cublaslt_helper
;
};
}
// namespace phi
paddle/phi/kernels/funcs/quant_dequant.h
浏览文件 @
27cc0df5
...
...
@@ -61,7 +61,7 @@ __forceinline__ __device__ int8_t quant_helper(const T input,
}
template
<
typename
T
>
__global__
void
quantize_k
ernel
(
const
T
*
input
,
__global__
void
QuantK
ernel
(
const
T
*
input
,
char4
*
output
,
const
float
scale
,
const
int
m
,
...
...
@@ -88,7 +88,7 @@ __global__ void quantize_kernel(const T* input,
}
template
<
typename
T
>
void
quantize_kernel_launcher
(
const
T
*
input
,
void
LaunchQuantKernel
(
const
T
*
input
,
int8_t
*
output
,
const
float
scale
,
const
int
m
,
...
...
@@ -101,7 +101,7 @@ void quantize_kernel_launcher(const T* input,
dim3
grid
((
n
>>
2
+
31
)
/
32
,
(
m
+
31
)
/
32
);
dim3
block
(
32
,
32
);
quantize_k
ernel
<<<
grid
,
block
,
0
,
stream
>>>
(
input
,
QuantK
ernel
<<<
grid
,
block
,
0
,
stream
>>>
(
input
,
(
char4
*
)
output
,
// NOLINT
scale
,
m
,
...
...
@@ -112,7 +112,7 @@ void quantize_kernel_launcher(const T* input,
}
template
<
typename
T
,
int
VecSize
>
__global__
void
dequantize_k
ernel
(
T
*
output
,
__global__
void
DequantK
ernel
(
T
*
output
,
const
int32_t
*
input
,
const
int
m
,
// batch size
const
int
n
,
// hidden
...
...
@@ -142,7 +142,7 @@ __global__ void dequantize_kernel(T* output,
}
template
<
typename
T
>
void
dequantize_kernel_launcher
(
const
int32_t
*
input
,
void
LaunchDequantKernel
(
const
int32_t
*
input
,
T
*
output
,
const
int
m
,
// m
const
int
n
,
// n
...
...
@@ -150,7 +150,7 @@ void dequantize_kernel_launcher(const int32_t* input,
GpuLaunchConfig
*
gpu_config
,
const
float
quant_in_scale
,
const
float
*
dequant_out_scale_data
)
{
dequantize_k
ernel
<
T
,
DequantKernelVecSize
>
DequantK
ernel
<
T
,
DequantKernelVecSize
>
<<<
gpu_config
->
block_per_grid
,
gpu_config
->
thread_per_block
,
0
,
stream
>>>
(
output
,
input
,
m
,
n
,
quant_in_scale
,
dequant_out_scale_data
);
}
...
...
paddle/phi/kernels/gpu/matmul_kernel.cu
浏览文件 @
27cc0df5
...
...
@@ -30,6 +30,9 @@ PD_REGISTER_KERNEL(matmul,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
matmul_int8
,
GPU
,
ALL_LAYOUT
,
phi
::
MatmulInt8Kernel
,
int8_t
)
{}
PD_REGISTER_KERNEL
(
matmul_with_flatten
,
GPU
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h
浏览文件 @
27cc0df5
...
...
@@ -667,7 +667,8 @@ void LLMGemm(const phi::GPUContext& dev_ctx,
dev_ctx
.
Alloc
<
int32_t
>
(
&
int_out
);
{
auto
helper
=
std
::
make_unique
<
CublasLtHelper
>
(
m
,
k
,
n
);
auto
helper
=
std
::
make_unique
<
CublasLtHelper
>
(
m
,
k
,
n
,
dev_ctx
.
cublaslt_handle
());
helper
->
GEMM
(
quant_input
.
data
<
int8_t
>
(),
weight
->
data
<
int8_t
>
(),
int_out
.
data
<
int32_t
>
(),
...
...
paddle/phi/kernels/impl/matmul_kernel_impl.h
浏览文件 @
27cc0df5
...
...
@@ -16,11 +16,15 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/autotune/cache_base.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/phi/kernels/funcs/cublaslt.h"
#endif
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#endif
...
...
@@ -948,6 +952,15 @@ struct MatMulDispatcher<phi::GPUContext, T> {
#endif
}
};
static
phi
::
Allocator
::
AllocationPtr
GetWorkspace
(
const
phi
::
GPUContext
&
ctx
,
size_t
workspace_size
)
{
return
phi
::
memory_utils
::
Alloc
(
ctx
.
GetPlace
(),
workspace_size
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
ctx
.
stream
())));
}
#endif // PADDLE_WITH_CUDA
template
<
typename
Context
,
typename
T
>
...
...
@@ -964,6 +977,107 @@ void MatMulFunction(const Context& ctx,
ctx
,
x
,
y
,
x_dims
,
y_dims
,
out
,
trans_x
,
trans_y
,
flag
);
}
template
<
typename
Context
>
void
MatMulInt8Function
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
std
::
vector
<
std
::
int64_t
>&
x_dims
,
const
std
::
vector
<
std
::
int64_t
>&
y_dims
,
DenseTensor
*
out
,
bool
trans_x
,
bool
trans_y
)
{
PADDLE_ENFORCE_EQ
(
x
.
dtype
(),
DataType
::
INT8
,
phi
::
errors
::
InvalidArgument
(
"The type of input(x) used in int8 matmul must be (%s) does not "
"match the "
"type of data (%s) currently contained in the container."
,
phi
::
CppTypeToDataType
<
int8_t
>::
Type
(),
x
.
dtype
()));
PADDLE_ENFORCE_EQ
(
y
.
dtype
(),
DataType
::
INT8
,
phi
::
errors
::
InvalidArgument
(
"The type of input(y) used in int8 matmul must be (%s) does not "
"match the "
"type of data (%s) currently contained in the container."
,
phi
::
CppTypeToDataType
<
int8_t
>::
Type
(),
x
.
dtype
()));
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020
const
int
x_ndim
=
x_dims
.
size
();
const
int
y_ndim
=
y_dims
.
size
();
PADDLE_ENFORCE_EQ
(
x_ndim
,
2
,
phi
::
errors
::
InvalidArgument
(
"[INT8 GEMM] The number of dims of input(x) "
"must be equal to 2 but received %d"
,
x_ndim
));
PADDLE_ENFORCE_EQ
(
y_ndim
,
2
,
phi
::
errors
::
InvalidArgument
(
"[INT8 GEMM] The number of dims of input(x) "
"must be equal to 2 but received %d"
,
y_ndim
));
PADDLE_ENFORCE_EQ
(
trans_x
,
false
,
phi
::
errors
::
InvalidArgument
(
"[INT8 GEMM] Input(x) must be not "
"transposed to acheive better performance"
));
PADDLE_ENFORCE_EQ
(
trans_y
,
true
,
phi
::
errors
::
InvalidArgument
(
"[INT8 GEMM] Input(y) must be transposed to "
"acheive better performance"
));
const
int
M
=
trans_x
?
x_dims
[
x_ndim
-
1
]
:
x_dims
[
x_ndim
-
2
];
const
int
K
=
trans_x
?
x_dims
[
x_ndim
-
2
]
:
x_dims
[
x_ndim
-
1
];
if
(
trans_y
)
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
1
],
K
,
phi
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d"
,
y_ndim
-
1
,
K
,
y_ndim
-
1
,
y_dims
[
y_ndim
-
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
2
],
K
,
phi
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d"
,
y_ndim
-
2
,
K
,
y_ndim
-
2
,
y_dims
[
y_ndim
-
2
]));
}
const
int
N
=
trans_y
?
y_dims
[
y_ndim
-
2
]
:
y_dims
[
y_ndim
-
1
];
size_t
workspace_size
=
static_cast
<
size_t
>
(
4
)
*
1024
*
1024
;
phi
::
Allocator
::
AllocationPtr
workspace
=
GetWorkspace
(
ctx
,
workspace_size
);
// TODO(wufeisheng): cublaslt_helper is a temp scheme for Int8 GEMM,
// and releted functions need to be integrated into
// phi::funcs::MatmulWithCublasLt
auto
cublaslt_helper
=
CublasLtHelper
(
M
,
K
,
N
,
ctx
.
cublaslt_handle
());
ctx
.
template
Alloc
<
int32_t
>(
out
);
cublaslt_helper
.
GEMM
(
x
.
data
<
int8_t
>
(),
y
.
data
<
int8_t
>
(),
out
->
data
<
int32_t
>
(),
ctx
.
stream
(),
workspace
->
ptr
());
#else
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"MatmulInt8 op needs paddle with cuda and cuda version >= 11.2"
));
#endif
}
template
<
typename
T
,
typename
Context
>
void
MatmulKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
...
...
@@ -987,6 +1101,29 @@ void MatmulKernel(const Context& ctx,
ctx
,
x
,
y
,
x_dims
,
y_dims
,
out
,
transpose_x
,
transpose_y
);
}
template
<
typename
T
,
typename
Context
>
void
MatmulInt8Kernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
bool
transpose_x
,
bool
transpose_y
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_NE
(
phi
::
product
(
x
.
dims
()),
0
,
phi
::
errors
::
InvalidArgument
(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "
));
PADDLE_ENFORCE_NE
(
phi
::
product
(
y
.
dims
()),
0
,
phi
::
errors
::
InvalidArgument
(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "
));
const
std
::
vector
<
std
::
int64_t
>
x_dims
=
vectorize
(
x
.
dims
());
const
std
::
vector
<
std
::
int64_t
>
y_dims
=
vectorize
(
y
.
dims
());
MatMulInt8Function
<
Context
>
(
ctx
,
x
,
y
,
x_dims
,
y_dims
,
out
,
transpose_x
,
transpose_y
);
}
template
<
typename
T
,
typename
Context
>
void
MatmulWithFlattenKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
...
...
test/legacy_test/CMakeLists.txt
浏览文件 @
27cc0df5
...
...
@@ -157,6 +157,7 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
list
(
REMOVE_ITEM TEST_OPS test_rms_norm_op
)
list
(
REMOVE_ITEM TEST_OPS test_linear_compress
)
list
(
REMOVE_ITEM TEST_OPS test_matmul_int8_op
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_checkpoint_saver
)
...
...
test/legacy_test/test_matmul_int8_op.py
0 → 100644
浏览文件 @
27cc0df5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
test_sparse_attention_op
import
get_cuda_version
import
paddle
from
paddle.fluid
import
core
paddle
.
disable_static
()
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11020
or
paddle
.
device
.
cuda
.
get_device_capability
()[
0
]
<
8
,
"MatmulInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8"
,
)
class
TestMatmulInt8
(
unittest
.
TestCase
):
"""
Test matmul int8
Only NT (Non-Transposed-A and Transposed-B) is supported
"""
def
config
(
self
):
self
.
dtype
=
'int8'
self
.
rtol
=
1e-5
self
.
atol
=
1e-2
self
.
bias
=
True
self
.
m
=
8
self
.
k
=
64
self
.
n
=
64
def
setUp
(
self
):
self
.
config
()
self
.
input_a_np
=
np
.
random
.
randint
(
-
127
,
127
,
[
self
.
m
,
self
.
k
]).
astype
(
'int32'
)
self
.
input_b_np
=
np
.
random
.
randint
(
-
127
,
127
,
[
self
.
k
,
self
.
n
]).
astype
(
'int32'
)
self
.
input_a
=
paddle
.
to_tensor
(
self
.
input_a_np
,
dtype
=
self
.
dtype
)
self
.
input_b
=
paddle
.
to_tensor
(
self
.
input_b_np
.
transpose
((
1
,
0
)),
dtype
=
self
.
dtype
)
def
get_reference_out
(
self
):
out
=
np
.
dot
(
self
.
input_a_np
,
self
.
input_b_np
)
return
out
def
get_op_out
(
self
):
out
=
paddle
.
_C_ops
.
matmul_int8
(
self
.
input_a
,
self
.
input_b
,
False
,
True
)
return
out
.
numpy
()
def
test_matmul_int8
(
self
):
out_real
=
self
.
get_op_out
()
out_expect
=
self
.
get_reference_out
()
np
.
testing
.
assert_allclose
(
out_real
,
out_expect
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录