Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f4290a92
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f4290a92
编写于
7月 11, 2023
作者:
FormlessUnit
提交者:
GitHub
7月 11, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Linear compress (#55128)
* rename weight_only/llm.int8
上级
ab46b14c
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
414 addition
and
93 deletion
+414
-93
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+6
-7
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+13
-7
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+3
-2
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+6
-11
paddle/phi/kernels/cpu/quant_for_compress_kernel.cc
paddle/phi/kernels/cpu/quant_for_compress_kernel.cc
+15
-9
paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu
paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu
+5
-5
paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu
paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu
+117
-0
paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h
paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h
+0
-0
paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h
paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h
+217
-43
paddle/phi/kernels/llm_int8_matmul_kernel.h
paddle/phi/kernels/llm_int8_matmul_kernel.h
+1
-1
paddle/phi/kernels/weight_only_matmul_kernel.h
paddle/phi/kernels/weight_only_matmul_kernel.h
+1
-1
python/paddle/nn/functional/common.py
python/paddle/nn/functional/common.py
+12
-4
python/paddle/nn/layer/common.py
python/paddle/nn/layer/common.py
+6
-3
test/legacy_test/test_linear_compress.py
test/legacy_test/test_linear_compress.py
+12
-0
未找到文件。
paddle/phi/api/yaml/ops.yaml
浏览文件 @
f4290a92
...
@@ -1367,14 +1367,14 @@
...
@@ -1367,14 +1367,14 @@
data_transform
:
data_transform
:
skip_transform
:
out_size, size_tensor, scale_tensor
skip_transform
:
out_size, size_tensor, scale_tensor
-
op
:
llm_int8_mat
_
mul
-
op
:
llm_int8_matmul
args
:
(Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0)
args
:
(Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0)
output
:
Tensor(out)
output
:
Tensor(out)
infer_meta
:
infer_meta
:
func
:
LLMInt8Mat
M
ulInferMeta
func
:
LLMInt8Mat
m
ulInferMeta
param
:
[
x
,
weight
]
param
:
[
x
,
weight
]
kernel
:
kernel
:
func
:
llm_int8_mat
_
mul
func
:
llm_int8_matmul
data_type
:
x
data_type
:
x
-
op
:
log
-
op
:
log
...
@@ -2602,14 +2602,13 @@
...
@@ -2602,14 +2602,13 @@
intermediate
:
warprnntgrad
intermediate
:
warprnntgrad
backward
:
warprnnt_grad
backward
:
warprnnt_grad
-
op
:
weight_only_mat
_
mul
-
op
:
weight_only_matmul
args
:
(Tensor x, Tensor weight, Tensor weight_scale)
args
:
(Tensor x, Tensor weight, Tensor weight_scale)
output
:
Tensor(out)
output
:
Tensor(out)
infer_meta
:
infer_meta
:
func
:
WeightOnlyMatMulInferMeta
func
:
WeightOnlyMatmulInferMeta
param
:
[
x
,
weight
]
kernel
:
kernel
:
func
:
weight_only_mat
_
mul
func
:
weight_only_matmul
data_type
:
x
data_type
:
x
-
op
:
weighted_sample_neighbors
-
op
:
weighted_sample_neighbors
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
f4290a92
...
@@ -3572,7 +3572,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
...
@@ -3572,7 +3572,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count
->
set_dtype
(
DataType
::
INT32
);
out_count
->
set_dtype
(
DataType
::
INT32
);
}
}
void
LLMInt8Mat
M
ulInferMeta
(
const
MetaTensor
&
x
,
void
LLMInt8Mat
m
ulInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
weight
,
const
MetaTensor
&
weight
,
MetaTensor
*
out
)
{
MetaTensor
*
out
)
{
auto
x_dims
=
x
.
dims
();
auto
x_dims
=
x
.
dims
();
...
@@ -3595,25 +3595,31 @@ void LLMInt8MatMulInferMeta(const MetaTensor& x,
...
@@ -3595,25 +3595,31 @@ void LLMInt8MatMulInferMeta(const MetaTensor& x,
out
->
set_dtype
(
x
.
dtype
());
out
->
set_dtype
(
x
.
dtype
());
}
}
void
WeightOnlyMat
M
ulInferMeta
(
const
MetaTensor
&
x
,
void
WeightOnlyMat
m
ulInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
weight
,
const
MetaTensor
&
weight
,
const
MetaTensor
&
weight_scale
,
MetaTensor
*
out
)
{
MetaTensor
*
out
)
{
auto
x_dims
=
x
.
dims
();
auto
x_dims
=
x
.
dims
();
auto
w_dims
=
weight
.
dims
();
auto
w_dims
=
weight
.
dims
();
auto
n
=
weight_scale
.
dims
()[
0
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
w_dims
.
size
(),
2UL
,
2UL
,
errors
::
InvalidArgument
(
"The input(weight) must be a 2D Tensor."
));
errors
::
InvalidArgument
(
"The input(weight) must be a 2D Tensor."
));
PADDLE_ENFORCE_EQ
(
weight_scale
.
dims
().
size
(),
1UL
,
errors
::
InvalidArgument
(
"The input(weight_scale) must be a 1D Tensor."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
x_dims
[
x_dims
.
size
()
-
1
],
x_dims
[
x_dims
.
size
()
-
1
],
w_dims
[
0
],
w_dims
[
1
],
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Input(X) dim[-1] and Input(Weight) dim[
0
] should be euqal."
"Input(X) dim[-1] and Input(Weight) dim[
1
] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[
0
](%s)"
,
"But received Input(X) dim[-1](%s) != Input(Weight) dim[
1
](%s)"
,
x_dims
[
x_dims
.
size
()
-
1
],
x_dims
[
x_dims
.
size
()
-
1
],
w_dims
[
0
]));
w_dims
[
1
]));
auto
out_dims
=
x_dims
;
auto
out_dims
=
x_dims
;
out_dims
[
out_dims
.
size
()
-
1
]
=
w_dims
[
1
]
;
out_dims
[
out_dims
.
size
()
-
1
]
=
n
;
out
->
set_dims
(
out_dims
);
out
->
set_dims
(
out_dims
);
out
->
set_dtype
(
x
.
dtype
());
out
->
set_dtype
(
x
.
dtype
());
}
}
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
f4290a92
...
@@ -690,12 +690,13 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query,
...
@@ -690,12 +690,13 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query,
bool
causal
,
bool
causal
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
LLMInt8Mat
M
ulInferMeta
(
const
MetaTensor
&
x
,
void
LLMInt8Mat
m
ulInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
weight
,
const
MetaTensor
&
weight
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
WeightOnlyMat
M
ulInferMeta
(
const
MetaTensor
&
x
,
void
WeightOnlyMat
m
ulInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
weight
,
const
MetaTensor
&
weight
,
const
MetaTensor
&
weight_scale
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
FusedRopeInferMeta
(
const
MetaTensor
&
q
,
void
FusedRopeInferMeta
(
const
MetaTensor
&
q
,
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
f4290a92
...
@@ -5086,22 +5086,17 @@ void QuantForCompressInferMeta(const MetaTensor& x,
...
@@ -5086,22 +5086,17 @@ void QuantForCompressInferMeta(const MetaTensor& x,
x_dims
[
0
]));
x_dims
[
0
]));
std
::
vector
<
int64_t
>
dim_scale
({
x_dims
[
1
]});
std
::
vector
<
int64_t
>
dim_scale
({
x_dims
[
1
]});
std
::
vector
<
int64_t
>
dim_out
;
std
::
vector
<
int64_t
>
dim_out
;
if
(
layout
==
"weight_only"
)
{
if
(
bits
==
8
)
{
dim_out
=
std
::
vector
<
int64_t
>
({
x_dims
[
0
],
x_dims
[
1
]});
}
else
if
(
layout
==
"llm.int8"
)
{
dim_out
=
std
::
vector
<
int64_t
>
({
x_dims
[
1
],
x_dims
[
0
]});
dim_out
=
std
::
vector
<
int64_t
>
({
x_dims
[
1
],
x_dims
[
0
]});
}
else
if
(
bits
==
4
)
{
dim_out
=
std
::
vector
<
int64_t
>
({
x_dims
[
1
]
/
2
,
x_dims
[
0
]});
}
else
{
}
else
{
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The bit must be 8 or 4, but got %d"
,
bits
);
"The layout must be weight_only or llm.int8, but got %s"
,
layout
);
}
}
out
->
set_dims
(
phi
::
make_ddim
(
dim_out
));
out
->
set_dims
(
phi
::
make_ddim
(
dim_out
));
// TODO(lizhenyun) support weight_only int4
if
(
bits
==
8
)
{
out
->
set_dtype
(
DataType
::
INT8
);
out
->
set_dtype
(
DataType
::
INT8
);
}
else
{
phi
::
errors
::
Fatal
(
"The bits only support 8, but got[%d]"
,
bits
);
}
scale
->
set_dims
(
phi
::
make_ddim
(
dim_scale
));
scale
->
set_dims
(
phi
::
make_ddim
(
dim_scale
));
scale
->
set_dtype
(
DataType
::
FLOAT32
);
scale
->
set_dtype
(
DataType
::
FLOAT32
);
}
}
...
...
paddle/phi/kernels/cpu/quant_for_compress_kernel.cc
浏览文件 @
f4290a92
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
namespace
phi
{
namespace
phi
{
template
<
typename
DeviceContext
,
typename
T
,
typename
D
>
template
<
typename
DeviceContext
,
typename
T
,
typename
D
,
int
bits
>
void
quant_compute
(
const
DeviceContext
&
dev_ctx
,
void
quant_compute
(
const
DeviceContext
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
DenseTensor
*
out
,
DenseTensor
*
out
,
...
@@ -59,15 +59,15 @@ void quant_compute(const DeviceContext& dev_ctx,
...
@@ -59,15 +59,15 @@ void quant_compute(const DeviceContext& dev_ctx,
per_channel_scale
(
scale_data
,
x_data
,
m
,
n
);
per_channel_scale
(
scale_data
,
x_data
,
m
,
n
);
per_channel_quant
(
x_int_data
,
x_data
,
scale_data
,
m
,
n
);
per_channel_quant
<
T
,
bits
>
(
x_int_data
,
x_data
,
scale_data
,
m
,
n
);
if
(
layout
==
"weight_only"
)
{
if
(
layout
==
"weight_only"
)
{
permute_B_rows_for_mixed_gemm
(
permute_B_rows_for_mixed_gemm
<
bits
>
(
int_processed_data
,
x_int_data
,
std
::
vector
<
size_t
>
{
m
,
n
},
(
int64_t
)
80
);
int_processed_data
,
x_int_data
,
std
::
vector
<
size_t
>
{
m
,
n
},
(
int64_t
)
80
);
row_major_to_column_major
(
subbyte_transpose_impl
<
bits
>
(
int_processed_2_data
,
int_processed_data
,
std
::
vector
<
size_t
>
{
m
,
n
});
int_processed_2_data
,
int_processed_data
,
std
::
vector
<
size_t
>
{
m
,
n
});
interleave_column_major_tensor
(
interleave_column_major_tensor
<
bits
>
(
out_data
,
int_processed_2_data
,
std
::
vector
<
size_t
>
{
m
,
n
});
out_data
,
int_processed_2_data
,
std
::
vector
<
size_t
>
{
m
,
n
});
add_bias_and_interleave_in
t8s_inplace
(
out_data
,
num
);
add_bias_and_interleave_in
place
<
bits
>
(
out_data
,
num
);
}
else
if
(
layout
==
"llm.int8"
)
{
}
else
if
(
layout
==
"llm.int8"
)
{
std
::
vector
<
int
>
axis
=
{
1
,
0
};
std
::
vector
<
int
>
axis
=
{
1
,
0
};
funcs
::
Transpose
<
DeviceContext
,
int8_t
,
2
>
trans
;
funcs
::
Transpose
<
DeviceContext
,
int8_t
,
2
>
trans
;
...
@@ -88,9 +88,16 @@ void QuantForCompressKernel(const Context& dev_ctx,
...
@@ -88,9 +88,16 @@ void QuantForCompressKernel(const Context& dev_ctx,
if
(
bits
==
8
)
{
if
(
bits
==
8
)
{
dev_ctx
.
template
Alloc
<
int8_t
>(
out
);
dev_ctx
.
template
Alloc
<
int8_t
>(
out
);
dev_ctx
.
template
Alloc
<
float
>(
scale
);
dev_ctx
.
template
Alloc
<
float
>(
scale
);
quant_compute
<
Context
,
T
,
int8_t
>
(
dev_ctx
,
x
,
out
,
scale
,
layout
);
quant_compute
<
Context
,
T
,
int8_t
,
8
>
(
dev_ctx
,
x
,
out
,
scale
,
layout
);
}
else
if
(
bits
==
4
&&
layout
==
"weight_only"
)
{
dev_ctx
.
template
Alloc
<
int8_t
>(
out
);
dev_ctx
.
template
Alloc
<
float
>(
scale
);
quant_compute
<
Context
,
T
,
int8_t
,
4
>
(
dev_ctx
,
x
,
out
,
scale
,
layout
);
}
else
{
}
else
{
phi
::
errors
::
Unimplemented
(
"The bits only support 8, but got[%d]"
,
bits
);
phi
::
errors
::
Unimplemented
(
"The bits only support 8 or weight_only 4, but got[%s] [%d]"
,
layout
,
bits
);
}
}
// VLOG(0) << "x: " << x.dtype() << x;
// VLOG(0) << "x: " << x.dtype() << x;
// VLOG(0) << "out: " << out->dtype() << *out;
// VLOG(0) << "out: " << out->dtype() << *out;
...
@@ -102,5 +109,4 @@ PD_REGISTER_KERNEL(quant_for_compress,
...
@@ -102,5 +109,4 @@ PD_REGISTER_KERNEL(quant_for_compress,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
QuantForCompressKernel
,
phi
::
QuantForCompressKernel
,
float
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/llm_int8_mat
_
mul_kernel.cu
→
paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu
浏览文件 @
f4290a92
...
@@ -12,12 +12,12 @@
...
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/phi/kernels/llm_int8_mat
_
mul_kernel.h"
#include "paddle/phi/kernels/llm_int8_matmul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/impl/llm_int8_mat
_
mul_kernel_impl.h"
#include "paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h"
#endif
#endif
namespace
phi
{
namespace
phi
{
...
@@ -56,7 +56,7 @@ void llm_int8_compute(const Context& dev_ctx,
...
@@ -56,7 +56,7 @@ void llm_int8_compute(const Context& dev_ctx,
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
LLMInt8Mat
M
ulKernel
(
const
Context
&
dev_ctx
,
void
LLMInt8Mat
m
ulKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight_scale
,
const
DenseTensor
&
weight_scale
,
...
@@ -68,8 +68,8 @@ void LLMInt8MatMulKernel(const Context& dev_ctx,
...
@@ -68,8 +68,8 @@ void LLMInt8MatMulKernel(const Context& dev_ctx,
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
llm_int8_mat
_
mul
,
PD_REGISTER_KERNEL
(
llm_int8_matmul
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
LLMInt8Mat
M
ulKernel
,
phi
::
LLMInt8Mat
m
ulKernel
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/weight_only_mat
_
mul_kernel.cu
→
paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu
浏览文件 @
f4290a92
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/phi/kernels/weight_only_mat
_
mul_kernel.h"
#include "paddle/phi/kernels/weight_only_matmul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
WeightOnlyMat
M
ulKernel
(
const
Context
&
dev_ctx
,
void
WeightOnlyMat
m
ulKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight_scale
,
const
DenseTensor
&
weight_scale
,
...
@@ -32,17 +32,26 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
...
@@ -32,17 +32,26 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
T
>(
out
);
dev_ctx
.
template
Alloc
<
T
>(
out
);
const
auto
x_dims
=
x
.
dims
();
const
auto
x_dims
=
x
.
dims
();
const
auto
w_dims
=
weight
.
dims
();
const
auto
w_dims
=
weight
.
dims
();
int
n
=
weight_scale
.
dims
()[
0
];
int
quant_bit
=
0
;
if
(
n
%
w_dims
[
0
]
==
0
)
{
quant_bit
=
w_dims
[
0
]
*
8
/
n
;
}
else
{
errors
::
InvalidArgument
(
"w_dims[0] must be divisible by weight_scale.dims()[0]"
);
}
int
k
=
w_dims
[
0
];
int
k
=
w_dims
[
1
];
int
n
=
w_dims
[
1
];
int
m
=
x
.
numel
()
/
k
;
int
m
=
x
.
numel
()
/
k
;
switch
(
quant_bit
)
{
case
8
:
{
auto
mixed_gemm_runner
=
auto
mixed_gemm_runner
=
CutlassFpAIntBGemmRunner
<
typename
PDDataTypeTraits
<
T
>::
DataType
,
CutlassFpAIntBGemmRunner
<
typename
PDDataTypeTraits
<
T
>::
DataType
,
uint8_t
>
();
uint8_t
>
();
int
mixgemm_max_size
=
std
::
max
(
n
,
k
);
int
mixgemm_max_size
=
std
::
max
(
n
,
k
);
DenseTensor
mixgemm_workspace
;
DenseTensor
mixgemm_workspace
;
int64_t
mixgemm_workspace_size_bytes
=
int64_t
mixgemm_workspace_size_bytes
=
mixed_gemm_runner
.
getWorkspaceSize
(
mixed_gemm_runner
.
getWorkspaceSize
(
m
,
mixgemm_max_size
,
mixgemm_max_size
);
m
,
mixgemm_max_size
,
mixgemm_max_size
);
mixgemm_workspace
.
Resize
({
mixgemm_workspace_size_bytes
});
mixgemm_workspace
.
Resize
({
mixgemm_workspace_size_bytes
});
dev_ctx
.
template
Alloc
<
uint8_t
>(
&
mixgemm_workspace
);
dev_ctx
.
template
Alloc
<
uint8_t
>(
&
mixgemm_workspace
);
...
@@ -53,21 +62,56 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
...
@@ -53,21 +62,56 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
x
.
data
<
T
>
()),
x
.
data
<
T
>
()),
reinterpret_cast
<
const
uint8_t
*>
(
weight
.
data
<
int8_t
>
()),
reinterpret_cast
<
const
uint8_t
*>
(
weight
.
data
<
int8_t
>
()),
reinterpret_cast
<
const
float
*>
(
weight_scale
.
data
<
float
>
()),
reinterpret_cast
<
const
float
*>
(
weight_scale
.
data
<
float
>
()),
reinterpret_cast
<
typename
PDDataTypeTraits
<
T
>::
DataType
*>
(
out
->
data
<
T
>
()),
reinterpret_cast
<
typename
PDDataTypeTraits
<
T
>::
DataType
*>
(
out
->
data
<
T
>
()),
m
,
m
,
n
,
n
,
k
,
k
,
mixgemm_workspace_data
,
mixgemm_workspace_data
,
mixgemm_workspace_size_bytes
,
mixgemm_workspace_size_bytes
,
dev_ctx
.
stream
());
dev_ctx
.
stream
());
}
break
;
case
4
:
{
auto
mixed_gemm_runner
=
CutlassFpAIntBGemmRunner
<
typename
PDDataTypeTraits
<
T
>::
DataType
,
cutlass
::
uint4b_t
>
();
int
mixgemm_max_size
=
std
::
max
(
n
,
k
);
DenseTensor
mixgemm_workspace
;
int64_t
mixgemm_workspace_size_bytes
=
mixed_gemm_runner
.
getWorkspaceSize
(
m
,
mixgemm_max_size
,
mixgemm_max_size
);
mixgemm_workspace
.
Resize
({
mixgemm_workspace_size_bytes
});
dev_ctx
.
template
Alloc
<
uint8_t
>(
&
mixgemm_workspace
);
char
*
mixgemm_workspace_data
=
reinterpret_cast
<
char
*>
(
mixgemm_workspace
.
data
<
uint8_t
>
());
mixed_gemm_runner
.
gemm
(
reinterpret_cast
<
const
typename
PDDataTypeTraits
<
T
>::
DataType
*>
(
x
.
data
<
T
>
()),
reinterpret_cast
<
const
cutlass
::
uint4b_t
*>
(
weight
.
data
<
int8_t
>
()),
reinterpret_cast
<
const
float
*>
(
weight_scale
.
data
<
float
>
()),
reinterpret_cast
<
typename
PDDataTypeTraits
<
T
>::
DataType
*>
(
out
->
data
<
T
>
()),
m
,
n
,
k
,
mixgemm_workspace_data
,
mixgemm_workspace_size_bytes
,
dev_ctx
.
stream
());
}
break
;
default:
PADDLE_THROW
(
errors
::
Unimplemented
(
"Quant_bits (%d) is not supported when gemm "
,
quant_bit
));
break
;
}
#else
#else
LOG
(
ERROR
)
<<
"Please compile with cutlass to EnableUseCutlass()"
;
LOG
(
ERROR
)
<<
"Please compile with cutlass to EnableUseCutlass()"
;
#endif
#endif
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
weight_only_mat
_
mul
,
PD_REGISTER_KERNEL
(
weight_only_matmul
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
WeightOnlyMat
M
ulKernel
,
phi
::
WeightOnlyMat
m
ulKernel
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/impl/llm_int8_mat
_
mul_kernel_impl.h
→
paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h
浏览文件 @
f4290a92
文件已移动
paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h
浏览文件 @
f4290a92
...
@@ -39,51 +39,109 @@ void per_channel_scale(float* scale, const T* input, size_t m, size_t n) {
...
@@ -39,51 +39,109 @@ void per_channel_scale(float* scale, const T* input, size_t m, size_t n) {
}
}
}
}
template
<
typename
T
,
typename
D
>
template
<
typename
T
,
int
quant_bit
=
8
>
void
per_channel_quant
(
void
per_channel_quant
(
int8_t
*
output
,
D
*
output
,
const
T
*
input
,
const
float
*
scale
,
size_t
m
,
size_t
n
)
{
const
T
*
input
,
for
(
size_t
i
=
0
;
i
<
m
;
i
++
)
{
const
float
*
scale
,
for
(
size_t
j
=
0
;
j
<
n
;
j
++
)
{
size_t
num_rows
,
output
[
i
*
n
+
j
]
=
static_cast
<
D
>
(
size_t
num_cols
)
{
round
(
static_cast
<
float
>
(
input
[
i
*
n
+
j
])
/
scale
[
j
]));
size_t
bytes_per_out_col
=
num_cols
*
quant_bit
/
8
;
for
(
size_t
ii
=
0
;
ii
<
num_rows
;
++
ii
)
{
int8_t
*
current_quantized_weight_row
=
output
+
ii
*
bytes_per_out_col
;
const
T
*
current_weight_row
=
input
+
ii
*
num_cols
;
for
(
size_t
jj
=
0
;
jj
<
bytes_per_out_col
;
++
jj
)
{
if
(
quant_bit
==
8
)
{
const
float
col_scale
=
scale
[
jj
];
const
float
weight_elt
=
static_cast
<
float
>
(
current_weight_row
[
jj
]);
const
float
scaled_weight
=
round
(
weight_elt
/
col_scale
);
const
int8_t
clipped_weight
=
static_cast
<
int8_t
>
(
std
::
max
(
-
127.
f
,
std
::
min
(
127.
f
,
scaled_weight
)));
current_quantized_weight_row
[
jj
]
=
clipped_weight
;
}
else
if
(
quant_bit
==
4
)
{
// We will pack two int4 elements per iteration of the inner loop.
int8_t
packed_int4s
=
0
;
for
(
int
packed_idx
=
0
;
packed_idx
<
2
;
++
packed_idx
)
{
const
size_t
input_idx
=
2
*
jj
+
packed_idx
;
if
(
input_idx
<
num_cols
)
{
const
float
col_scale
=
scale
[
input_idx
];
const
float
weight_elt
=
static_cast
<
float
>
(
current_weight_row
[
input_idx
]);
const
float
scaled_weight
=
round
(
weight_elt
/
col_scale
);
int
int_weight
=
static_cast
<
int
>
(
scaled_weight
);
const
int8_t
clipped_weight
=
std
::
max
(
-
7
,
std
::
min
(
7
,
int_weight
));
// Kill the sign extension bits (hence 0x0F mask) then shift to
// upper bits if packing the second int4 and or the bits into the
// final result.
packed_int4s
|=
((
clipped_weight
&
0x0F
)
<<
(
4
*
packed_idx
));
}
}
current_quantized_weight_row
[
jj
]
=
packed_int4s
;
}
else
{
phi
::
errors
::
Unimplemented
(
"Unsupported quantization bits: %d"
,
quant_bit
);
}
}
}
}
}
void
row_major_to_column_major
(
int8_t
*
col_major_tensor
,
const
int8_t
*
row_major_tensor
,
const
std
::
vector
<
size_t
>&
shape
)
{
size_t
m
=
shape
[
0
];
size_t
n
=
shape
[
1
];
for
(
size_t
i
=
0
;
i
<
m
*
n
;
i
++
)
{
size_t
im
=
i
/
n
;
size_t
in
=
i
%
n
;
col_major_tensor
[
in
*
m
+
im
]
=
row_major_tensor
[
im
*
n
+
in
];
}
}
}
}
void
add_bias_and_interleave_int8s_inplace
(
int8_t
*
int8_tensor_ptr
,
template
<
int
quant_bit
=
8
>
size_t
num_elts
)
{
void
add_bias_and_interleave_inplace
(
int8_t
*
tensor_ptr
,
size_t
num_elts
)
{
int8_t
*
int8_tensor
=
reinterpret_cast
<
int8_t
*>
(
int8_tensor_ptr
);
const
size_t
num_bytes
=
num_elts
*
quant_bit
/
8
;
for
(
size_t
ii
=
0
;
ii
<
num_elts
;
++
ii
)
{
int8_tensor
[
ii
]
=
static_cast
<
int8_t
>
(
static_cast
<
int
>
(
int8_tensor
[
ii
])
+
128
);
}
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
// match the int4 layout. This has no performance benefit and is purely so
// that int4 and int8 have the same layout. Pictorially, this does the
// following: bit 32 0
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
//
// And it will rearrange the output 32 bit register to be the following:
// bit 32 0
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
for
(
size_t
ii
=
0
;
ii
<
num_bytes
;
++
ii
)
{
if
(
quant_bit
==
8
)
{
tensor_ptr
[
ii
]
=
static_cast
<
int8_t
>
(
static_cast
<
int
>
(
tensor_ptr
[
ii
])
+
128
);
}
else
{
int8_t
transformed_packed_int4s
=
0
;
int8_t
transformed_first_elt
=
(
int8_t
(
tensor_ptr
[
ii
]
<<
4
)
>>
4
)
+
8
;
// The double shift here is to ensure sign extension
int8_t
transformed_second_elt
=
(
tensor_ptr
[
ii
]
>>
4
)
+
8
;
if
(
!
(
transformed_first_elt
>=
0
&&
transformed_first_elt
<=
15
))
{
phi
::
errors
::
InvalidArgument
(
"Illegal result for int4 transform (first elt)"
);
}
if
(
!
(
transformed_second_elt
>=
0
&&
transformed_second_elt
<=
15
))
{
phi
::
errors
::
InvalidArgument
(
"Illegal result for int4 transform (second elt)"
);
}
// We don't need to mask in these ops since everything should be in the
// range 0-15
transformed_packed_int4s
|=
transformed_first_elt
;
transformed_packed_int4s
|=
(
transformed_second_elt
<<
4
);
tensor_ptr
[
ii
]
=
transformed_packed_int4s
;
}
}
if
(
quant_bit
==
8
)
{
for
(
size_t
base
=
0
;
base
<
num_elts
;
base
+=
4
)
{
for
(
size_t
base
=
0
;
base
<
num_elts
;
base
+=
4
)
{
std
::
swap
(
int8_tensor
[
base
+
1
],
int8_tensor
[
base
+
2
]);
std
::
swap
(
tensor_ptr
[
base
+
1
],
tensor_ptr
[
base
+
2
]);
}
}
else
{
const
size_t
num_registers
=
num_bytes
/
4
;
uint32_t
*
register_ptr
=
reinterpret_cast
<
uint32_t
*>
(
tensor_ptr
);
for
(
size_t
ii
=
0
;
ii
<
num_registers
;
++
ii
)
{
const
uint32_t
current_register
=
register_ptr
[
ii
];
uint32_t
transformed_register
=
0
;
for
(
int
dest_idx
=
0
;
dest_idx
<
8
;
++
dest_idx
)
{
const
int
src_idx
=
dest_idx
<
4
?
2
*
dest_idx
:
2
*
(
dest_idx
-
4
)
+
1
;
const
int
src_shift
=
4
*
src_idx
;
const
int
dest_shift
=
4
*
dest_idx
;
const
uint32_t
src_bits
=
(
current_register
>>
src_shift
)
&
0xF
;
transformed_register
|=
(
src_bits
<<
dest_shift
);
}
register_ptr
[
ii
]
=
transformed_register
;
}
}
}
}
}
template
<
int
quant_bit
>
void
permute_B_rows_for_mixed_gemm
(
int8_t
*
permuted_quantized_tensor
,
void
permute_B_rows_for_mixed_gemm
(
int8_t
*
permuted_quantized_tensor
,
const
int8_t
*
quantized_tensor
,
const
int8_t
*
quantized_tensor
,
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
shape
,
...
@@ -92,9 +150,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
...
@@ -92,9 +150,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const
size_t
num_rows
=
shape
.
size
()
==
2
?
shape
[
0
]
:
shape
[
1
];
const
size_t
num_rows
=
shape
.
size
()
==
2
?
shape
[
0
]
:
shape
[
1
];
const
size_t
num_cols
=
shape
.
size
()
==
2
?
shape
[
1
]
:
shape
[
2
];
const
size_t
num_cols
=
shape
.
size
()
==
2
?
shape
[
1
]
:
shape
[
2
];
const
int
BITS_PER_ELT
=
8
;
const
int
BITS_PER_ELT
=
quant_bit
;
const
int
K
=
16
/
BITS_PER_ELT
;
const
int
K
=
16
/
BITS_PER_ELT
;
// const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
const
int
ELTS_PER_REG
=
32
/
BITS_PER_ELT
;
const
int
ELTS_PER_REG
=
32
/
BITS_PER_ELT
;
const
uint32_t
*
input_byte_ptr
=
const
uint32_t
*
input_byte_ptr
=
...
@@ -102,7 +159,6 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
...
@@ -102,7 +159,6 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
uint32_t
*
output_byte_ptr
=
uint32_t
*
output_byte_ptr
=
reinterpret_cast
<
uint32_t
*>
(
permuted_quantized_tensor
);
reinterpret_cast
<
uint32_t
*>
(
permuted_quantized_tensor
);
// int MMA_SHAPE_N = 8;
int
B_ROWS_PER_MMA
=
8
*
K
;
int
B_ROWS_PER_MMA
=
8
*
K
;
const
int
elts_in_int32
=
32
/
BITS_PER_ELT
;
const
int
elts_in_int32
=
32
/
BITS_PER_ELT
;
...
@@ -118,15 +174,134 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
...
@@ -118,15 +174,134 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const
int
read_row
=
base_row
+
tile_read_row
;
const
int
read_row
=
base_row
+
tile_read_row
;
const
int
read_col
=
write_col
;
const
int
read_col
=
write_col
;
const
int64_t
read_offset
=
int64_t
(
read_row
)
*
num_vec_cols
+
read_col
;
const
int64_t
read_offset
=
static_cast
<
int64_t
>
(
read_row
)
*
num_vec_cols
+
read_col
;
const
int64_t
write_offset
=
const
int64_t
write_offset
=
int64_t
(
write_row
)
*
num_vec_cols
+
write_col
;
static_cast
<
int64_t
>
(
write_row
)
*
num_vec_cols
+
write_col
;
output_byte_ptr
[
write_offset
]
=
input_byte_ptr
[
read_offset
];
output_byte_ptr
[
write_offset
]
=
input_byte_ptr
[
read_offset
];
}
}
}
}
}
}
}
}
template
<
int
quant_bit
>
void
subbyte_transpose_impl
(
int8_t
*
transposed_quantized_tensor
,
const
int8_t
*
quantized_tensor
,
const
std
::
vector
<
size_t
>&
shape
)
{
const
int
bits_per_elt
=
quant_bit
;
// FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be
// 2-D or 3-D");
// const size_t num_experts = 1;
const
size_t
num_rows
=
shape
.
size
()
==
2
?
shape
[
0
]
:
shape
[
1
];
const
size_t
num_cols
=
shape
.
size
()
==
2
?
shape
[
1
]
:
shape
[
2
];
const
size_t
col_bytes
=
num_cols
*
bits_per_elt
/
8
;
const
size_t
col_bytes_trans
=
num_rows
*
bits_per_elt
/
8
;
// const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
const
uint8_t
*
input_byte_ptr
=
reinterpret_cast
<
const
uint8_t
*>
(
quantized_tensor
);
uint8_t
*
output_byte_ptr
=
reinterpret_cast
<
uint8_t
*>
(
transposed_quantized_tensor
);
static
constexpr
int
ELTS_PER_BYTE
=
8
/
quant_bit
;
static
constexpr
int
M_TILE_L1
=
64
;
static
constexpr
int
N_TILE_L1
=
M_TILE_L1
/
ELTS_PER_BYTE
;
uint8_t
cache_buf
[
M_TILE_L1
][
N_TILE_L1
];
static
constexpr
int
VECTOR_WIDTH
=
std
::
min
(
32
,
N_TILE_L1
);
// const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
// const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
for
(
size_t
row_tile_start
=
0
;
row_tile_start
<
num_rows
;
row_tile_start
+=
M_TILE_L1
)
{
for
(
size_t
col_tile_start_byte
=
0
;
col_tile_start_byte
<
col_bytes
;
col_tile_start_byte
+=
N_TILE_L1
)
{
const
int
row_limit
=
std
::
min
(
row_tile_start
+
M_TILE_L1
,
num_rows
);
const
int
col_limit
=
std
::
min
(
col_tile_start_byte
+
N_TILE_L1
,
col_bytes
);
for
(
int
ii
=
0
;
ii
<
M_TILE_L1
;
++
ii
)
{
const
int
row
=
row_tile_start
+
ii
;
for
(
int
jj
=
0
;
jj
<
N_TILE_L1
;
jj
+=
VECTOR_WIDTH
)
{
const
int
col
=
col_tile_start_byte
+
jj
;
const
size_t
logical_src_offset
=
row
*
col_bytes
+
col
;
if
(
row
<
row_limit
&&
col
<
col_limit
)
{
for
(
int
v
=
0
;
v
<
VECTOR_WIDTH
;
++
v
)
{
cache_buf
[
ii
][
jj
+
v
]
=
input_byte_ptr
[
logical_src_offset
+
v
];
}
}
}
}
if
(
quant_bit
==
8
)
{
for
(
int
ii
=
0
;
ii
<
M_TILE_L1
;
++
ii
)
{
for
(
int
jj
=
ii
+
1
;
jj
<
N_TILE_L1
;
++
jj
)
{
std
::
swap
(
cache_buf
[
ii
][
jj
],
cache_buf
[
jj
][
ii
]);
}
}
}
else
if
(
quant_bit
==
4
)
{
for
(
int
ii
=
0
;
ii
<
M_TILE_L1
;
++
ii
)
{
// Using M_TILE_L1 here is deliberate since we assume that the cache
// tile is square in the number of elements (not necessarily the
// number of bytes).
for
(
int
jj
=
ii
+
1
;
jj
<
M_TILE_L1
;
++
jj
)
{
const
int
ii_byte
=
ii
/
ELTS_PER_BYTE
;
const
int
ii_bit_offset
=
ii
%
ELTS_PER_BYTE
;
const
int
jj_byte
=
jj
/
ELTS_PER_BYTE
;
const
int
jj_bit_offset
=
jj
%
ELTS_PER_BYTE
;
uint8_t
src_elt
=
0xF
&
(
cache_buf
[
ii
][
jj_byte
]
>>
(
4
*
jj_bit_offset
));
uint8_t
tgt_elt
=
0xF
&
(
cache_buf
[
jj
][
ii_byte
]
>>
(
4
*
ii_bit_offset
));
cache_buf
[
ii
][
jj_byte
]
&=
(
0xF0
>>
(
4
*
jj_bit_offset
));
cache_buf
[
jj
][
ii_byte
]
&=
(
0xF0
>>
(
4
*
ii_bit_offset
));
cache_buf
[
ii
][
jj_byte
]
|=
(
tgt_elt
<<
(
4
*
jj_bit_offset
));
cache_buf
[
jj
][
ii_byte
]
|=
(
src_elt
<<
(
4
*
ii_bit_offset
));
}
}
}
else
{
phi
::
errors
::
Unimplemented
(
"Unsupported quantization bits: %d"
,
quant_bit
);
}
const
size_t
row_tile_start_trans
=
col_tile_start_byte
*
ELTS_PER_BYTE
;
const
size_t
col_tile_start_byte_trans
=
row_tile_start
/
ELTS_PER_BYTE
;
const
int
row_limit_trans
=
std
::
min
(
row_tile_start_trans
+
M_TILE_L1
,
num_cols
);
const
int
col_limit_trans
=
std
::
min
(
col_tile_start_byte_trans
+
N_TILE_L1
,
col_bytes_trans
);
for
(
int
ii
=
0
;
ii
<
M_TILE_L1
;
++
ii
)
{
const
int
row
=
row_tile_start_trans
+
ii
;
for
(
int
jj
=
0
;
jj
<
N_TILE_L1
;
jj
+=
VECTOR_WIDTH
)
{
const
int
col
=
col_tile_start_byte_trans
+
jj
;
const
size_t
logical_tgt_offset
=
row
*
col_bytes_trans
+
col
;
if
(
row
<
row_limit_trans
&&
col
<
col_limit_trans
)
{
for
(
int
v
=
0
;
v
<
VECTOR_WIDTH
;
++
v
)
{
output_byte_ptr
[
logical_tgt_offset
+
v
]
=
cache_buf
[
ii
][
jj
+
v
];
}
}
}
}
}
}
}
template
<
int
quant_bit
>
void
interleave_column_major_tensor
(
int8_t
*
interleaved_quantized_tensor
,
void
interleave_column_major_tensor
(
int8_t
*
interleaved_quantized_tensor
,
const
int8_t
*
quantized_tensor
,
const
int8_t
*
quantized_tensor
,
const
std
::
vector
<
size_t
>&
shape
)
{
const
std
::
vector
<
size_t
>&
shape
)
{
...
@@ -134,7 +309,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
...
@@ -134,7 +309,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
const
size_t
num_rows
=
shape
.
size
()
==
2
?
shape
[
0
]
:
shape
[
1
];
const
size_t
num_rows
=
shape
.
size
()
==
2
?
shape
[
0
]
:
shape
[
1
];
const
size_t
num_cols
=
shape
.
size
()
==
2
?
shape
[
1
]
:
shape
[
2
];
const
size_t
num_cols
=
shape
.
size
()
==
2
?
shape
[
1
]
:
shape
[
2
];
const
size_t
BITS_PER_ELT
=
8
;
const
size_t
BITS_PER_ELT
=
quant_bit
;
const
size_t
elts_in_int32
=
32
/
BITS_PER_ELT
;
const
size_t
elts_in_int32
=
32
/
BITS_PER_ELT
;
const
size_t
rows_per_tile
=
64
;
const
size_t
rows_per_tile
=
64
;
...
@@ -169,6 +344,5 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
...
@@ -169,6 +344,5 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
}
}
}
}
}
}
}
// namespace phi
}
// namespace phi
#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
paddle/phi/kernels/llm_int8_mat
_
mul_kernel.h
→
paddle/phi/kernels/llm_int8_matmul_kernel.h
浏览文件 @
f4290a92
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
LLMInt8Mat
M
ulKernel
(
const
Context
&
dev_ctx
,
void
LLMInt8Mat
m
ulKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight_scale
,
const
DenseTensor
&
weight_scale
,
...
...
paddle/phi/kernels/weight_only_mat
_
mul_kernel.h
→
paddle/phi/kernels/weight_only_matmul_kernel.h
浏览文件 @
f4290a92
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
WeightOnlyMat
M
ulKernel
(
const
Context
&
dev_ctx
,
void
WeightOnlyMat
m
ulKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
weight_scale
,
const
DenseTensor
&
weight_scale
,
...
...
python/paddle/nn/functional/common.py
浏览文件 @
f4290a92
...
@@ -1892,11 +1892,11 @@ def linear_compress(
...
@@ -1892,11 +1892,11 @@ def linear_compress(
):
):
if
in_dynamic_mode
():
if
in_dynamic_mode
():
if
algo
==
"llm.int8"
:
if
algo
==
"llm.int8"
:
y
=
_C_ops
.
llm_int8_mat
_
mul
(
y
=
_C_ops
.
llm_int8_matmul
(
x
,
weight
,
weight_scale
,
config
[
'threshold'
]
x
,
weight
,
weight_scale
,
config
[
'threshold'
]
)
)
elif
algo
==
"weight_only"
:
elif
algo
==
"weight_only"
:
y
=
_C_ops
.
weight_only_mat
_
mul
(
x
,
weight
,
weight_scale
)
y
=
_C_ops
.
weight_only_matmul
(
x
,
weight
,
weight_scale
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'."
.
format
(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'."
.
format
(
...
@@ -1915,11 +1915,19 @@ def linear_compress(
...
@@ -1915,11 +1915,19 @@ def linear_compress(
if
algo
==
"llm.int8"
:
if
algo
==
"llm.int8"
:
type
=
"llm_int8_matmul"
type
=
"llm_int8_matmul"
inputs
=
{
'X'
:
[
x
],
'Y'
:
[
weight
],
'weight_scale'
:
[
weight_scale
]}
inputs
=
{
'x'
:
[
x
],
'weight'
:
[
weight
],
'weight_scale'
:
[
weight_scale
],
}
attrs
=
{
'algo'
:
algo
,
'threshold'
:
config
[
'threshold'
]}
attrs
=
{
'algo'
:
algo
,
'threshold'
:
config
[
'threshold'
]}
elif
algo
==
"weight_only"
:
elif
algo
==
"weight_only"
:
type
=
"weight_only_matmul"
type
=
"weight_only_matmul"
inputs
=
{
'X'
:
[
x
],
'Y'
:
[
weight
],
'weight_scale'
:
[
weight_scale
]}
inputs
=
{
'x'
:
[
x
],
'weight'
:
[
weight
],
'weight_scale'
:
[
weight_scale
],
}
attrs
=
{}
attrs
=
{}
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
python/paddle/nn/layer/common.py
浏览文件 @
f4290a92
...
@@ -301,10 +301,13 @@ class LinearCompress(Layer):
...
@@ -301,10 +301,13 @@ class LinearCompress(Layer):
weight_attr
=
paddle
.
framework
.
ParamAttr
(
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
weight_tensor
)
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
weight_tensor
)
)
)
weight_shape
=
(
[
self
.
weight
.
shape
[
1
],
self
.
weight
.
shape
[
0
]]
if
self
.
bits
==
8
else
[
self
.
weight
.
shape
[
1
]
/
2
,
self
.
weight
.
shape
[
0
]]
)
self
.
weight
=
self
.
create_parameter
(
self
.
weight
=
self
.
create_parameter
(
shape
=
self
.
weight
.
shape
shape
=
weight_shape
,
if
self
.
layout
==
0
else
[
self
.
weight
.
shape
[
1
],
self
.
weight
.
shape
[
0
]],
attr
=
weight_attr
,
attr
=
weight_attr
,
dtype
=
"int8"
,
dtype
=
"int8"
,
is_bias
=
False
,
is_bias
=
False
,
...
...
test/legacy_test/test_linear_compress.py
浏览文件 @
f4290a92
...
@@ -36,6 +36,7 @@ class LinearTestCase(unittest.TestCase):
...
@@ -36,6 +36,7 @@ class LinearTestCase(unittest.TestCase):
self
.
in_features
=
64
self
.
in_features
=
64
self
.
out_features
=
64
self
.
out_features
=
64
self
.
algo
=
"weight_only"
self
.
algo
=
"weight_only"
self
.
bits
=
8
def
setUp
(
self
):
def
setUp
(
self
):
self
.
config
()
self
.
config
()
...
@@ -62,6 +63,7 @@ class LinearTestCase(unittest.TestCase):
...
@@ -62,6 +63,7 @@ class LinearTestCase(unittest.TestCase):
self
.
in_features
,
self
.
in_features
,
self
.
out_features
,
self
.
out_features
,
bias_attr
=
bias_attr
,
bias_attr
=
bias_attr
,
bits
=
8
,
algo
=
self
.
algo
,
algo
=
self
.
algo
,
config
=
self
.
config
,
config
=
self
.
config
,
)
)
...
@@ -112,5 +114,15 @@ class LinearTestCase3(LinearTestCase):
...
@@ -112,5 +114,15 @@ class LinearTestCase3(LinearTestCase):
self
.
atol
=
1e-1
self
.
atol
=
1e-1
class
LinearTestCase4
(
LinearTestCase
):
def
config
(
self
):
super
().
config
()
self
.
dtype
=
'float16'
self
.
bias
=
True
self
.
in_features
=
128
self
.
out_features
=
64
self
.
bits
=
4
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录