diff --git a/paddleslim/quant/layers/__init__.py b/paddleslim/quant/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34e5a1ea510da11c5619ad896c21ebb5bfaa6aa9 --- /dev/null +++ b/paddleslim/quant/layers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_linear import QuantizedColumnParallelLinear, QuantizedRowParallelLinear + +__all__ = ["QuantizedColumnParallelLinear", "QuantizedRowParallelLinear"] \ No newline at end of file diff --git a/paddleslim/quant/layers/parallel_linear.py b/paddleslim/quant/layers/parallel_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..beb06996255a9d95f87e0bb743814e8cf98cfda2 --- /dev/null +++ b/paddleslim/quant/layers/parallel_linear.py @@ -0,0 +1,140 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.nn import Layer +from paddle.nn import functional as F + +from paddle.nn.quant.format import ConvertibleQuantedLayer + + +class QuantizedRowParallelLinear(ConvertibleQuantedLayer): + """ + The computational logic of QuantizedRowParallelLinear is the same as RowParallelLinear. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, layer: Layer, q_config): + super().__init__() + # For Linear + self.weight = layer.weight + self.bias = layer.bias + self._name = layer._name + self.input_is_parallel = layer.input_is_parallel + self.is_mp = layer.is_mp + self.model_parallel_group = layer.model_parallel_group + self.linear = layer.linear + + # For FakeQuant + self.weight_quanter = None + self.activation_quanter = None + if q_config.weight is not None: + self.weight_quanter = q_config.weight._instance(layer) + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward(self, input): + quant_input = input + quant_weight = self.weight + if self.activation_quanter is not None: + quant_input = self.activation_quanter(input) + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(self.weight) + return self._linear_forward(quant_input, quant_weight) + + def _linear_forward(self, input, weight): + if self.input_is_parallel or (not self.is_mp): + input_parallel = input + else: + # split last dim + input_parallel = paddle.distributed.collective._c_split( + input, group=self.model_parallel_group) + + if self.is_mp: + output_parallel = self.linear( + input_parallel, weight, name=self._name) + output_ = paddle.distributed.collective._mp_allreduce( + output_parallel, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True) + output = output_ + self.bias if self.bias is not None else output_ + else: + output = self.linear( + input_parallel, weight, self.bias, name=self._name) + return output + + def weights_to_quanters(self): + return [('weight', 'weight_quanter')] + + def activation_quanters(self): + return ['activation_quanter'] + + +class QuantizedColumnParallelLinear(ConvertibleQuantedLayer): + """ + The computational logic of QuantizedColumnParallelLinear is the same as ColumnParallelLinear. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, layer: Layer, q_config): + super().__init__() + # For Linear + self.weight = layer.weight + self.bias = layer.bias + self._name = layer._name + self.is_mp = layer.is_mp + self.model_parallel_group = layer.model_parallel_group + self.gather_output = layer.gather_output + self.linear = layer.linear + + # For FakeQuant + self.weight_quanter = None + self.activation_quanter = None + if q_config.weight is not None: + self.weight_quanter = q_config.weight._instance(layer) + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward(self, input): + quant_input = input + quant_weight = self.weight + if self.activation_quanter is not None: + quant_input = self.activation_quanter(input) + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(self.weight) + return self._linear_forward(quant_input, quant_weight) + + def _linear_forward(self, input, weight): + if self.is_mp: + input_parallel = paddle.distributed.collective._c_identity( + input, group=self.model_parallel_group) + else: + input_parallel = input + + output_parallel = self.linear( + input_parallel, weight, self.bias, name=self._name) + + if self.gather_output and self.is_mp: + output = paddle.distributed.collective._c_concat( + output_parallel, group=self.model_parallel_group) + else: + output = output_parallel + return output + + def weights_to_quanters(self): + return [('weight', 'weight_quanter')] + + def activation_quanters(self): + return ['activation_quanter'] diff --git a/paddleslim/quant/observers/emd.py b/paddleslim/quant/observers/emd.py index eeaf348fdddbba62e7617a21ecce344f0aee1189..02bea81a588158720b362081176a2f324788bac4 100644 --- a/paddleslim/quant/observers/emd.py +++ b/paddleslim/quant/observers/emd.py @@ -63,6 +63,7 @@ class EMDObserverLayer(UniformObserver): abs_max_value = float(paddle.max(paddle.flatten(inputs))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value s = 0.3 + scale_emd = abs_max_value while s <= 1.0: scale = s * abs_max_value s += 0.02 @@ -78,8 +79,8 @@ class EMDObserverLayer(UniformObserver): emd_loss = float(emd_loss) if emd_loss <= self._calibration_loss: self._calibration_loss = emd_loss - - return 0, scale + scale_emd = scale + return 0, scale_emd def cal_thresholds(self): """ Compute thresholds for MAX function. diff --git a/paddleslim/quant/observers/mse.py b/paddleslim/quant/observers/mse.py index 2641fadd45bdb705e76121640e128b2ceff43850..6deab94ad833cf18b3cbb1113e52b2dfcf827e0f 100644 --- a/paddleslim/quant/observers/mse.py +++ b/paddleslim/quant/observers/mse.py @@ -64,6 +64,7 @@ class MSEObserverLayer(UniformObserver): abs_max_value = float(paddle.max(paddle.abs(inputs.flatten()))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value s = 0.3 + scale_mse = abs_max_value while s <= 1.0: scale = s * abs_max_value s += 0.02 @@ -75,8 +76,8 @@ class MSEObserverLayer(UniformObserver): mse_loss = float(((inputs - quant_dequant_var)**2).mean()) if mse_loss <= self.calibration_loss: self.calibration_loss = mse_loss - - return 0, scale + scale_mse = scale + return 0, scale_mse def cal_thresholds(self): """ Compute thresholds for MAX function. diff --git a/paddleslim/quant/observers/uniform.py b/paddleslim/quant/observers/uniform.py index 216418a91c1d5bb053966e509da456cb5f17b991..d874fa687cc9924e604d7db696df2f4ab237aff6 100644 --- a/paddleslim/quant/observers/uniform.py +++ b/paddleslim/quant/observers/uniform.py @@ -89,7 +89,7 @@ class UniformObserver(BaseObserver): _max = max(self.max_value(), 0.) if self._symmetric: - self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2) + self._scale = max(-_min, _max) if self._sign: self._zero_point = 0 else: diff --git a/tests/distribution/test_quant_parallel_layers.py b/tests/distribution/test_quant_parallel_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..212a723b462818ba7769d7ef79d77090e0f8b8c2 --- /dev/null +++ b/tests/distribution/test_quant_parallel_layers.py @@ -0,0 +1,337 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import os +import unittest +import paddle +import tempfile +import random +import numpy as np +sys.path.append("../../") +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +import paddle.fluid as fluid +import paddle.nn as nn +from paddle.distributed.utils.launch_utils import find_free_ports, get_cluster +from paddle.quantization import QuantConfig +from paddle.quantization import QAT +from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver +from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer +from paddle.nn.quant.format import LinearDequanter, LinearQuanter +from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from paddleslim.quant.layers import QuantizedColumnParallelLinear, QuantizedRowParallelLinear +import logging +from paddleslim.common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + rank_id) + + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 + + +def get_attr(layer, name): + if getattr(layer, name, None) is not None: + return getattr(layer, name, None) + else: + return get_attr(layer._layer, name) + + +def get_gpus(selected_gpus): + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + +def get_cluster_from_args(selected_gpus): + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) + + +def parallel_matmul(lm_output, logit_weights, parallel_output): + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + world_size = hcg.get_model_parallel_world_size() + rank = hcg.get_model_parallel_rank() + + if world_size > 1: + input_parallel = paddle.distributed.collective._c_identity( + lm_output, group=model_parallel_group) + + logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True) + + if parallel_output: + return logits + + return paddle.distributed.collective._c_concat( + logits, group=model_parallel_group) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + +class SimpleMPNet(nn.Layer): + def __init__( + self, + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, ): + super().__init__() + + if mp_id == 0: + init_fc1_data = np_fc1[:, :(inner_size // 2)] + init_fc2_data = np_fc2[:(inner_size // 2), :] + else: + init_fc1_data = np_fc1[:, (inner_size // 2):] + init_fc2_data = np_fc2[(inner_size // 2):, :] + + self.linear1 = ColumnParallelLinear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc1_data)), + gather_output=False, + has_bias=True, ) + + self.linear2 = RowParallelLinear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc2_data)), + input_is_parallel=True, + has_bias=True, ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), ) + + self.embedding = VocabParallelEmbedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=1.0), ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = parallel_matmul(x, self.embedding.weight, False) + return x + + +class SimpleDPNet(nn.Layer): + def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1, + np_fc2): + + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), ) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=1.0), ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistMPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.onnx_format = False + self.check_export_model_accuracy = True + self.diff_threshold = 0.01 + self.fuse_conv_bn = False + + def train_batch(self, batch, model, optimizer, is_mp): + output = model(batch) + loss = output.mean() + loss.backward() # do backward + optimizer.step() # update parameters + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + optimizer = paddle.optimizer.SGD( + learning_rate=0.001, parameters=model.parameters()) + return optimizer + + def build_model_optimizer(self, qat): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.ones((hidden_size, inner_size)) + np_fc2 = np.ones((inner_size, hidden_size)) + + model_a = SimpleMPNet( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, ) + model_a = qat.quantize(model_a, inplace=True) + optimizer_a = self.build_optimizer(model_a) + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size, + np_fc1, np_fc2) + model_b = qat.quantize(model_b, inplace=True) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b + + def train(self, model_a, optimizer_a, model_b, optimizer_b): + + for epoch in range(5): + + np_data = np.random.randint( + 0, + vocab_size, + (batch_size, seq_length, ), ) + batch = paddle.to_tensor(np_data, dtype='int32') + loss_a = self.train_batch(batch, model_a, optimizer_a, True) + loss_b = self.train_batch(batch, model_b, optimizer_b, False) + + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-6) + + def test_mp_model_1(self): + if (not fluid.core.is_compiled_with_cuda() or + fluid.core.get_cuda_device_count() == 0): + return + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + observer = FakeQuanterWithAbsMaxObserver() + q_config = QuantConfig(activation=None, weight=None) + q_config.add_type_config( + ColumnParallelLinear, activation=observer, weight=observer) + q_config.add_type_config( + RowParallelLinear, activation=observer, weight=observer) + q_config.add_type_config( + nn.Linear, activation=observer, weight=observer) + q_config.add_qat_layer_mapping(ColumnParallelLinear, + QuantizedColumnParallelLinear) + q_config.add_qat_layer_mapping(RowParallelLinear, + QuantizedRowParallelLinear) + qat = QAT(q_config) + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( + qat) + self.train(model_a, optimizer_a, model_b, optimizer_b) + + def test_mp_model_2(self): + if (not fluid.core.is_compiled_with_cuda() or + fluid.core.get_cuda_device_count() == 0): + return + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + observer = FakeQuanterWithAbsMaxObserver() + q_config = QuantConfig(activation=None, weight=None) + q_config.add_type_config( + ColumnParallelLinear, activation=observer, weight=observer) + q_config.add_type_config( + RowParallelLinear, activation=observer, weight=observer) + q_config.add_type_config( + nn.Linear, activation=observer, weight=observer) + q_config.add_qat_layer_mapping(ColumnParallelLinear, + QuantizedColumnParallelLinear) + q_config.add_qat_layer_mapping(RowParallelLinear, + QuantizedRowParallelLinear) + qat = QAT(q_config) + + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( + qat) + self.train(model_a, optimizer_a, model_b, optimizer_b) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/quantization/test_observers_acc.py b/tests/quantization/test_observers_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b0368520116f13552be5dfe21aaebcc02ec25a --- /dev/null +++ b/tests/quantization/test_observers_acc.py @@ -0,0 +1,242 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("../../") +import os +import unittest +import paddle +import tempfile +import numpy as np + +from paddle.vision.models import resnet18 +from paddle.quantization import QuantConfig +from paddle.quantization import PTQ + +from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver +from paddleslim.quant.observers.hist import PercentHistObserverLayer +from paddleslim.quant.observers.kl import KLObserverLayer +from paddleslim.quant.observers.mse import MSEObserverLayer +from paddleslim.quant.observers.avg import AVGObserverLayer +from paddleslim.quant.observers.emd import EMDObserverLayer +from paddleslim.quant.observers.kl import KLObserverLayer +from paddle.nn.quant.format import LinearDequanter, LinearQuanter + +import logging +from paddleslim.common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + + +class ImperativeLenet(paddle.nn.Layer): + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(ImperativeLenet, self).__init__() + self.features = paddle.nn.Sequential( + paddle.nn.Conv2D( + in_channels=1, + out_channels=6, + kernel_size=3, + stride=1, + padding=1), + paddle.nn.AvgPool2D(kernel_size=2, stride=2), + paddle.nn.Conv2D( + in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0), paddle.nn.AvgPool2D(kernel_size=2, stride=2)) + + self.fc = paddle.nn.Sequential( + paddle.nn.Linear(in_features=400, out_features=120), + paddle.nn.Linear(in_features=120, out_features=84), + paddle.nn.Linear(in_features=84, out_features=num_classes), ) + + def forward(self, inputs): + x = self.features(inputs) + + x = paddle.flatten(x, 1) + x = self.fc(x) + return x + + +class TestPTQObserverAcc(unittest.TestCase): + def __init__(self, observer, observer_type, *args, **kvargs): + super(TestPTQObserverAcc, self).__init__(*args, **kvargs) + self.observer = observer + self.observer_type = observer_type + + def setUp(self): + paddle.set_device("cpu") + self.init_case() + self.dummy_input = paddle.rand([1, 3, 224, 224]) + self.temp_dir = tempfile.TemporaryDirectory(dir="./") + self.path = os.path.join(self.temp_dir.name, 'qat') + if not os.path.exists('ILSVRC2012_data_demo'): + os.system( + 'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz' + ) + os.system('tar -xf ILSVRC2012_data_demo.tar.gz') + seed = 1 + np.random.seed(seed) + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed + + def tearDown(self): + self.temp_dir.cleanup() + + def runTest(self): + self.test_convergence() + + def init_case(self): + self.q_config = QuantConfig(activation=None, weight=None) + self.q_config.add_type_config( + paddle.nn.Conv2D, activation=self.observer, weight=self.observer) + + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def test_convergence(self): + model = ImperativeLenet() + place = paddle.CUDAPlace(0) \ + if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + + transform = paddle.vision.transforms.Compose([ + paddle.vision.transforms.Transpose(), + paddle.vision.transforms.Normalize([127.5], [127.5]) + ]) + + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform) + val_dataset = paddle.vision.datasets.MNIST( + mode='test', backend='cv2', transform=transform) + + train_reader = paddle.io.DataLoader( + train_dataset, + drop_last=True, + places=place, + batch_size=64, + return_list=True) + test_reader = paddle.io.DataLoader( + val_dataset, places=place, batch_size=64, return_list=True) + + def train(model): + adam = paddle.optimizer.Adam( + learning_rate=0.0001, parameters=model.parameters()) + epoch_num = 1 + for epoch in range(epoch_num): + model.train() + for batch_id, data in enumerate(train_reader): + img = paddle.to_tensor(data[0]) + label = paddle.to_tensor(data[1]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + label = paddle.reshape(label, [-1, 1]) + + out = model(img) + acc = paddle.metric.accuracy(out, label) + loss = paddle.nn.functional.loss.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + model.clear_gradients() + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, + avg_loss.numpy(), acc.numpy())) + + def test(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader): + img = paddle.to_tensor(data[0]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + label = paddle.to_tensor(data[1]) + label = paddle.reshape(label, [-1, 1]) + + out = model(img) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + if batch_id % 100 == 0: + _logger.info( + "Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + _logger.info("Test | Average: acc_top1 {}, acc_top5 {}".format( + np.mean(avg_acc[0]), np.mean(avg_acc[1]))) + return np.mean(avg_acc[0]), np.mean(avg_acc[1]) + + def ptq_sample(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader): + img = paddle.to_tensor(data[0]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + label = paddle.to_tensor(data[1]) + label = paddle.reshape(label, [-1, 1]) + + out = model(img) + + if batch_id % 100 == 0: + _logger.info("PTQ sampling | step {}".format(batch_id)) + + train(model) + top1_1, top5_1 = test(model) + ptq = PTQ(self.q_config) + model.eval() + quant_model = ptq.quantize(model, inplace=False) + + ptq_sample(quant_model) + converted_model = ptq.convert(quant_model, inplace=False) + top1_2, top5_2 = test(converted_model) + + _logger.info( + "Before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) + _logger.info( + "After quantization: top1: {}, top5: {}".format(top1_2, top5_2)) + _logger.info("\n") + + diff = 0.01 + self.assertTrue( + top1_1 - top1_2 < diff, + msg="The acc of quant model is too lower than fp32 model") + _logger.info('done') + return + + +observer_suite = unittest.TestSuite() +observer_suite.addTest( + TestPTQObserverAcc( + observer=HistObserver(sign=True, symmetric=True), + observer_type=PercentHistObserverLayer)) +observer_suite.addTest( + TestPTQObserverAcc( + observer=KLObserver(bins_count=256), observer_type=KLObserverLayer)) + +observer_suite.addTest( + TestPTQObserverAcc(observer=AVGObserver(), observer_type=AVGObserverLayer)) +observer_suite.addTest( + TestPTQObserverAcc(observer=EMDObserver(), observer_type=EMDObserverLayer)) +observer_suite.addTest( + TestPTQObserverAcc(observer=MSEObserver(), observer_type=MSEObserverLayer)) + +if __name__ == '__main__': + runner = unittest.TextTestRunner(verbosity=2) + runner.run(observer_suite) + os.system('rm -rf ILSVRC2012_data_demo.tar.gz') + os.system('rm -rf ILSVRC2012_data_demo')