Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
236ad4fc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
236ad4fc
编写于
8月 12, 2022
作者:
C
Chang Xu
提交者:
GitHub
8月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Quant Row&Column ParallelLinear (#44869)
上级
4eec94dd
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
750 addition
and
37 deletion
+750
-37
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
.../paddle/fluid/contrib/slim/quantization/imperative/qat.py
+8
-3
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
...addle/fluid/contrib/slim/quantization/imperative/utils.py
+51
-22
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py
python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py
+326
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py
...paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py
+141
-0
python/paddle/nn/quant/quant_layers.py
python/paddle/nn/quant/quant_layers.py
+221
-12
未找到文件。
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
浏览文件 @
236ad4fc
...
@@ -48,7 +48,10 @@ class ImperativeQuantAware(object):
...
@@ -48,7 +48,10 @@ class ImperativeQuantAware(object):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
,
'Conv2DTranspose'
],
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
,
'Conv2DTranspose'
,
'ColumnParallelLinear'
,
'RowParallelLinear'
],
weight_quantize_type
=
'abs_max'
,
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
weight_bits
=
8
,
weight_bits
=
8
,
...
@@ -431,12 +434,14 @@ class ImperativeQuantizeOutputs(object):
...
@@ -431,12 +434,14 @@ class ImperativeQuantizeOutputs(object):
parent_layer
,
sub_name
=
\
parent_layer
,
sub_name
=
\
utils
.
find_parent_layer_and_sub_name
(
model
,
cur_name
)
utils
.
find_parent_layer_and_sub_name
(
model
,
cur_name
)
reduce_type
=
None
if
isinstance
(
cur_layer
,
tuple
(
utils
.
fake_quant_output_layers
)):
if
isinstance
(
cur_layer
,
tuple
(
utils
.
fake_quant_output_layers
)):
cur_quant_layer
=
quant_layers
.
FakeQuantMAOutputScaleLayer
(
cur_quant_layer
=
quant_layers
.
FakeQuantMAOutputScaleLayer
(
cur_layer
,
self
.
_moving_rate
)
cur_layer
,
self
.
_moving_rate
,
reduce_type
=
reduce_type
)
else
:
else
:
cur_quant_layer
=
quant_layers
.
MAOutputScaleLayer
(
cur_quant_layer
=
quant_layers
.
MAOutputScaleLayer
(
cur_layer
,
self
.
_moving_rate
)
cur_layer
,
self
.
_moving_rate
,
reduce_type
=
reduce_type
)
setattr
(
parent_layer
,
sub_name
,
cur_quant_layer
)
setattr
(
parent_layer
,
sub_name
,
cur_quant_layer
)
...
...
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
浏览文件 @
236ad4fc
...
@@ -16,36 +16,63 @@ import math
...
@@ -16,36 +16,63 @@ import math
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
paddle.distributed
import
fleet
import
paddle.nn.quant.quant_layers
as
quant_layers
import
paddle.nn.quant.quant_layers
as
quant_layers
from
..utils
import
_get_op_input_var_names
,
_get_op_output_var_names
,
_get_output_name_index
,
_get_input_name_index
from
..utils
import
_get_op_input_var_names
,
_get_op_output_var_names
,
_get_output_name_index
,
_get_input_name_index
layer_name_map
=
{
layer_name_map
=
{
'Conv2DTranspose'
:
paddle
.
nn
.
Conv2DTranspose
,
'Conv2DTranspose'
:
'Conv2D'
:
paddle
.
nn
.
Conv2D
,
paddle
.
nn
.
Conv2DTranspose
,
'Linear'
:
paddle
.
nn
.
Linear
,
'Conv2D'
:
'AdaptiveAvgPool2D'
:
paddle
.
nn
.
AdaptiveAvgPool2D
,
paddle
.
nn
.
Conv2D
,
'AdaptiveMaxPool2D'
:
paddle
.
nn
.
AdaptiveMaxPool2D
,
'Linear'
:
'AvgPool2D'
:
paddle
.
nn
.
AvgPool2D
,
paddle
.
nn
.
Linear
,
'MaxPool2D'
:
paddle
.
nn
.
MaxPool2D
,
'AdaptiveAvgPool2D'
:
'Hardswish'
:
paddle
.
nn
.
Hardswish
,
paddle
.
nn
.
AdaptiveAvgPool2D
,
'LeakyReLU'
:
paddle
.
nn
.
LeakyReLU
,
'AdaptiveMaxPool2D'
:
'PReLU'
:
paddle
.
nn
.
PReLU
,
paddle
.
nn
.
AdaptiveMaxPool2D
,
'ReLU'
:
paddle
.
nn
.
ReLU
,
'AvgPool2D'
:
'ReLU6'
:
paddle
.
nn
.
ReLU6
,
paddle
.
nn
.
AvgPool2D
,
'Sigmoid'
:
paddle
.
nn
.
Sigmoid
,
'MaxPool2D'
:
'Softmax'
:
paddle
.
nn
.
Softmax
,
paddle
.
nn
.
MaxPool2D
,
'Swish'
:
paddle
.
nn
.
Swish
,
'Hardswish'
:
'Tanh'
:
paddle
.
nn
.
Tanh
,
paddle
.
nn
.
Hardswish
,
'Hardswish'
:
paddle
.
nn
.
Hardswish
,
'LeakyReLU'
:
'BatchNorm'
:
paddle
.
nn
.
BatchNorm
,
paddle
.
nn
.
LeakyReLU
,
'GroupNorm'
:
paddle
.
nn
.
GroupNorm
,
'PReLU'
:
'LayerNorm'
:
paddle
.
nn
.
LayerNorm
,
paddle
.
nn
.
PReLU
,
'ReLU'
:
paddle
.
nn
.
ReLU
,
'ReLU6'
:
paddle
.
nn
.
ReLU6
,
'Sigmoid'
:
paddle
.
nn
.
Sigmoid
,
'Softmax'
:
paddle
.
nn
.
Softmax
,
'Swish'
:
paddle
.
nn
.
Swish
,
'Tanh'
:
paddle
.
nn
.
Tanh
,
'Hardswish'
:
paddle
.
nn
.
Hardswish
,
'BatchNorm'
:
paddle
.
nn
.
BatchNorm
,
'GroupNorm'
:
paddle
.
nn
.
GroupNorm
,
'LayerNorm'
:
paddle
.
nn
.
LayerNorm
,
'ColumnParallelLinear'
:
fleet
.
meta_parallel
.
parallel_layers
.
mp_layers
.
ColumnParallelLinear
,
'RowParallelLinear'
:
fleet
.
meta_parallel
.
parallel_layers
.
mp_layers
.
RowParallelLinear
}
}
# Apply fake quant for the inputs of these layers
# Apply fake quant for the inputs of these layers
fake_quant_input_layers
=
[
fake_quant_input_layers
=
[
paddle
.
nn
.
Conv2D
,
paddle
.
nn
.
Linear
,
paddle
.
nn
.
Conv2DTranspose
paddle
.
nn
.
Conv2D
,
paddle
.
nn
.
Linear
,
paddle
.
nn
.
Conv2DTranspose
,
fleet
.
meta_parallel
.
RowParallelLinear
,
fleet
.
meta_parallel
.
ColumnParallelLinear
]
]
# Apply fake quant for the output of these layers
# Apply fake quant for the output of these layers
...
@@ -65,7 +92,9 @@ fake_quant_leaf_layers = [
...
@@ -65,7 +92,9 @@ fake_quant_leaf_layers = [
fake_quant_wrap_layers
=
[
fake_quant_wrap_layers
=
[
quant_layers
.
QuantizedConv2D
,
quant_layers
.
QuantizedLinear
,
quant_layers
.
QuantizedConv2D
,
quant_layers
.
QuantizedLinear
,
quant_layers
.
QuantizedConv2DTranspose
quant_layers
.
QuantizedConv2DTranspose
,
quant_layers
.
QuantizedColumnParallelLinear
,
quant_layers
.
QuantizedRowParallelLinear
]
]
# The weight format of these layers is Cin * Cout * H * W
# The weight format of these layers is Cin * Cout * H * W
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
236ad4fc
...
@@ -82,6 +82,7 @@ list(APPEND DIST_TEST_OPS test_collective_alltoall_single)
...
@@ -82,6 +82,7 @@ list(APPEND DIST_TEST_OPS test_collective_alltoall_single)
list
(
APPEND DIST_TEST_OPS test_eager_dist_api
)
list
(
APPEND DIST_TEST_OPS test_eager_dist_api
)
list
(
APPEND DIST_TEST_OPS test_collective_batch_isend_irecv
)
list
(
APPEND DIST_TEST_OPS test_collective_batch_isend_irecv
)
list
(
APPEND DIST_TEST_OPS test_collective_reduce_scatter
)
list
(
APPEND DIST_TEST_OPS test_collective_reduce_scatter
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_qat
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
#remove distribute unittests.
#remove distribute unittests.
...
@@ -352,6 +353,7 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
...
@@ -352,6 +353,7 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
list
(
REMOVE_ITEM TEST_OPS test_eager_dist_api
)
list
(
REMOVE_ITEM TEST_OPS test_eager_dist_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_batch_isend_irecv
)
list
(
REMOVE_ITEM TEST_OPS test_collective_batch_isend_irecv
)
list
(
REMOVE_ITEM TEST_OPS test_collective_reduce_scatter
)
list
(
REMOVE_ITEM TEST_OPS test_collective_reduce_scatter
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_qat
)
elseif
(
WITH_GPU
)
elseif
(
WITH_GPU
)
if
(
${
CUDNN_VERSION
}
VERSION_LESS 7100
)
if
(
${
CUDNN_VERSION
}
VERSION_LESS 7100
)
...
@@ -1607,6 +1609,7 @@ if(WITH_DISTRIBUTE
...
@@ -1607,6 +1609,7 @@ if(WITH_DISTRIBUTE
set_tests_properties
(
test_eager_dist_api PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_eager_dist_api PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_collective_batch_isend_irecv PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_collective_batch_isend_irecv PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_collective_reduce_scatter PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_collective_reduce_scatter PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_parallel_dygraph_qat PROPERTIES TIMEOUT 120
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding
set_tests_properties
(
test_parallel_dygraph_sparse_embedding
PROPERTIES TIMEOUT 200
)
PROPERTIES TIMEOUT 200
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py
0 → 100644
浏览文件 @
236ad4fc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
os
import
paddle
import
numpy
as
np
import
random
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet
as
fleet
from
paddle.io
import
DataLoader
,
Dataset
import
unittest
import
paddle.nn
as
nn
from
paddle.fluid.contrib.slim.quantization
import
ImperativeQuantAware
from
paddle.distributed.utils
import
find_free_ports
,
watch_local_trainers
,
get_cluster
,
TrainerProc
def
set_random_seed
(
seed
,
dp_id
,
rank_id
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
+
dp_id
)
paddle
.
seed
(
seed
+
rank_id
)
vocab_size
=
20
hidden_size
=
10
inner_size
=
8
output_size
=
10
seq_length
=
2
batch_size
=
4
def
get_attr
(
layer
,
name
):
if
getattr
(
layer
,
name
,
None
)
is
not
None
:
return
getattr
(
layer
,
name
,
None
)
else
:
return
get_attr
(
layer
.
_layer
,
name
)
def
get_gpus
(
selected_gpus
):
selected_gpus
=
[
x
.
strip
()
for
x
in
selected_gpus
.
split
(
','
)]
return
selected_gpus
def
get_cluster_from_args
(
selected_gpus
):
cluster_node_ips
=
'127.0.0.1'
node_ip
=
'127.0.0.1'
node_ips
=
[
x
.
strip
()
for
x
in
cluster_node_ips
.
split
(
','
)]
node_ips
.
index
(
node_ip
)
free_ports
=
None
free_ports
=
find_free_ports
(
len
(
selected_gpus
))
if
free_ports
is
not
None
:
free_ports
=
list
(
free_ports
)
trainer_endpoints
=
[]
for
ip
in
node_ips
:
trainer_endpoints
.
append
([
"%s:%d"
%
(
ip
,
port
)
for
port
in
free_ports
])
return
get_cluster
(
node_ips
,
node_ip
,
trainer_endpoints
,
selected_gpus
)
def
parallel_matmul
(
lm_output
,
logit_weights
,
parallel_output
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
model_parallel_group
=
hcg
.
get_model_parallel_group
()
world_size
=
hcg
.
get_model_parallel_world_size
()
rank
=
hcg
.
get_model_parallel_rank
()
if
world_size
>
1
:
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
lm_output
,
group
=
model_parallel_group
)
logits
=
paddle
.
matmul
(
input_parallel
,
logit_weights
,
transpose_y
=
True
)
if
parallel_output
:
return
logits
return
paddle
.
distributed
.
collective
.
_c_concat
(
logits
,
group
=
model_parallel_group
)
else
:
logits
=
paddle
.
matmul
(
lm_output
,
logit_weights
,
transpose_y
=
True
)
return
logits
class
PACT
(
nn
.
Layer
):
def
__init__
(
self
,
init_value
=
20
):
super
(
PACT
,
self
).
__init__
()
alpha_attr
=
paddle
.
ParamAttr
(
name
=
self
.
full_name
()
+
".pact"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
init_value
))
self
.
alpha
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
alpha_attr
,
dtype
=
'float32'
)
def
forward
(
self
,
x
):
out_left
=
paddle
.
nn
.
functional
.
relu
(
x
-
self
.
alpha
)
out_right
=
paddle
.
nn
.
functional
.
relu
(
-
self
.
alpha
-
x
)
x
=
x
-
out_left
+
out_right
return
x
class
SimpleMPNet
(
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
,
mp_id
):
super
(
SimpleMPNet
,
self
).
__init__
()
if
mp_id
==
0
:
init_fc1_data
=
np_fc1
[:,
:(
inner_size
//
2
)]
init_fc2_data
=
np_fc2
[:(
inner_size
//
2
),
:]
else
:
init_fc1_data
=
np_fc1
[:,
(
inner_size
//
2
):]
init_fc2_data
=
np_fc2
[(
inner_size
//
2
):,
:]
self
.
linear1
=
fleet
.
meta_parallel
.
ColumnParallelLinear
(
hidden_size
,
inner_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
init_fc1_data
)),
gather_output
=
False
,
has_bias
=
True
)
self
.
linear2
=
fleet
.
meta_parallel
.
RowParallelLinear
(
inner_size
,
hidden_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
init_fc2_data
)),
input_is_parallel
=
True
,
has_bias
=
True
)
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
embedding
=
fleet
.
meta_parallel
.
VocabParallelEmbedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.
))
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
parallel_matmul
(
x
,
get_attr
(
self
.
embedding
,
"weight"
),
False
)
return
x
class
SimpleDPNet
(
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
):
super
(
SimpleDPNet
,
self
).
__init__
()
self
.
linear1
=
paddle
.
nn
.
Linear
(
hidden_size
,
inner_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc1
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
linear2
=
paddle
.
nn
.
Linear
(
inner_size
,
hidden_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc2
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
embedding
=
paddle
.
nn
.
Embedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.
))
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
paddle
.
matmul
(
x
,
get_attr
(
self
.
embedding
,
"weight"
),
transpose_y
=
True
)
return
x
class
TestDistMPTraning
(
unittest
.
TestCase
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
self
.
model_parallel_size
=
2
self
.
data_parallel_size
=
1
strategy
.
hybrid_configs
=
{
"dp_degree"
:
self
.
data_parallel_size
,
"mp_degree"
:
self
.
model_parallel_size
,
"pp_degree"
:
1
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
self
.
onnx_format
=
False
self
.
check_export_model_accuracy
=
True
self
.
diff_threshold
=
0.01
self
.
fuse_conv_bn
=
False
def
train_batch
(
self
,
batch
,
model
,
optimizer
,
is_mp
):
output
=
model
(
batch
)
loss
=
output
.
mean
()
loss
.
backward
()
# do backward
optimizer
.
step
()
# update parameters
optimizer
.
clear_grad
()
return
loss
def
build_optimizer
(
self
,
model
):
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameters
=
model
.
parameters
())
return
optimizer
def
build_model_optimizer
(
self
,
weight_quantize_type
,
activation_quantize_type
,
use_pact
=
False
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
word_size
=
hcg
.
get_model_parallel_world_size
()
mp_id
=
hcg
.
get_model_parallel_rank
()
dp_id
=
hcg
.
get_data_parallel_rank
()
rank_id
=
dist
.
get_rank
()
imperative_qat
=
ImperativeQuantAware
(
weight_quantize_type
=
weight_quantize_type
,
activation_quantize_type
=
activation_quantize_type
,
fuse_conv_bn
=
self
.
fuse_conv_bn
,
act_preprocess_layer
=
PACT
if
use_pact
else
None
)
set_random_seed
(
1024
,
dp_id
,
rank_id
)
np_fc1
=
np
.
ones
((
hidden_size
,
inner_size
))
np_fc2
=
np
.
ones
(
(
inner_size
,
hidden_size
))
#np.random.random_sample((inner_size, hidden_size))
model_a
=
SimpleMPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
,
mp_id
)
model_a
=
imperative_qat
.
quantize
(
model_a
)
optimizer_a
=
self
.
build_optimizer
(
model_a
)
model_a
=
fleet
.
distributed_model
(
model_a
)
optimizer_a
=
fleet
.
distributed_optimizer
(
optimizer_a
)
model_b
=
SimpleDPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
)
model_b
=
imperative_qat
.
quantize
(
model_b
)
optimizer_b
=
self
.
build_optimizer
(
model_b
)
return
model_a
,
optimizer_a
,
model_b
,
optimizer_b
def
train
(
self
,
model_a
,
optimizer_a
,
model_b
,
optimizer_b
):
for
epoch
in
range
(
5
):
np_data
=
np
.
random
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_length
,
))
batch
=
paddle
.
to_tensor
(
np_data
)
loss_a
=
self
.
train_batch
(
batch
,
model_a
,
optimizer_a
,
True
)
loss_b
=
self
.
train_batch
(
batch
,
model_b
,
optimizer_b
,
False
)
np
.
testing
.
assert_allclose
(
loss_a
.
numpy
(),
loss_b
.
numpy
(),
rtol
=
1e-6
)
def
test_mp_model_1
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
or
fluid
.
core
.
get_cuda_device_count
()
==
0
:
return
selected_gpus
=
get_gpus
(
'0,1'
)
cluster
=
None
pod
=
None
model_a
,
optimizer_a
,
model_b
,
optimizer_b
=
self
.
build_model_optimizer
(
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
)
self
.
train
(
model_a
,
optimizer_a
,
model_b
,
optimizer_b
)
def
test_mp_model_2
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
or
fluid
.
core
.
get_cuda_device_count
()
==
0
:
return
selected_gpus
=
get_gpus
(
'0,1'
)
cluster
=
None
pod
=
None
model_a
,
optimizer_a
,
model_b
,
optimizer_b
=
self
.
build_model_optimizer
(
weight_quantize_type
=
'channel_wise_abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
use_pact
=
True
)
self
.
train
(
model_a
,
optimizer_a
,
model_b
,
optimizer_b
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py
0 → 100644
浏览文件 @
236ad4fc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
time
import
paddle
import
paddle.fluid
as
fluid
import
copy
import
os
import
subprocess
from
paddle.distributed.utils
import
find_free_ports
,
watch_local_trainers
,
get_cluster
,
TrainerProc
def
get_cluster_from_args
(
selected_gpus
):
cluster_node_ips
=
'127.0.0.1'
node_ip
=
'127.0.0.1'
node_ips
=
[
x
.
strip
()
for
x
in
cluster_node_ips
.
split
(
','
)]
node_ips
.
index
(
node_ip
)
free_ports
=
None
free_ports
=
find_free_ports
(
len
(
selected_gpus
))
if
free_ports
is
not
None
:
free_ports
=
list
(
free_ports
)
trainer_endpoints
=
[]
for
ip
in
node_ips
:
trainer_endpoints
.
append
([
"%s:%d"
%
(
ip
,
port
)
for
port
in
free_ports
])
return
get_cluster
(
node_ips
,
node_ip
,
trainer_endpoints
,
selected_gpus
)
def
get_gpus
(
selected_gpus
):
selected_gpus
=
[
x
.
strip
()
for
x
in
selected_gpus
.
split
(
','
)]
return
selected_gpus
def
start_local_trainers
(
cluster
,
pod
,
training_script
,
training_script_args
,
eager_mode
=
True
,
log_dir
=
None
):
current_env
=
copy
.
copy
(
os
.
environ
.
copy
())
#paddle broadcast ncclUniqueId use socket, and
#proxy maybe make trainers unreachable, so delete them.
#if we set them to "", grpc will log error message "bad uri"
#so just delete them.
current_env
.
pop
(
"http_proxy"
,
None
)
current_env
.
pop
(
"https_proxy"
,
None
)
procs
=
[]
for
t
in
pod
.
trainers
:
proc_env
=
{
"FLAGS_selected_gpus"
:
"%s"
%
","
.
join
([
str
(
g
)
for
g
in
t
.
gpus
]),
"PADDLE_TRAINER_ID"
:
"%d"
%
t
.
rank
,
"PADDLE_CURRENT_ENDPOINT"
:
"%s"
%
t
.
endpoint
,
"PADDLE_TRAINERS_NUM"
:
"%d"
%
cluster
.
trainers_nranks
(),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
cluster
.
trainers_endpoints
())
}
if
not
eager_mode
:
proc_env
[
"FLAGS_enable_eager_mode"
]
=
"%d"
%
0
current_env
.
update
(
proc_env
)
print
(
"trainer proc env:{}"
.
format
(
current_env
))
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
cmd
=
"python -m coverage run --branch -p "
+
training_script
else
:
cmd
=
"python -u "
+
training_script
print
(
"start trainer proc:{} env:{}"
.
format
(
cmd
,
proc_env
))
fn
=
None
proc
=
subprocess
.
Popen
(
cmd
.
split
(
" "
),
env
=
current_env
)
tp
=
TrainerProc
()
tp
.
proc
=
proc
tp
.
rank
=
t
.
rank
tp
.
log_fn
=
fn
tp
.
cmd
=
cmd
procs
.
append
(
tp
)
return
procs
class
TestMultipleGpus
(
unittest
.
TestCase
):
def
run_2gpu
(
self
,
target_file_name
,
eager_mode
=
True
):
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
or
fluid
.
core
.
get_cuda_device_count
()
==
0
:
return
selected_gpus
=
get_gpus
(
'0,1'
)
cluster
=
None
pod
=
None
cluster
,
pod
=
get_cluster_from_args
(
selected_gpus
)
procs
=
start_local_trainers
(
cluster
,
pod
,
eager_mode
=
eager_mode
,
training_script
=
target_file_name
,
training_script_args
=
[])
while
True
:
alive
=
watch_local_trainers
(
procs
,
cluster
.
trainers_endpoints
())
if
not
alive
:
print
(
"Local procs complete, POD info:{}"
.
format
(
pod
))
break
time
.
sleep
(
3
)
class
TestDataParallelQAT
(
TestMultipleGpus
):
def
test_multiple_gpus_qat
(
self
):
self
.
run_2gpu
(
'hybrid_parallel_qat.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/nn/quant/quant_layers.py
浏览文件 @
236ad4fc
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
paddle
from
paddle.framework
import
core
from
paddle.framework
import
core
from
paddle.fluid
import
dygraph_utils
from
paddle.fluid
import
dygraph_utils
from
paddle.utils
import
unique_name
from
paddle.utils
import
unique_name
...
@@ -37,6 +38,8 @@ __all__ = [
...
@@ -37,6 +38,8 @@ __all__ = [
'MAOutputScaleLayer'
,
'MAOutputScaleLayer'
,
'FakeQuantMAOutputScaleLayer'
,
'FakeQuantMAOutputScaleLayer'
,
'QuantStub'
,
'QuantStub'
,
'QuantizedRowParallelLinear'
,
'QuantizedColumnParallelLinear'
,
]
]
_logger
=
get_logger
(
__name__
,
_logger
=
get_logger
(
__name__
,
...
@@ -58,10 +61,12 @@ class FakeQuantAbsMax(Layer):
...
@@ -58,10 +61,12 @@ class FakeQuantAbsMax(Layer):
name
=
None
,
name
=
None
,
quant_bits
=
8
,
quant_bits
=
8
,
dtype
=
'float32'
,
dtype
=
'float32'
,
quant_on_weight
=
False
):
quant_on_weight
=
False
,
reduce_type
=
None
):
super
(
FakeQuantAbsMax
,
self
).
__init__
()
super
(
FakeQuantAbsMax
,
self
).
__init__
()
self
.
_quant_bits
=
quant_bits
self
.
_quant_bits
=
quant_bits
self
.
_name
=
name
self
.
_name
=
name
self
.
_reduce_type
=
reduce_type
scale_prefix
=
"{}.scale"
.
format
(
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
name
)
if
name
else
'quant_dequant.scale'
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
...
@@ -86,6 +91,10 @@ class FakeQuantAbsMax(Layer):
...
@@ -86,6 +91,10 @@ class FakeQuantAbsMax(Layer):
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
persistable
=
False
)
persistable
=
False
)
out_scale
=
self
.
_scale
out_scale
=
self
.
_scale
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
out_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
not
out_scale
:
if
not
out_scale
:
out_scale
=
_varbase_creator
(
out_scale
=
_varbase_creator
(
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
...
@@ -139,11 +148,12 @@ class FakeQuantMovingAverageAbsMax(Layer):
...
@@ -139,11 +148,12 @@ class FakeQuantMovingAverageAbsMax(Layer):
name
=
None
,
name
=
None
,
moving_rate
=
0.9
,
moving_rate
=
0.9
,
quant_bits
=
8
,
quant_bits
=
8
,
dtype
=
'float32'
):
dtype
=
'float32'
,
reduce_type
=
None
):
super
(
FakeQuantMovingAverageAbsMax
,
self
).
__init__
()
super
(
FakeQuantMovingAverageAbsMax
,
self
).
__init__
()
self
.
_moving_rate
=
moving_rate
self
.
_moving_rate
=
moving_rate
self
.
_quant_bits
=
quant_bits
self
.
_quant_bits
=
quant_bits
self
.
_reduce_type
=
reduce_type
scale_prefix
=
"{}.scale"
.
format
(
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
name
)
if
name
else
'quant_dequant.scale'
scale_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
scale_prefix
),
scale_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
scale_prefix
),
...
@@ -184,12 +194,17 @@ class FakeQuantMovingAverageAbsMax(Layer):
...
@@ -184,12 +194,17 @@ class FakeQuantMovingAverageAbsMax(Layer):
shape
=
input
.
shape
,
shape
=
input
.
shape
,
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
persistable
=
False
)
persistable
=
False
)
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
state
=
self
.
_state
if
self
.
training
else
None
state
=
self
.
_state
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
out
,
_
,
_
,
_
=
_C_ops
.
fake_quantize_dequantize_moving_average_abs_max
(
out
,
_
,
_
,
_
=
_C_ops
.
fake_quantize_dequantize_moving_average_abs_max
(
input
,
self
.
_scale
,
accum
,
state
,
quant_out
,
self
.
_scale
,
state
,
input
,
self
.
_scale
,
accum
,
state
,
quant_out
,
self
.
_scale
,
state
,
accum
,
*
attrs
)
accum
,
*
attrs
)
return
out
return
out
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
],
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
],
...
@@ -231,7 +246,8 @@ class FakeQuantChannelWiseAbsMax(Layer):
...
@@ -231,7 +246,8 @@ class FakeQuantChannelWiseAbsMax(Layer):
quant_bits
=
8
,
quant_bits
=
8
,
quant_axis
=
0
,
quant_axis
=
0
,
dtype
=
'float32'
,
dtype
=
'float32'
,
quant_on_weight
=
False
):
quant_on_weight
=
False
,
reduce_type
=
None
):
assert
quant_on_weight
==
True
,
"Channel_wise only can be used on weight quantization."
assert
quant_on_weight
==
True
,
"Channel_wise only can be used on weight quantization."
super
(
FakeQuantChannelWiseAbsMax
,
self
).
__init__
()
super
(
FakeQuantChannelWiseAbsMax
,
self
).
__init__
()
self
.
_quant_bits
=
quant_bits
self
.
_quant_bits
=
quant_bits
...
@@ -239,6 +255,7 @@ class FakeQuantChannelWiseAbsMax(Layer):
...
@@ -239,6 +255,7 @@ class FakeQuantChannelWiseAbsMax(Layer):
self
.
_dtype
=
dtype
self
.
_dtype
=
dtype
self
.
_name
=
name
self
.
_name
=
name
self
.
_channel_num
=
channel_num
self
.
_channel_num
=
channel_num
self
.
_reduce_type
=
reduce_type
scale_prefix
=
"{}.scale"
.
format
(
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
name
)
if
name
else
'quant_dequant.scale'
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
...
@@ -265,6 +282,9 @@ class FakeQuantChannelWiseAbsMax(Layer):
...
@@ -265,6 +282,9 @@ class FakeQuantChannelWiseAbsMax(Layer):
persistable
=
False
)
persistable
=
False
)
out_scale
=
self
.
_scale
out_scale
=
self
.
_scale
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
out_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
out_scale
is
None
:
if
out_scale
is
None
:
out_scale
=
_varbase_creator
(
out_scale
=
_varbase_creator
(
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
...
@@ -309,7 +329,11 @@ class FakeQuantChannelWiseAbsMax(Layer):
...
@@ -309,7 +329,11 @@ class FakeQuantChannelWiseAbsMax(Layer):
class
MovingAverageAbsMaxScale
(
Layer
):
class
MovingAverageAbsMaxScale
(
Layer
):
def
__init__
(
self
,
name
=
None
,
moving_rate
=
0.9
,
dtype
=
'float32'
):
def
__init__
(
self
,
name
=
None
,
moving_rate
=
0.9
,
dtype
=
'float32'
,
reduce_type
=
None
):
r
"""
r
"""
MovingAverageMaxScale layer is used to calculating the output quantization
MovingAverageMaxScale layer is used to calculating the output quantization
scale of Layer. Its computational formula is described as below:
scale of Layer. Its computational formula is described as below:
...
@@ -319,7 +343,7 @@ class MovingAverageAbsMaxScale(Layer):
...
@@ -319,7 +343,7 @@ class MovingAverageAbsMaxScale(Layer):
"""
"""
super
(
MovingAverageAbsMaxScale
,
self
).
__init__
()
super
(
MovingAverageAbsMaxScale
,
self
).
__init__
()
self
.
_moving_rate
=
moving_rate
self
.
_moving_rate
=
moving_rate
self
.
_reduce_type
=
reduce_type
scale_prefix
=
'{}.scale'
.
format
(
name
)
if
name
else
'outscale.scale'
scale_prefix
=
'{}.scale'
.
format
(
name
)
if
name
else
'outscale.scale'
scale_name
=
unique_name
.
generate
(
scale_prefix
)
scale_name
=
unique_name
.
generate
(
scale_prefix
)
scale_attr
=
ParamAttr
(
name
=
scale_name
,
scale_attr
=
ParamAttr
(
name
=
scale_name
,
...
@@ -352,13 +376,18 @@ class MovingAverageAbsMaxScale(Layer):
...
@@ -352,13 +376,18 @@ class MovingAverageAbsMaxScale(Layer):
if
in_dynamic_mode
():
if
in_dynamic_mode
():
attrs
=
(
'moving_rate'
,
self
.
_moving_rate
,
'is_test'
,
attrs
=
(
'moving_rate'
,
self
.
_moving_rate
,
'is_test'
,
not
self
.
training
)
not
self
.
training
)
state
=
self
.
_state
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
quant_out
=
_varbase_creator
(
type
=
input
.
type
,
quant_out
=
_varbase_creator
(
type
=
input
.
type
,
name
=
"{}.tmp"
.
format
(
input
.
name
),
name
=
"{}.tmp"
.
format
(
input
.
name
),
shape
=
input
.
shape
,
shape
=
input
.
shape
,
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
persistable
=
False
)
persistable
=
False
)
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
state
=
self
.
_state
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
out
,
_
,
_
,
_
=
_C_ops
.
moving_average_abs_max_scale
(
out
,
_
,
_
,
_
=
_C_ops
.
moving_average_abs_max_scale
(
input
,
accum
,
state
,
quant_out
,
self
.
_scale
,
state
,
accum
,
input
,
accum
,
state
,
quant_out
,
self
.
_scale
,
state
,
accum
,
...
@@ -659,13 +688,190 @@ class QuantizedLinear(Layer):
...
@@ -659,13 +688,190 @@ class QuantizedLinear(Layer):
return
out
return
out
class
QuantizedColumnParallelLinear
(
Layer
):
def
__init__
(
self
,
layer
,
weight_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'abs_max'
,
weight_pre_layer
=
None
,
act_pre_layer
=
None
,
weight_quant_layer
=
None
,
act_quant_layer
=
None
):
super
(
QuantizedColumnParallelLinear
,
self
).
__init__
()
'''
'''
assert
weight_quant_layer
is
None
,
"When quantizing ColumnParallelLinear, weight_quant_layer should be None."
assert
act_quant_layer
is
None
,
"When quantizing ColumnParallelLinear, act_quant_layer should be None."
self
.
weight
=
getattr
(
layer
,
'weight'
)
self
.
bias
=
getattr
(
layer
,
'bias'
)
self
.
name
=
getattr
(
layer
,
'_name'
)
# For FakeQuant
self
.
_linear_quant_axis
=
1
self
.
is_mp
=
getattr
(
layer
,
'is_mp'
)
self
.
model_parallel_group
=
getattr
(
layer
,
'model_parallel_group'
)
self
.
gather_output
=
getattr
(
layer
,
'gather_output'
)
self
.
_fake_quant_weight
=
_get_fake_quant_type
(
weight_quantize_type
,
name
=
self
.
weight
.
name
,
moving_rate
=
moving_rate
,
quant_bits
=
weight_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
True
,
channel_num
=
self
.
weight
.
shape
[
self
.
_linear_quant_axis
],
quant_axis
=
self
.
_linear_quant_axis
,
reduce_type
=
'max'
if
paddle
.
distributed
.
get_world_size
()
>
1
else
None
)
self
.
_fake_quant_input
=
_get_fake_quant_type
(
activation_quantize_type
,
name
=
layer
.
full_name
(),
moving_rate
=
moving_rate
,
quant_bits
=
activation_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
False
,
reduce_type
=
None
)
self
.
_act_preprocess
=
act_pre_layer
(
)
if
act_pre_layer
is
not
None
else
None
self
.
_weight_preprocess
=
weight_pre_layer
(
)
if
weight_pre_layer
is
not
None
else
None
def
forward
(
self
,
input
):
if
self
.
is_mp
:
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
input
,
group
=
self
.
model_parallel_group
)
else
:
input_parallel
=
input
if
self
.
_act_preprocess
is
not
None
:
input_parallel
=
self
.
_act_preprocess
(
input_parallel
)
quant_input
=
self
.
_fake_quant_input
(
input_parallel
)
weight
=
self
.
weight
if
self
.
_weight_preprocess
is
not
None
:
weight
=
self
.
_weight_preprocess
(
self
.
weight
)
quant_weight
=
self
.
_fake_quant_weight
(
weight
)
output_parallel
=
F
.
linear
(
x
=
quant_input
,
weight
=
quant_weight
,
bias
=
self
.
bias
,
name
=
self
.
name
)
if
self
.
gather_output
and
self
.
is_mp
:
output
=
paddle
.
distributed
.
collective
.
_c_concat
(
output_parallel
,
group
=
self
.
model_parallel_group
)
else
:
output
=
output_parallel
return
output
class
QuantizedRowParallelLinear
(
Layer
):
def
__init__
(
self
,
layer
,
weight_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'abs_max'
,
weight_pre_layer
=
None
,
act_pre_layer
=
None
,
weight_quant_layer
=
None
,
act_quant_layer
=
None
):
super
(
QuantizedRowParallelLinear
,
self
).
__init__
()
assert
weight_quant_layer
is
None
,
"When quantizing RowParallelLinear, weight_quant_layer cannot defined by yourself."
assert
act_quant_layer
is
None
,
"When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself."
# For Linear
self
.
weight
=
getattr
(
layer
,
'weight'
)
self
.
bias
=
getattr
(
layer
,
'bias'
)
self
.
name
=
getattr
(
layer
,
'_name'
)
# For FakeQuant
self
.
_linear_quant_axis
=
1
self
.
input_is_parallel
=
getattr
(
layer
,
'input_is_parallel'
)
self
.
is_mp
=
getattr
(
layer
,
'is_mp'
)
self
.
model_parallel_group
=
getattr
(
layer
,
'model_parallel_group'
)
self
.
_fake_quant_weight
=
_get_fake_quant_type
(
weight_quantize_type
,
name
=
self
.
weight
.
name
,
moving_rate
=
moving_rate
,
quant_bits
=
weight_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
True
,
channel_num
=
self
.
weight
.
shape
[
self
.
_linear_quant_axis
],
quant_axis
=
self
.
_linear_quant_axis
,
reduce_type
=
'max'
if
paddle
.
distributed
.
get_world_size
()
>
1
else
None
)
self
.
_fake_quant_input
=
_get_fake_quant_type
(
activation_quantize_type
,
name
=
layer
.
full_name
(),
moving_rate
=
moving_rate
,
quant_bits
=
activation_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
False
,
reduce_type
=
'max'
if
paddle
.
distributed
.
get_world_size
()
>
1
else
None
)
self
.
_act_preprocess
=
act_pre_layer
(
)
if
act_pre_layer
is
not
None
else
None
self
.
_weight_preprocess
=
weight_pre_layer
(
)
if
weight_pre_layer
is
not
None
else
None
def
forward
(
self
,
input
):
if
self
.
input_is_parallel
or
(
not
self
.
is_mp
):
input_parallel
=
input
else
:
# split last dim
input_parallel
=
paddle
.
distributed
.
collective
.
_c_split
(
input
,
group
=
self
.
model_parallel_group
)
if
self
.
_act_preprocess
is
not
None
:
input_parallel
=
self
.
_act_preprocess
(
input_parallel
)
quant_input
=
self
.
_fake_quant_input
(
input_parallel
)
weight
=
self
.
weight
if
self
.
_weight_preprocess
is
not
None
:
weight
=
self
.
_weight_preprocess
(
self
.
weight
)
quant_weight
=
self
.
_fake_quant_weight
(
weight
)
output_parallel
=
F
.
linear
(
x
=
quant_input
,
weight
=
quant_weight
,
name
=
self
.
name
)
if
self
.
is_mp
:
output_
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
else
:
output_
=
output_parallel
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
return
output
class
MAOutputScaleLayer
(
Layer
):
class
MAOutputScaleLayer
(
Layer
):
"""
"""
Add MovingAverageMaxScale layer to the behind of the input layer.
Add MovingAverageMaxScale layer to the behind of the input layer.
Calculate the scale (moving average abs max) for the output of the input layer.
Calculate the scale (moving average abs max) for the output of the input layer.
"""
"""
def
__init__
(
self
,
layer
=
None
,
moving_rate
=
0.9
,
name
=
None
,
dtype
=
'float32'
):
def
__init__
(
self
,
layer
=
None
,
moving_rate
=
0.9
,
name
=
None
,
dtype
=
'float32'
,
reduce_type
=
None
):
r
"""
r
"""
Construct
Construct
"""
"""
...
@@ -674,7 +880,7 @@ class MAOutputScaleLayer(Layer):
...
@@ -674,7 +880,7 @@ class MAOutputScaleLayer(Layer):
if
name
is
None
:
if
name
is
None
:
name
=
layer
.
full_name
()
name
=
layer
.
full_name
()
self
.
_ma_output_scale
=
\
self
.
_ma_output_scale
=
\
MovingAverageAbsMaxScale
(
name
,
moving_rate
,
dtype
)
MovingAverageAbsMaxScale
(
name
,
moving_rate
,
dtype
,
reduce_type
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
out
=
self
.
_layer
(
*
inputs
,
**
kwargs
)
out
=
self
.
_layer
(
*
inputs
,
**
kwargs
)
...
@@ -697,6 +903,7 @@ class FakeQuantMAOutputScaleLayer(Layer):
...
@@ -697,6 +903,7 @@ class FakeQuantMAOutputScaleLayer(Layer):
activation_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
moving_rate
=
0.9
,
name
=
None
,
name
=
None
,
reduce_type
=
None
,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
...
@@ -708,7 +915,8 @@ class FakeQuantMAOutputScaleLayer(Layer):
...
@@ -708,7 +915,8 @@ class FakeQuantMAOutputScaleLayer(Layer):
moving_rate
=
moving_rate
,
moving_rate
=
moving_rate
,
quant_bits
=
activation_bits
,
quant_bits
=
activation_bits
,
dtype
=
self
.
_dtype
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
False
)
quant_on_weight
=
False
,
reduce_type
=
reduce_type
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
out
=
self
.
_layer
(
*
inputs
,
**
kwargs
)
out
=
self
.
_layer
(
*
inputs
,
**
kwargs
)
...
@@ -723,7 +931,8 @@ def _get_fake_quant_type(quant_type, **kwargs):
...
@@ -723,7 +931,8 @@ def _get_fake_quant_type(quant_type, **kwargs):
call_args
=
{
call_args
=
{
"name"
:
kwargs
.
get
(
"name"
,
None
),
"name"
:
kwargs
.
get
(
"name"
,
None
),
"quant_bits"
:
kwargs
.
get
(
"quant_bits"
,
8
),
"quant_bits"
:
kwargs
.
get
(
"quant_bits"
,
8
),
"dtype"
:
kwargs
.
get
(
"dtype"
,
"float32"
)
"dtype"
:
kwargs
.
get
(
"dtype"
,
"float32"
),
"reduce_type"
:
kwargs
.
get
(
"reduce_type"
,
None
)
}
}
if
quant_type
==
'abs_max'
:
if
quant_type
==
'abs_max'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录