diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 29f4707124c6c315c35d5bc0b2fa49b70b8b2372..27b634e2ddbdd3bcaaa2a01fb0dc312c8fee18e0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -48,7 +48,10 @@ class ImperativeQuantAware(object): """ def __init__(self, - quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'], + quantizable_layer_type=[ + 'Conv2D', 'Linear', 'Conv2DTranspose', + 'ColumnParallelLinear', 'RowParallelLinear' + ], weight_quantize_type='abs_max', activation_quantize_type='moving_average_abs_max', weight_bits=8, @@ -431,12 +434,14 @@ class ImperativeQuantizeOutputs(object): parent_layer, sub_name = \ utils.find_parent_layer_and_sub_name(model, cur_name) + reduce_type = None + if isinstance(cur_layer, tuple(utils.fake_quant_output_layers)): cur_quant_layer = quant_layers.FakeQuantMAOutputScaleLayer( - cur_layer, self._moving_rate) + cur_layer, self._moving_rate, reduce_type=reduce_type) else: cur_quant_layer = quant_layers.MAOutputScaleLayer( - cur_layer, self._moving_rate) + cur_layer, self._moving_rate, reduce_type=reduce_type) setattr(parent_layer, sub_name, cur_quant_layer) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index 1ac6eec80d94fd1d09afa641fe8757a2d6f23f0d..fafd8d70c800f8968a27e7c7e08411e8fb1129e5 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -16,36 +16,63 @@ import math import numpy as np import paddle +from paddle.distributed import fleet import paddle.nn.quant.quant_layers as quant_layers from ..utils import _get_op_input_var_names, _get_op_output_var_names, _get_output_name_index, _get_input_name_index layer_name_map = { - 'Conv2DTranspose': paddle.nn.Conv2DTranspose, - 'Conv2D': paddle.nn.Conv2D, - 'Linear': paddle.nn.Linear, - 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D, - 'AdaptiveMaxPool2D': paddle.nn.AdaptiveMaxPool2D, - 'AvgPool2D': paddle.nn.AvgPool2D, - 'MaxPool2D': paddle.nn.MaxPool2D, - 'Hardswish': paddle.nn.Hardswish, - 'LeakyReLU': paddle.nn.LeakyReLU, - 'PReLU': paddle.nn.PReLU, - 'ReLU': paddle.nn.ReLU, - 'ReLU6': paddle.nn.ReLU6, - 'Sigmoid': paddle.nn.Sigmoid, - 'Softmax': paddle.nn.Softmax, - 'Swish': paddle.nn.Swish, - 'Tanh': paddle.nn.Tanh, - 'Hardswish': paddle.nn.Hardswish, - 'BatchNorm': paddle.nn.BatchNorm, - 'GroupNorm': paddle.nn.GroupNorm, - 'LayerNorm': paddle.nn.LayerNorm, + 'Conv2DTranspose': + paddle.nn.Conv2DTranspose, + 'Conv2D': + paddle.nn.Conv2D, + 'Linear': + paddle.nn.Linear, + 'AdaptiveAvgPool2D': + paddle.nn.AdaptiveAvgPool2D, + 'AdaptiveMaxPool2D': + paddle.nn.AdaptiveMaxPool2D, + 'AvgPool2D': + paddle.nn.AvgPool2D, + 'MaxPool2D': + paddle.nn.MaxPool2D, + 'Hardswish': + paddle.nn.Hardswish, + 'LeakyReLU': + paddle.nn.LeakyReLU, + 'PReLU': + paddle.nn.PReLU, + 'ReLU': + paddle.nn.ReLU, + 'ReLU6': + paddle.nn.ReLU6, + 'Sigmoid': + paddle.nn.Sigmoid, + 'Softmax': + paddle.nn.Softmax, + 'Swish': + paddle.nn.Swish, + 'Tanh': + paddle.nn.Tanh, + 'Hardswish': + paddle.nn.Hardswish, + 'BatchNorm': + paddle.nn.BatchNorm, + 'GroupNorm': + paddle.nn.GroupNorm, + 'LayerNorm': + paddle.nn.LayerNorm, + 'ColumnParallelLinear': + fleet.meta_parallel.parallel_layers.mp_layers.ColumnParallelLinear, + 'RowParallelLinear': + fleet.meta_parallel.parallel_layers.mp_layers.RowParallelLinear } # Apply fake quant for the inputs of these layers fake_quant_input_layers = [ - paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose + paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose, + fleet.meta_parallel.RowParallelLinear, + fleet.meta_parallel.ColumnParallelLinear ] # Apply fake quant for the output of these layers @@ -65,7 +92,9 @@ fake_quant_leaf_layers = [ fake_quant_wrap_layers = [ quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear, - quant_layers.QuantizedConv2DTranspose + quant_layers.QuantizedConv2DTranspose, + quant_layers.QuantizedColumnParallelLinear, + quant_layers.QuantizedRowParallelLinear ] # The weight format of these layers is Cin * Cout * H * W diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9f899552e698732eaccbbf4c7e3474bbda29152e..be92f4eed5e99ea2318d764008d73c8ba12e5535 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -82,6 +82,7 @@ list(APPEND DIST_TEST_OPS test_collective_alltoall_single) list(APPEND DIST_TEST_OPS test_eager_dist_api) list(APPEND DIST_TEST_OPS test_collective_batch_isend_irecv) list(APPEND DIST_TEST_OPS test_collective_reduce_scatter) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_qat) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. @@ -352,6 +353,7 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_eager_dist_api) list(REMOVE_ITEM TEST_OPS test_collective_batch_isend_irecv) list(REMOVE_ITEM TEST_OPS test_collective_reduce_scatter) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_qat) elseif(WITH_GPU) if(${CUDNN_VERSION} VERSION_LESS 7100) @@ -1607,6 +1609,7 @@ if(WITH_DISTRIBUTE set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 100) set_tests_properties(test_collective_batch_isend_irecv PROPERTIES TIMEOUT 100) set_tests_properties(test_collective_reduce_scatter PROPERTIES TIMEOUT 100) + set_tests_properties(test_parallel_dygraph_qat PROPERTIES TIMEOUT 120) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py new file mode 100644 index 0000000000000000000000000000000000000000..aefe03b26108059c617f1d590848ab452a6a62db --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_qat.py @@ -0,0 +1,326 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import os +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle.io import DataLoader, Dataset +import unittest +import paddle.nn as nn +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware +from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + rank_id) + + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 + + +def get_attr(layer, name): + if getattr(layer, name, None) is not None: + return getattr(layer, name, None) + else: + return get_attr(layer._layer, name) + + +def get_gpus(selected_gpus): + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + +def get_cluster_from_args(selected_gpus): + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) + + +def parallel_matmul(lm_output, logit_weights, parallel_output): + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + world_size = hcg.get_model_parallel_world_size() + rank = hcg.get_model_parallel_rank() + + if world_size > 1: + input_parallel = paddle.distributed.collective._c_identity( + lm_output, group=model_parallel_group) + + logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True) + + if parallel_output: + return logits + + return paddle.distributed.collective._c_concat( + logits, group=model_parallel_group) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + +class PACT(nn.Layer): + + def __init__(self, init_value=20): + super(PACT, self).__init__() + alpha_attr = paddle.ParamAttr( + name=self.full_name() + ".pact", + initializer=paddle.nn.initializer.Constant(value=init_value)) + self.alpha = self.create_parameter(shape=[1], + attr=alpha_attr, + dtype='float32') + + def forward(self, x): + out_left = paddle.nn.functional.relu(x - self.alpha) + out_right = paddle.nn.functional.relu(-self.alpha - x) + x = x - out_left + out_right + return x + + +class SimpleMPNet(nn.Layer): + + def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1, + np_fc2, mp_id): + super(SimpleMPNet, self).__init__() + + if mp_id == 0: + init_fc1_data = np_fc1[:, :(inner_size // 2)] + init_fc2_data = np_fc2[:(inner_size // 2), :] + else: + init_fc1_data = np_fc1[:, (inner_size // 2):] + init_fc2_data = np_fc2[(inner_size // 2):, :] + + self.linear1 = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc1_data)), + gather_output=False, + has_bias=True) + + self.linear2 = fleet.meta_parallel.RowParallelLinear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc2_data)), + input_is_parallel=True, + has_bias=True) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.embedding = fleet.meta_parallel.VocabParallelEmbedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=1.)) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = parallel_matmul(x, get_attr(self.embedding, "weight"), False) + return x + + +class SimpleDPNet(nn.Layer): + + def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1, + np_fc2): + + super(SimpleDPNet, self).__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=1.)) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, + get_attr(self.embedding, "weight"), + transpose_y=True) + return x + + +class TestDistMPTraning(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1 + } + fleet.init(is_collective=True, strategy=strategy) + self.onnx_format = False + self.check_export_model_accuracy = True + self.diff_threshold = 0.01 + self.fuse_conv_bn = False + + def train_batch(self, batch, model, optimizer, is_mp): + output = model(batch) + loss = output.mean() + loss.backward() # do backward + optimizer.step() # update parameters + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) + return optimizer + + def build_model_optimizer(self, + weight_quantize_type, + activation_quantize_type, + use_pact=False): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + imperative_qat = ImperativeQuantAware( + weight_quantize_type=weight_quantize_type, + activation_quantize_type=activation_quantize_type, + fuse_conv_bn=self.fuse_conv_bn, + act_preprocess_layer=PACT if use_pact else None) + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.ones((hidden_size, inner_size)) + np_fc2 = np.ones( + (inner_size, + hidden_size)) #np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleMPNet(vocab_size, hidden_size, inner_size, output_size, + np_fc1, np_fc2, mp_id) + model_a = imperative_qat.quantize(model_a) + optimizer_a = self.build_optimizer(model_a) + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size, + np_fc1, np_fc2) + model_b = imperative_qat.quantize(model_b) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b + + def train(self, model_a, optimizer_a, model_b, optimizer_b): + + for epoch in range(5): + + np_data = np.random.randint(0, vocab_size, ( + batch_size, + seq_length, + )) + batch = paddle.to_tensor(np_data) + loss_a = self.train_batch(batch, model_a, optimizer_a, True) + loss_b = self.train_batch(batch, model_b, optimizer_b, False) + + np.testing.assert_allclose(loss_a.numpy(), + loss_b.numpy(), + rtol=1e-6) + + def test_mp_model_1(self): + if not fluid.core.is_compiled_with_cuda( + ) or fluid.core.get_cuda_device_count() == 0: + return + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max') + self.train(model_a, optimizer_a, model_b, optimizer_b) + + def test_mp_model_2(self): + if not fluid.core.is_compiled_with_cuda( + ) or fluid.core.get_cuda_device_count() == 0: + return + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( + weight_quantize_type='channel_wise_abs_max', + activation_quantize_type='moving_average_abs_max', + use_pact=True) + self.train(model_a, optimizer_a, model_b, optimizer_b) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b2da46740ddfd98678877c491664e389b4ae5c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_qat.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import time +import paddle +import paddle.fluid as fluid +import copy +import os +import subprocess + +from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc + + +def get_cluster_from_args(selected_gpus): + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) + + +def get_gpus(selected_gpus): + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + +def start_local_trainers(cluster, + pod, + training_script, + training_script_args, + eager_mode=True, + log_dir=None): + current_env = copy.copy(os.environ.copy()) + #paddle broadcast ncclUniqueId use socket, and + #proxy maybe make trainers unreachable, so delete them. + #if we set them to "", grpc will log error message "bad uri" + #so just delete them. + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + for t in pod.trainers: + proc_env = { + "FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]), + "PADDLE_TRAINER_ID": "%d" % t.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) + } + + if not eager_mode: + proc_env["FLAGS_enable_eager_mode"] = "%d" % 0 + + current_env.update(proc_env) + + print("trainer proc env:{}".format(current_env)) + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + cmd = "python -m coverage run --branch -p " + training_script + else: + cmd = "python -u " + training_script + + print("start trainer proc:{} env:{}".format(cmd, proc_env)) + + fn = None + + proc = subprocess.Popen(cmd.split(" "), env=current_env) + + tp = TrainerProc() + tp.proc = proc + tp.rank = t.rank + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + +class TestMultipleGpus(unittest.TestCase): + + def run_2gpu(self, target_file_name, eager_mode=True): + if not fluid.core.is_compiled_with_cuda( + ) or fluid.core.get_cuda_device_count() == 0: + return + + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + cluster, pod = get_cluster_from_args(selected_gpus) + + procs = start_local_trainers(cluster, + pod, + eager_mode=eager_mode, + training_script=target_file_name, + training_script_args=[]) + + while True: + alive = watch_local_trainers(procs, cluster.trainers_endpoints()) + + if not alive: + print("Local procs complete, POD info:{}".format(pod)) + break + time.sleep(3) + + +class TestDataParallelQAT(TestMultipleGpus): + + def test_multiple_gpus_qat(self): + self.run_2gpu('hybrid_parallel_qat.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/quant/quant_layers.py b/python/paddle/nn/quant/quant_layers.py index 62fe8087c4fdb7262b478e66976c518e60554e27..b2fc03b1b9003068b5c02dbd234f55d507a58305 100644 --- a/python/paddle/nn/quant/quant_layers.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle.framework import core from paddle.fluid import dygraph_utils from paddle.utils import unique_name @@ -37,6 +38,8 @@ __all__ = [ 'MAOutputScaleLayer', 'FakeQuantMAOutputScaleLayer', 'QuantStub', + 'QuantizedRowParallelLinear', + 'QuantizedColumnParallelLinear', ] _logger = get_logger(__name__, @@ -58,10 +61,12 @@ class FakeQuantAbsMax(Layer): name=None, quant_bits=8, dtype='float32', - quant_on_weight=False): + quant_on_weight=False, + reduce_type=None): super(FakeQuantAbsMax, self).__init__() self._quant_bits = quant_bits self._name = name + self._reduce_type = reduce_type scale_prefix = "{}.scale".format( name) if name else 'quant_dequant.scale' self._scale_name = unique_name.generate(scale_prefix) @@ -86,6 +91,10 @@ class FakeQuantAbsMax(Layer): dtype=input.dtype, persistable=False) out_scale = self._scale + if self._reduce_type == "max": + paddle.distributed.all_reduce( + out_scale, op=paddle.distributed.ReduceOp.MAX) + if not out_scale: out_scale = _varbase_creator( type=core.VarDesc.VarType.LOD_TENSOR, @@ -139,11 +148,12 @@ class FakeQuantMovingAverageAbsMax(Layer): name=None, moving_rate=0.9, quant_bits=8, - dtype='float32'): + dtype='float32', + reduce_type=None): super(FakeQuantMovingAverageAbsMax, self).__init__() self._moving_rate = moving_rate self._quant_bits = quant_bits - + self._reduce_type = reduce_type scale_prefix = "{}.scale".format( name) if name else 'quant_dequant.scale' scale_attr = ParamAttr(name=unique_name.generate(scale_prefix), @@ -184,12 +194,17 @@ class FakeQuantMovingAverageAbsMax(Layer): shape=input.shape, dtype=input.dtype, persistable=False) + if self._reduce_type == "max": + paddle.distributed.all_reduce( + self._scale, op=paddle.distributed.ReduceOp.MAX) + state = self._state if self.training else None accum = self._accum if self.training else None out, _, _, _ = _C_ops.fake_quantize_dequantize_moving_average_abs_max( input, self._scale, accum, state, quant_out, self._scale, state, accum, *attrs) + return out check_variable_and_dtype(input, 'input', ['float32'], @@ -231,7 +246,8 @@ class FakeQuantChannelWiseAbsMax(Layer): quant_bits=8, quant_axis=0, dtype='float32', - quant_on_weight=False): + quant_on_weight=False, + reduce_type=None): assert quant_on_weight == True, "Channel_wise only can be used on weight quantization." super(FakeQuantChannelWiseAbsMax, self).__init__() self._quant_bits = quant_bits @@ -239,6 +255,7 @@ class FakeQuantChannelWiseAbsMax(Layer): self._dtype = dtype self._name = name self._channel_num = channel_num + self._reduce_type = reduce_type scale_prefix = "{}.scale".format( name) if name else 'quant_dequant.scale' self._scale_name = unique_name.generate(scale_prefix) @@ -265,6 +282,9 @@ class FakeQuantChannelWiseAbsMax(Layer): persistable=False) out_scale = self._scale + if self._reduce_type == "max": + paddle.distributed.all_reduce( + out_scale, op=paddle.distributed.ReduceOp.MAX) if out_scale is None: out_scale = _varbase_creator( type=core.VarDesc.VarType.LOD_TENSOR, @@ -309,7 +329,11 @@ class FakeQuantChannelWiseAbsMax(Layer): class MovingAverageAbsMaxScale(Layer): - def __init__(self, name=None, moving_rate=0.9, dtype='float32'): + def __init__(self, + name=None, + moving_rate=0.9, + dtype='float32', + reduce_type=None): r""" MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer. Its computational formula is described as below: @@ -319,7 +343,7 @@ class MovingAverageAbsMaxScale(Layer): """ super(MovingAverageAbsMaxScale, self).__init__() self._moving_rate = moving_rate - + self._reduce_type = reduce_type scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' scale_name = unique_name.generate(scale_prefix) scale_attr = ParamAttr(name=scale_name, @@ -352,13 +376,18 @@ class MovingAverageAbsMaxScale(Layer): if in_dynamic_mode(): attrs = ('moving_rate', self._moving_rate, 'is_test', not self.training) - state = self._state if self.training else None - accum = self._accum if self.training else None + quant_out = _varbase_creator(type=input.type, name="{}.tmp".format(input.name), shape=input.shape, dtype=input.dtype, persistable=False) + if self._reduce_type == "max": + paddle.distributed.all_reduce( + self._scale, op=paddle.distributed.ReduceOp.MAX) + + state = self._state if self.training else None + accum = self._accum if self.training else None out, _, _, _ = _C_ops.moving_average_abs_max_scale( input, accum, state, quant_out, self._scale, state, accum, @@ -659,13 +688,190 @@ class QuantizedLinear(Layer): return out +class QuantizedColumnParallelLinear(Layer): + + def __init__(self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None): + super(QuantizedColumnParallelLinear, self).__init__() + ''' + + ''' + assert weight_quant_layer is None, "When quantizing ColumnParallelLinear, weight_quant_layer should be None." + assert act_quant_layer is None, "When quantizing ColumnParallelLinear, act_quant_layer should be None." + + self.weight = getattr(layer, 'weight') + self.bias = getattr(layer, 'bias') + self.name = getattr(layer, '_name') + # For FakeQuant + self._linear_quant_axis = 1 + + self.is_mp = getattr(layer, 'is_mp') + self.model_parallel_group = getattr(layer, 'model_parallel_group') + self.gather_output = getattr(layer, 'gather_output') + + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True, + channel_num=self.weight.shape[self._linear_quant_axis], + quant_axis=self._linear_quant_axis, + reduce_type='max' + if paddle.distributed.get_world_size() > 1 else None) + + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype, + quant_on_weight=False, + reduce_type=None) + + self._act_preprocess = act_pre_layer( + ) if act_pre_layer is not None else None + self._weight_preprocess = weight_pre_layer( + ) if weight_pre_layer is not None else None + + def forward(self, input): + if self.is_mp: + input_parallel = paddle.distributed.collective._c_identity( + input, group=self.model_parallel_group) + else: + input_parallel = input + + if self._act_preprocess is not None: + input_parallel = self._act_preprocess(input_parallel) + quant_input = self._fake_quant_input(input_parallel) + + weight = self.weight + if self._weight_preprocess is not None: + weight = self._weight_preprocess(self.weight) + quant_weight = self._fake_quant_weight(weight) + + output_parallel = F.linear(x=quant_input, + weight=quant_weight, + bias=self.bias, + name=self.name) + + if self.gather_output and self.is_mp: + output = paddle.distributed.collective._c_concat( + output_parallel, group=self.model_parallel_group) + else: + output = output_parallel + return output + + +class QuantizedRowParallelLinear(Layer): + + def __init__(self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None): + super(QuantizedRowParallelLinear, self).__init__() + assert weight_quant_layer is None, "When quantizing RowParallelLinear, weight_quant_layer cannot defined by yourself." + assert act_quant_layer is None, "When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself." + + # For Linear + self.weight = getattr(layer, 'weight') + self.bias = getattr(layer, 'bias') + self.name = getattr(layer, '_name') + # For FakeQuant + self._linear_quant_axis = 1 + + self.input_is_parallel = getattr(layer, 'input_is_parallel') + self.is_mp = getattr(layer, 'is_mp') + self.model_parallel_group = getattr(layer, 'model_parallel_group') + + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True, + channel_num=self.weight.shape[self._linear_quant_axis], + quant_axis=self._linear_quant_axis, + reduce_type='max' + if paddle.distributed.get_world_size() > 1 else None) + + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype, + quant_on_weight=False, + reduce_type='max' + if paddle.distributed.get_world_size() > 1 else None) + + self._act_preprocess = act_pre_layer( + ) if act_pre_layer is not None else None + self._weight_preprocess = weight_pre_layer( + ) if weight_pre_layer is not None else None + + def forward(self, input): + if self.input_is_parallel or (not self.is_mp): + input_parallel = input + else: + # split last dim + input_parallel = paddle.distributed.collective._c_split( + input, group=self.model_parallel_group) + + if self._act_preprocess is not None: + input_parallel = self._act_preprocess(input_parallel) + quant_input = self._fake_quant_input(input_parallel) + + weight = self.weight + if self._weight_preprocess is not None: + weight = self._weight_preprocess(self.weight) + quant_weight = self._fake_quant_weight(weight) + + output_parallel = F.linear(x=quant_input, + weight=quant_weight, + name=self.name) + if self.is_mp: + output_ = paddle.distributed.collective._mp_allreduce( + output_parallel, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True) + else: + output_ = output_parallel + output = output_ + self.bias if self.bias is not None else output_ + return output + + class MAOutputScaleLayer(Layer): """ Add MovingAverageMaxScale layer to the behind of the input layer. Calculate the scale (moving average abs max) for the output of the input layer. """ - def __init__(self, layer=None, moving_rate=0.9, name=None, dtype='float32'): + def __init__(self, + layer=None, + moving_rate=0.9, + name=None, + dtype='float32', + reduce_type=None): r""" Construct """ @@ -674,7 +880,7 @@ class MAOutputScaleLayer(Layer): if name is None: name = layer.full_name() self._ma_output_scale = \ - MovingAverageAbsMaxScale(name, moving_rate, dtype) + MovingAverageAbsMaxScale(name, moving_rate, dtype, reduce_type) def forward(self, *inputs, **kwargs): out = self._layer(*inputs, **kwargs) @@ -697,6 +903,7 @@ class FakeQuantMAOutputScaleLayer(Layer): activation_bits=8, moving_rate=0.9, name=None, + reduce_type=None, *args, **kwargs): @@ -708,7 +915,8 @@ class FakeQuantMAOutputScaleLayer(Layer): moving_rate=moving_rate, quant_bits=activation_bits, dtype=self._dtype, - quant_on_weight=False) + quant_on_weight=False, + reduce_type=reduce_type) def forward(self, *inputs, **kwargs): out = self._layer(*inputs, **kwargs) @@ -723,7 +931,8 @@ def _get_fake_quant_type(quant_type, **kwargs): call_args = { "name": kwargs.get("name", None), "quant_bits": kwargs.get("quant_bits", 8), - "dtype": kwargs.get("dtype", "float32") + "dtype": kwargs.get("dtype", "float32"), + "reduce_type": kwargs.get("reduce_type", None) } if quant_type == 'abs_max':