Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
236ad4fc
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
236ad4fc
编写于
8月 12, 2022
作者:
C
Chang Xu
提交者:
GitHub
8月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Quant Row&Column ParallelLinear (#44869)
上级
4eec94dd
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
750 addition
and
37 deletion
+750
-37
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
.../paddle/fluid/contrib/slim/quantization/imperative/qat.py
+8
-3
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
...addle/fluid/contrib/slim/quantization/imperative/utils.py
+51
-22
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py
python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py
+326
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py
...paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py
+141
-0
python/paddle/nn/quant/quant_layers.py
python/paddle/nn/quant/quant_layers.py
+221
-12
未找到文件。
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
浏览文件 @
236ad4fc
...
...
@@ -48,7 +48,10 @@ class ImperativeQuantAware(object):
"""
def
__init__
(
self
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
,
'Conv2DTranspose'
],
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
,
'Conv2DTranspose'
,
'ColumnParallelLinear'
,
'RowParallelLinear'
],
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
weight_bits
=
8
,
...
...
@@ -431,12 +434,14 @@ class ImperativeQuantizeOutputs(object):
parent_layer
,
sub_name
=
\
utils
.
find_parent_layer_and_sub_name
(
model
,
cur_name
)
reduce_type
=
None
if
isinstance
(
cur_layer
,
tuple
(
utils
.
fake_quant_output_layers
)):
cur_quant_layer
=
quant_layers
.
FakeQuantMAOutputScaleLayer
(
cur_layer
,
self
.
_moving_rate
)
cur_layer
,
self
.
_moving_rate
,
reduce_type
=
reduce_type
)
else
:
cur_quant_layer
=
quant_layers
.
MAOutputScaleLayer
(
cur_layer
,
self
.
_moving_rate
)
cur_layer
,
self
.
_moving_rate
,
reduce_type
=
reduce_type
)
setattr
(
parent_layer
,
sub_name
,
cur_quant_layer
)
...
...
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
浏览文件 @
236ad4fc
...
...
@@ -16,36 +16,63 @@ import math
import
numpy
as
np
import
paddle
from
paddle.distributed
import
fleet
import
paddle.nn.quant.quant_layers
as
quant_layers
from
..utils
import
_get_op_input_var_names
,
_get_op_output_var_names
,
_get_output_name_index
,
_get_input_name_index
layer_name_map
=
{
'Conv2DTranspose'
:
paddle
.
nn
.
Conv2DTranspose
,
'Conv2D'
:
paddle
.
nn
.
Conv2D
,
'Linear'
:
paddle
.
nn
.
Linear
,
'AdaptiveAvgPool2D'
:
paddle
.
nn
.
AdaptiveAvgPool2D
,
'AdaptiveMaxPool2D'
:
paddle
.
nn
.
AdaptiveMaxPool2D
,
'AvgPool2D'
:
paddle
.
nn
.
AvgPool2D
,
'MaxPool2D'
:
paddle
.
nn
.
MaxPool2D
,
'Hardswish'
:
paddle
.
nn
.
Hardswish
,
'LeakyReLU'
:
paddle
.
nn
.
LeakyReLU
,
'PReLU'
:
paddle
.
nn
.
PReLU
,
'ReLU'
:
paddle
.
nn
.
ReLU
,
'ReLU6'
:
paddle
.
nn
.
ReLU6
,
'Sigmoid'
:
paddle
.
nn
.
Sigmoid
,
'Softmax'
:
paddle
.
nn
.
Softmax
,
'Swish'
:
paddle
.
nn
.
Swish
,
'Tanh'
:
paddle
.
nn
.
Tanh
,
'Hardswish'
:
paddle
.
nn
.
Hardswish
,
'BatchNorm'
:
paddle
.
nn
.
BatchNorm
,
'GroupNorm'
:
paddle
.
nn
.
GroupNorm
,
'LayerNorm'
:
paddle
.
nn
.
LayerNorm
,
'Conv2DTranspose'
:
paddle
.
nn
.
Conv2DTranspose
,
'Conv2D'
:
paddle
.
nn
.
Conv2D
,
'Linear'
:
paddle
.
nn
.
Linear
,
'AdaptiveAvgPool2D'
:
paddle
.
nn
.
AdaptiveAvgPool2D
,
'AdaptiveMaxPool2D'
:
paddle
.
nn
.
AdaptiveMaxPool2D
,
'AvgPool2D'
:
paddle
.
nn
.
AvgPool2D
,
'MaxPool2D'
:
paddle
.
nn
.
MaxPool2D
,
'Hardswish'
:
paddle
.
nn
.
Hardswish
,
'LeakyReLU'
:
paddle
.
nn
.
LeakyReLU
,
'PReLU'
:
paddle
.
nn
.
PReLU
,
'ReLU'
:
paddle
.
nn
.
ReLU
,
'ReLU6'
:
paddle
.
nn
.
ReLU6
,
'Sigmoid'
:
paddle
.
nn
.
Sigmoid
,
'Softmax'
:
paddle
.
nn
.
Softmax
,
'Swish'
:
paddle
.
nn
.
Swish
,
'Tanh'
:
paddle
.
nn
.
Tanh
,
'Hardswish'
:
paddle
.
nn
.
Hardswish
,
'BatchNorm'
:
paddle
.
nn
.
BatchNorm
,
'GroupNorm'
:
paddle
.
nn
.
GroupNorm
,
'LayerNorm'
:
paddle
.
nn
.
LayerNorm
,
'ColumnParallelLinear'
:
fleet
.
meta_parallel
.
parallel_layers
.
mp_layers
.
ColumnParallelLinear
,
'RowParallelLinear'
:
fleet
.
meta_parallel
.
parallel_layers
.
mp_layers
.
RowParallelLinear
}
# Apply fake quant for the inputs of these layers
fake_quant_input_layers
=
[
paddle
.
nn
.
Conv2D
,
paddle
.
nn
.
Linear
,
paddle
.
nn
.
Conv2DTranspose
paddle
.
nn
.
Conv2D
,
paddle
.
nn
.
Linear
,
paddle
.
nn
.
Conv2DTranspose
,
fleet
.
meta_parallel
.
RowParallelLinear
,
fleet
.
meta_parallel
.
ColumnParallelLinear
]
# Apply fake quant for the output of these layers
...
...
@@ -65,7 +92,9 @@ fake_quant_leaf_layers = [
fake_quant_wrap_layers
=
[
quant_layers
.
QuantizedConv2D
,
quant_layers
.
QuantizedLinear
,
quant_layers
.
QuantizedConv2DTranspose
quant_layers
.
QuantizedConv2DTranspose
,
quant_layers
.
QuantizedColumnParallelLinear
,
quant_layers
.
QuantizedRowParallelLinear
]
# The weight format of these layers is Cin * Cout * H * W
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
236ad4fc
...
...
@@ -82,6 +82,7 @@ list(APPEND DIST_TEST_OPS test_collective_alltoall_single)
list
(
APPEND DIST_TEST_OPS test_eager_dist_api
)
list
(
APPEND DIST_TEST_OPS test_collective_batch_isend_irecv
)
list
(
APPEND DIST_TEST_OPS test_collective_reduce_scatter
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_qat
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
#remove distribute unittests.
...
...
@@ -352,6 +353,7 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
list
(
REMOVE_ITEM TEST_OPS test_eager_dist_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_batch_isend_irecv
)
list
(
REMOVE_ITEM TEST_OPS test_collective_reduce_scatter
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_qat
)
elseif
(
WITH_GPU
)
if
(
${
CUDNN_VERSION
}
VERSION_LESS 7100
)
...
...
@@ -1607,6 +1609,7 @@ if(WITH_DISTRIBUTE
set_tests_properties
(
test_eager_dist_api PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_collective_batch_isend_irecv PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_collective_reduce_scatter PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_parallel_dygraph_qat PROPERTIES TIMEOUT 120
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding
PROPERTIES TIMEOUT 200
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py
0 → 100644
浏览文件 @
236ad4fc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
os
import
paddle
import
numpy
as
np
import
random
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet
as
fleet
from
paddle.io
import
DataLoader
,
Dataset
import
unittest
import
paddle.nn
as
nn
from
paddle.fluid.contrib.slim.quantization
import
ImperativeQuantAware
from
paddle.distributed.utils
import
find_free_ports
,
watch_local_trainers
,
get_cluster
,
TrainerProc
def
set_random_seed
(
seed
,
dp_id
,
rank_id
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
+
dp_id
)
paddle
.
seed
(
seed
+
rank_id
)
vocab_size
=
20
hidden_size
=
10
inner_size
=
8
output_size
=
10
seq_length
=
2
batch_size
=
4
def
get_attr
(
layer
,
name
):
if
getattr
(
layer
,
name
,
None
)
is
not
None
:
return
getattr
(
layer
,
name
,
None
)
else
:
return
get_attr
(
layer
.
_layer
,
name
)
def
get_gpus
(
selected_gpus
):
selected_gpus
=
[
x
.
strip
()
for
x
in
selected_gpus
.
split
(
','
)]
return
selected_gpus
def
get_cluster_from_args
(
selected_gpus
):
cluster_node_ips
=
'127.0.0.1'
node_ip
=
'127.0.0.1'
node_ips
=
[
x
.
strip
()
for
x
in
cluster_node_ips
.
split
(
','
)]
node_ips
.
index
(
node_ip
)
free_ports
=
None
free_ports
=
find_free_ports
(
len
(
selected_gpus
))
if
free_ports
is
not
None
:
free_ports
=
list
(
free_ports
)
trainer_endpoints
=
[]
for
ip
in
node_ips
:
trainer_endpoints
.
append
([
"%s:%d"
%
(
ip
,
port
)
for
port
in
free_ports
])
return
get_cluster
(
node_ips
,
node_ip
,
trainer_endpoints
,
selected_gpus
)
def
parallel_matmul
(
lm_output
,
logit_weights
,
parallel_output
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
model_parallel_group
=
hcg
.
get_model_parallel_group
()
world_size
=
hcg
.
get_model_parallel_world_size
()
rank
=
hcg
.
get_model_parallel_rank
()
if
world_size
>
1
:
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
lm_output
,
group
=
model_parallel_group
)
logits
=
paddle
.
matmul
(
input_parallel
,
logit_weights
,
transpose_y
=
True
)
if
parallel_output
:
return
logits
return
paddle
.
distributed
.
collective
.
_c_concat
(
logits
,
group
=
model_parallel_group
)
else
:
logits
=
paddle
.
matmul
(
lm_output
,
logit_weights
,
transpose_y
=
True
)
return
logits
class
PACT
(
nn
.
Layer
):
def
__init__
(
self
,
init_value
=
20
):
super
(
PACT
,
self
).
__init__
()
alpha_attr
=
paddle
.
ParamAttr
(
name
=
self
.
full_name
()
+
".pact"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
init_value
))
self
.
alpha
=
self
.
create_parameter
(
shape
=
[
1
],
attr
=
alpha_attr
,
dtype
=
'float32'
)
def
forward
(
self
,
x
):
out_left
=
paddle
.
nn
.
functional
.
relu
(
x
-
self
.
alpha
)
out_right
=
paddle
.
nn
.
functional
.
relu
(
-
self
.
alpha
-
x
)
x
=
x
-
out_left
+
out_right
return
x
class
SimpleMPNet
(
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
,
mp_id
):
super
(
SimpleMPNet
,
self
).
__init__
()
if
mp_id
==
0
:
init_fc1_data
=
np_fc1
[:,
:(
inner_size
//
2
)]
init_fc2_data
=
np_fc2
[:(
inner_size
//
2
),
:]
else
:
init_fc1_data
=
np_fc1
[:,
(
inner_size
//
2
):]
init_fc2_data
=
np_fc2
[(
inner_size
//
2
):,
:]
self
.
linear1
=
fleet
.
meta_parallel
.
ColumnParallelLinear
(
hidden_size
,
inner_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
init_fc1_data
)),
gather_output
=
False
,
has_bias
=
True
)
self
.
linear2
=
fleet
.
meta_parallel
.
RowParallelLinear
(
inner_size
,
hidden_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
init_fc2_data
)),
input_is_parallel
=
True
,
has_bias
=
True
)
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
embedding
=
fleet
.
meta_parallel
.
VocabParallelEmbedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.
))
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
parallel_matmul
(
x
,
get_attr
(
self
.
embedding
,
"weight"
),
False
)
return
x
class
SimpleDPNet
(
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
):
super
(
SimpleDPNet
,
self
).
__init__
()
self
.
linear1
=
paddle
.
nn
.
Linear
(
hidden_size
,
inner_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc1
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
linear2
=
paddle
.
nn
.
Linear
(
inner_size
,
hidden_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc2
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
embedding
=
paddle
.
nn
.
Embedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.
))
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
paddle
.
matmul
(
x
,
get_attr
(
self
.
embedding
,
"weight"
),
transpose_y
=
True
)
return
x
class
TestDistMPTraning
(
unittest
.
TestCase
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
self
.
model_parallel_size
=
2
self
.
data_parallel_size
=
1
strategy
.
hybrid_configs
=
{
"dp_degree"
:
self
.
data_parallel_size
,
"mp_degree"
:
self
.
model_parallel_size
,
"pp_degree"
:
1
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
self
.
onnx_format
=
False
self
.
check_export_model_accuracy
=
True
self
.
diff_threshold
=
0.01
self
.
fuse_conv_bn
=
False
def
train_batch
(
self
,
batch
,
model
,
optimizer
,
is_mp
):
output
=
model
(
batch
)
loss
=
output
.
mean
()
loss
.
backward
()
# do backward
optimizer
.
step
()
# update parameters
optimizer
.
clear_grad
()
return
loss
def
build_optimizer
(
self
,
model
):
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameters
=
model
.
parameters
())
return
optimizer
def
build_model_optimizer
(
self
,
weight_quantize_type
,
activation_quantize_type
,
use_pact
=
False
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
word_size
=
hcg
.
get_model_parallel_world_size
()
mp_id
=
hcg
.
get_model_parallel_rank
()
dp_id
=
hcg
.
get_data_parallel_rank
()
rank_id
=
dist
.
get_rank
()
imperative_qat
=
ImperativeQuantAware
(
weight_quantize_type
=
weight_quantize_type
,
activation_quantize_type
=
activation_quantize_type
,
fuse_conv_bn
=
self
.
fuse_conv_bn
,
act_preprocess_layer
=
PACT
if
use_pact
else
None
)
set_random_seed
(
1024
,
dp_id
,
rank_id
)
np_fc1
=
np
.
ones
((
hidden_size
,
inner_size
))
np_fc2
=
np
.
ones
(
(
inner_size
,
hidden_size
))
#np.random.random_sample((inner_size, hidden_size))
model_a
=
SimpleMPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
,
mp_id
)
model_a
=
imperative_qat
.
quantize
(
model_a
)
optimizer_a
=
self
.
build_optimizer
(
model_a
)
model_a
=
fleet
.
distributed_model
(
model_a
)
optimizer_a
=
fleet
.
distributed_optimizer
(
optimizer_a
)
model_b
=
SimpleDPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
)
model_b
=
imperative_qat
.
quantize
(
model_b
)
optimizer_b
=
self
.
build_optimizer
(
model_b
)
return
model_a
,
optimizer_a
,
model_b
,
optimizer_b
def
train
(
self
,
model_a
,
optimizer_a
,
model_b
,
optimizer_b
):
for
epoch
in
range
(
5
):
np_data
=
np
.
random
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_length
,
))
batch
=
paddle
.
to_tensor
(
np_data
)
loss_a
=
self
.
train_batch
(
batch
,
model_a
,
optimizer_a
,
True
)
loss_b
=
self
.
train_batch
(
batch
,
model_b
,
optimizer_b
,
False
)
np
.
testing
.
assert_allclose
(
loss_a
.
numpy
(),
loss_b
.
numpy
(),
rtol
=
1e-6
)
def
test_mp_model_1
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
or
fluid
.
core
.
get_cuda_device_count
()
==
0
:
return
selected_gpus
=
get_gpus
(
'0,1'
)
cluster
=
None
pod
=
None
model_a
,
optimizer_a
,
model_b
,
optimizer_b
=
self
.
build_model_optimizer
(
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
)
self
.
train
(
model_a
,
optimizer_a
,
model_b
,
optimizer_b
)
def
test_mp_model_2
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
or
fluid
.
core
.
get_cuda_device_count
()
==
0
:
return
selected_gpus
=
get_gpus
(
'0,1'
)
cluster
=
None
pod
=
None
model_a
,
optimizer_a
,
model_b
,
optimizer_b
=
self
.
build_model_optimizer
(
weight_quantize_type
=
'channel_wise_abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
use_pact
=
True
)
self
.
train
(
model_a
,
optimizer_a
,
model_b
,
optimizer_b
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py
0 → 100644
浏览文件 @
236ad4fc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
time
import
paddle
import
paddle.fluid
as
fluid
import
copy
import
os
import
subprocess
from
paddle.distributed.utils
import
find_free_ports
,
watch_local_trainers
,
get_cluster
,
TrainerProc
def
get_cluster_from_args
(
selected_gpus
):
cluster_node_ips
=
'127.0.0.1'
node_ip
=
'127.0.0.1'
node_ips
=
[
x
.
strip
()
for
x
in
cluster_node_ips
.
split
(
','
)]
node_ips
.
index
(
node_ip
)
free_ports
=
None
free_ports
=
find_free_ports
(
len
(
selected_gpus
))
if
free_ports
is
not
None
:
free_ports
=
list
(
free_ports
)
trainer_endpoints
=
[]
for
ip
in
node_ips
:
trainer_endpoints
.
append
([
"%s:%d"
%
(
ip
,
port
)
for
port
in
free_ports
])
return
get_cluster
(
node_ips
,
node_ip
,
trainer_endpoints
,
selected_gpus
)
def
get_gpus
(
selected_gpus
):
selected_gpus
=
[
x
.
strip
()
for
x
in
selected_gpus
.
split
(
','
)]
return
selected_gpus
def
start_local_trainers
(
cluster
,
pod
,
training_script
,
training_script_args
,
eager_mode
=
True
,
log_dir
=
None
):
current_env
=
copy
.
copy
(
os
.
environ
.
copy
())
#paddle broadcast ncclUniqueId use socket, and
#proxy maybe make trainers unreachable, so delete them.
#if we set them to "", grpc will log error message "bad uri"
#so just delete them.
current_env
.
pop
(
"http_proxy"
,
None
)
current_env
.
pop
(
"https_proxy"
,
None
)
procs
=
[]
for
t
in
pod
.
trainers
:
proc_env
=
{
"FLAGS_selected_gpus"
:
"%s"
%
","
.
join
([
str
(
g
)
for
g
in
t
.
gpus
]),
"PADDLE_TRAINER_ID"
:
"%d"
%
t
.
rank
,
"PADDLE_CURRENT_ENDPOINT"
:
"%s"
%
t
.
endpoint
,
"PADDLE_TRAINERS_NUM"
:
"%d"
%
cluster
.
trainers_nranks
(),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
cluster
.
trainers_endpoints
())
}
if
not
eager_mode
:
proc_env
[
"FLAGS_enable_eager_mode"
]
=
"%d"
%
0
current_env
.
update
(
proc_env
)
print
(
"trainer proc env:{}"
.
format
(
current_env
))
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
cmd
=
"python -m coverage run --branch -p "
+
training_script
else
:
cmd
=
"python -u "
+
training_script
print
(
"start trainer proc:{} env:{}"
.
format
(
cmd
,
proc_env
))
fn
=
None
proc
=
subprocess
.
Popen
(
cmd
.
split
(
" "
),
env
=
current_env
)
tp
=
TrainerProc
()
tp
.
proc
=
proc
tp
.
rank
=
t
.
rank
tp
.
log_fn
=
fn
tp
.
cmd
=
cmd
procs
.
append
(
tp
)
return
procs
class
TestMultipleGpus
(
unittest
.
TestCase
):
def
run_2gpu
(
self
,
target_file_name
,
eager_mode
=
True
):
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
or
fluid
.
core
.
get_cuda_device_count
()
==
0
:
return
selected_gpus
=
get_gpus
(
'0,1'
)
cluster
=
None
pod
=
None
cluster
,
pod
=
get_cluster_from_args
(
selected_gpus
)
procs
=
start_local_trainers
(
cluster
,
pod
,
eager_mode
=
eager_mode
,
training_script
=
target_file_name
,
training_script_args
=
[])
while
True
:
alive
=
watch_local_trainers
(
procs
,
cluster
.
trainers_endpoints
())
if
not
alive
:
print
(
"Local procs complete, POD info:{}"
.
format
(
pod
))
break
time
.
sleep
(
3
)
class
TestDataParallelQAT
(
TestMultipleGpus
):
def
test_multiple_gpus_qat
(
self
):
self
.
run_2gpu
(
'hybrid_parallel_qat.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/nn/quant/quant_layers.py
浏览文件 @
236ad4fc
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
paddle.framework
import
core
from
paddle.fluid
import
dygraph_utils
from
paddle.utils
import
unique_name
...
...
@@ -37,6 +38,8 @@ __all__ = [
'MAOutputScaleLayer'
,
'FakeQuantMAOutputScaleLayer'
,
'QuantStub'
,
'QuantizedRowParallelLinear'
,
'QuantizedColumnParallelLinear'
,
]
_logger
=
get_logger
(
__name__
,
...
...
@@ -58,10 +61,12 @@ class FakeQuantAbsMax(Layer):
name
=
None
,
quant_bits
=
8
,
dtype
=
'float32'
,
quant_on_weight
=
False
):
quant_on_weight
=
False
,
reduce_type
=
None
):
super
(
FakeQuantAbsMax
,
self
).
__init__
()
self
.
_quant_bits
=
quant_bits
self
.
_name
=
name
self
.
_reduce_type
=
reduce_type
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
...
...
@@ -86,6 +91,10 @@ class FakeQuantAbsMax(Layer):
dtype
=
input
.
dtype
,
persistable
=
False
)
out_scale
=
self
.
_scale
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
out_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
not
out_scale
:
out_scale
=
_varbase_creator
(
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
...
...
@@ -139,11 +148,12 @@ class FakeQuantMovingAverageAbsMax(Layer):
name
=
None
,
moving_rate
=
0.9
,
quant_bits
=
8
,
dtype
=
'float32'
):
dtype
=
'float32'
,
reduce_type
=
None
):
super
(
FakeQuantMovingAverageAbsMax
,
self
).
__init__
()
self
.
_moving_rate
=
moving_rate
self
.
_quant_bits
=
quant_bits
self
.
_reduce_type
=
reduce_type
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
scale_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
scale_prefix
),
...
...
@@ -184,12 +194,17 @@ class FakeQuantMovingAverageAbsMax(Layer):
shape
=
input
.
shape
,
dtype
=
input
.
dtype
,
persistable
=
False
)
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
state
=
self
.
_state
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
out
,
_
,
_
,
_
=
_C_ops
.
fake_quantize_dequantize_moving_average_abs_max
(
input
,
self
.
_scale
,
accum
,
state
,
quant_out
,
self
.
_scale
,
state
,
accum
,
*
attrs
)
return
out
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
],
...
...
@@ -231,7 +246,8 @@ class FakeQuantChannelWiseAbsMax(Layer):
quant_bits
=
8
,
quant_axis
=
0
,
dtype
=
'float32'
,
quant_on_weight
=
False
):
quant_on_weight
=
False
,
reduce_type
=
None
):
assert
quant_on_weight
==
True
,
"Channel_wise only can be used on weight quantization."
super
(
FakeQuantChannelWiseAbsMax
,
self
).
__init__
()
self
.
_quant_bits
=
quant_bits
...
...
@@ -239,6 +255,7 @@ class FakeQuantChannelWiseAbsMax(Layer):
self
.
_dtype
=
dtype
self
.
_name
=
name
self
.
_channel_num
=
channel_num
self
.
_reduce_type
=
reduce_type
scale_prefix
=
"{}.scale"
.
format
(
name
)
if
name
else
'quant_dequant.scale'
self
.
_scale_name
=
unique_name
.
generate
(
scale_prefix
)
...
...
@@ -265,6 +282,9 @@ class FakeQuantChannelWiseAbsMax(Layer):
persistable
=
False
)
out_scale
=
self
.
_scale
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
out_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
if
out_scale
is
None
:
out_scale
=
_varbase_creator
(
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
...
...
@@ -309,7 +329,11 @@ class FakeQuantChannelWiseAbsMax(Layer):
class
MovingAverageAbsMaxScale
(
Layer
):
def
__init__
(
self
,
name
=
None
,
moving_rate
=
0.9
,
dtype
=
'float32'
):
def
__init__
(
self
,
name
=
None
,
moving_rate
=
0.9
,
dtype
=
'float32'
,
reduce_type
=
None
):
r
"""
MovingAverageMaxScale layer is used to calculating the output quantization
scale of Layer. Its computational formula is described as below:
...
...
@@ -319,7 +343,7 @@ class MovingAverageAbsMaxScale(Layer):
"""
super
(
MovingAverageAbsMaxScale
,
self
).
__init__
()
self
.
_moving_rate
=
moving_rate
self
.
_reduce_type
=
reduce_type
scale_prefix
=
'{}.scale'
.
format
(
name
)
if
name
else
'outscale.scale'
scale_name
=
unique_name
.
generate
(
scale_prefix
)
scale_attr
=
ParamAttr
(
name
=
scale_name
,
...
...
@@ -352,13 +376,18 @@ class MovingAverageAbsMaxScale(Layer):
if
in_dynamic_mode
():
attrs
=
(
'moving_rate'
,
self
.
_moving_rate
,
'is_test'
,
not
self
.
training
)
state
=
self
.
_state
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
quant_out
=
_varbase_creator
(
type
=
input
.
type
,
name
=
"{}.tmp"
.
format
(
input
.
name
),
shape
=
input
.
shape
,
dtype
=
input
.
dtype
,
persistable
=
False
)
if
self
.
_reduce_type
==
"max"
:
paddle
.
distributed
.
all_reduce
(
self
.
_scale
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
)
state
=
self
.
_state
if
self
.
training
else
None
accum
=
self
.
_accum
if
self
.
training
else
None
out
,
_
,
_
,
_
=
_C_ops
.
moving_average_abs_max_scale
(
input
,
accum
,
state
,
quant_out
,
self
.
_scale
,
state
,
accum
,
...
...
@@ -659,13 +688,190 @@ class QuantizedLinear(Layer):
return
out
class
QuantizedColumnParallelLinear
(
Layer
):
def
__init__
(
self
,
layer
,
weight_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'abs_max'
,
weight_pre_layer
=
None
,
act_pre_layer
=
None
,
weight_quant_layer
=
None
,
act_quant_layer
=
None
):
super
(
QuantizedColumnParallelLinear
,
self
).
__init__
()
'''
'''
assert
weight_quant_layer
is
None
,
"When quantizing ColumnParallelLinear, weight_quant_layer should be None."
assert
act_quant_layer
is
None
,
"When quantizing ColumnParallelLinear, act_quant_layer should be None."
self
.
weight
=
getattr
(
layer
,
'weight'
)
self
.
bias
=
getattr
(
layer
,
'bias'
)
self
.
name
=
getattr
(
layer
,
'_name'
)
# For FakeQuant
self
.
_linear_quant_axis
=
1
self
.
is_mp
=
getattr
(
layer
,
'is_mp'
)
self
.
model_parallel_group
=
getattr
(
layer
,
'model_parallel_group'
)
self
.
gather_output
=
getattr
(
layer
,
'gather_output'
)
self
.
_fake_quant_weight
=
_get_fake_quant_type
(
weight_quantize_type
,
name
=
self
.
weight
.
name
,
moving_rate
=
moving_rate
,
quant_bits
=
weight_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
True
,
channel_num
=
self
.
weight
.
shape
[
self
.
_linear_quant_axis
],
quant_axis
=
self
.
_linear_quant_axis
,
reduce_type
=
'max'
if
paddle
.
distributed
.
get_world_size
()
>
1
else
None
)
self
.
_fake_quant_input
=
_get_fake_quant_type
(
activation_quantize_type
,
name
=
layer
.
full_name
(),
moving_rate
=
moving_rate
,
quant_bits
=
activation_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
False
,
reduce_type
=
None
)
self
.
_act_preprocess
=
act_pre_layer
(
)
if
act_pre_layer
is
not
None
else
None
self
.
_weight_preprocess
=
weight_pre_layer
(
)
if
weight_pre_layer
is
not
None
else
None
def
forward
(
self
,
input
):
if
self
.
is_mp
:
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
input
,
group
=
self
.
model_parallel_group
)
else
:
input_parallel
=
input
if
self
.
_act_preprocess
is
not
None
:
input_parallel
=
self
.
_act_preprocess
(
input_parallel
)
quant_input
=
self
.
_fake_quant_input
(
input_parallel
)
weight
=
self
.
weight
if
self
.
_weight_preprocess
is
not
None
:
weight
=
self
.
_weight_preprocess
(
self
.
weight
)
quant_weight
=
self
.
_fake_quant_weight
(
weight
)
output_parallel
=
F
.
linear
(
x
=
quant_input
,
weight
=
quant_weight
,
bias
=
self
.
bias
,
name
=
self
.
name
)
if
self
.
gather_output
and
self
.
is_mp
:
output
=
paddle
.
distributed
.
collective
.
_c_concat
(
output_parallel
,
group
=
self
.
model_parallel_group
)
else
:
output
=
output_parallel
return
output
class
QuantizedRowParallelLinear
(
Layer
):
def
__init__
(
self
,
layer
,
weight_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'abs_max'
,
weight_pre_layer
=
None
,
act_pre_layer
=
None
,
weight_quant_layer
=
None
,
act_quant_layer
=
None
):
super
(
QuantizedRowParallelLinear
,
self
).
__init__
()
assert
weight_quant_layer
is
None
,
"When quantizing RowParallelLinear, weight_quant_layer cannot defined by yourself."
assert
act_quant_layer
is
None
,
"When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself."
# For Linear
self
.
weight
=
getattr
(
layer
,
'weight'
)
self
.
bias
=
getattr
(
layer
,
'bias'
)
self
.
name
=
getattr
(
layer
,
'_name'
)
# For FakeQuant
self
.
_linear_quant_axis
=
1
self
.
input_is_parallel
=
getattr
(
layer
,
'input_is_parallel'
)
self
.
is_mp
=
getattr
(
layer
,
'is_mp'
)
self
.
model_parallel_group
=
getattr
(
layer
,
'model_parallel_group'
)
self
.
_fake_quant_weight
=
_get_fake_quant_type
(
weight_quantize_type
,
name
=
self
.
weight
.
name
,
moving_rate
=
moving_rate
,
quant_bits
=
weight_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
True
,
channel_num
=
self
.
weight
.
shape
[
self
.
_linear_quant_axis
],
quant_axis
=
self
.
_linear_quant_axis
,
reduce_type
=
'max'
if
paddle
.
distributed
.
get_world_size
()
>
1
else
None
)
self
.
_fake_quant_input
=
_get_fake_quant_type
(
activation_quantize_type
,
name
=
layer
.
full_name
(),
moving_rate
=
moving_rate
,
quant_bits
=
activation_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
False
,
reduce_type
=
'max'
if
paddle
.
distributed
.
get_world_size
()
>
1
else
None
)
self
.
_act_preprocess
=
act_pre_layer
(
)
if
act_pre_layer
is
not
None
else
None
self
.
_weight_preprocess
=
weight_pre_layer
(
)
if
weight_pre_layer
is
not
None
else
None
def
forward
(
self
,
input
):
if
self
.
input_is_parallel
or
(
not
self
.
is_mp
):
input_parallel
=
input
else
:
# split last dim
input_parallel
=
paddle
.
distributed
.
collective
.
_c_split
(
input
,
group
=
self
.
model_parallel_group
)
if
self
.
_act_preprocess
is
not
None
:
input_parallel
=
self
.
_act_preprocess
(
input_parallel
)
quant_input
=
self
.
_fake_quant_input
(
input_parallel
)
weight
=
self
.
weight
if
self
.
_weight_preprocess
is
not
None
:
weight
=
self
.
_weight_preprocess
(
self
.
weight
)
quant_weight
=
self
.
_fake_quant_weight
(
weight
)
output_parallel
=
F
.
linear
(
x
=
quant_input
,
weight
=
quant_weight
,
name
=
self
.
name
)
if
self
.
is_mp
:
output_
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
else
:
output_
=
output_parallel
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
return
output
class
MAOutputScaleLayer
(
Layer
):
"""
Add MovingAverageMaxScale layer to the behind of the input layer.
Calculate the scale (moving average abs max) for the output of the input layer.
"""
def
__init__
(
self
,
layer
=
None
,
moving_rate
=
0.9
,
name
=
None
,
dtype
=
'float32'
):
def
__init__
(
self
,
layer
=
None
,
moving_rate
=
0.9
,
name
=
None
,
dtype
=
'float32'
,
reduce_type
=
None
):
r
"""
Construct
"""
...
...
@@ -674,7 +880,7 @@ class MAOutputScaleLayer(Layer):
if
name
is
None
:
name
=
layer
.
full_name
()
self
.
_ma_output_scale
=
\
MovingAverageAbsMaxScale
(
name
,
moving_rate
,
dtype
)
MovingAverageAbsMaxScale
(
name
,
moving_rate
,
dtype
,
reduce_type
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
out
=
self
.
_layer
(
*
inputs
,
**
kwargs
)
...
...
@@ -697,6 +903,7 @@ class FakeQuantMAOutputScaleLayer(Layer):
activation_bits
=
8
,
moving_rate
=
0.9
,
name
=
None
,
reduce_type
=
None
,
*
args
,
**
kwargs
):
...
...
@@ -708,7 +915,8 @@ class FakeQuantMAOutputScaleLayer(Layer):
moving_rate
=
moving_rate
,
quant_bits
=
activation_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
False
)
quant_on_weight
=
False
,
reduce_type
=
reduce_type
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
out
=
self
.
_layer
(
*
inputs
,
**
kwargs
)
...
...
@@ -723,7 +931,8 @@ def _get_fake_quant_type(quant_type, **kwargs):
call_args
=
{
"name"
:
kwargs
.
get
(
"name"
,
None
),
"quant_bits"
:
kwargs
.
get
(
"quant_bits"
,
8
),
"dtype"
:
kwargs
.
get
(
"dtype"
,
"float32"
)
"dtype"
:
kwargs
.
get
(
"dtype"
,
"float32"
),
"reduce_type"
:
kwargs
.
get
(
"reduce_type"
,
None
)
}
if
quant_type
==
'abs_max'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录