未验证 提交 e2b5c162 编写于 作者: C Chang Xu 提交者: GitHub

Add QuantizedParallelLinear & Update Uniform (#1694)

上级 3b2ed2cf
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .parallel_linear import QuantizedColumnParallelLinear, QuantizedRowParallelLinear
__all__ = ["QuantizedColumnParallelLinear", "QuantizedRowParallelLinear"]
\ No newline at end of file
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.nn import Layer
from paddle.nn import functional as F
from paddle.nn.quant.format import ConvertibleQuantedLayer
class QuantizedRowParallelLinear(ConvertibleQuantedLayer):
"""
The computational logic of QuantizedRowParallelLinear is the same as RowParallelLinear.
The only difference is that its inputs are all fake quantized.
"""
def __init__(self, layer: Layer, q_config):
super().__init__()
# For Linear
self.weight = layer.weight
self.bias = layer.bias
self._name = layer._name
self.input_is_parallel = layer.input_is_parallel
self.is_mp = layer.is_mp
self.model_parallel_group = layer.model_parallel_group
self.linear = layer.linear
# For FakeQuant
self.weight_quanter = None
self.activation_quanter = None
if q_config.weight is not None:
self.weight_quanter = q_config.weight._instance(layer)
if q_config.activation is not None:
self.activation_quanter = q_config.activation._instance(layer)
def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._linear_forward(quant_input, quant_weight)
def _linear_forward(self, input, weight):
if self.input_is_parallel or (not self.is_mp):
input_parallel = input
else:
# split last dim
input_parallel = paddle.distributed.collective._c_split(
input, group=self.model_parallel_group)
if self.is_mp:
output_parallel = self.linear(
input_parallel, weight, name=self._name)
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output = output_ + self.bias if self.bias is not None else output_
else:
output = self.linear(
input_parallel, weight, self.bias, name=self._name)
return output
def weights_to_quanters(self):
return [('weight', 'weight_quanter')]
def activation_quanters(self):
return ['activation_quanter']
class QuantizedColumnParallelLinear(ConvertibleQuantedLayer):
"""
The computational logic of QuantizedColumnParallelLinear is the same as ColumnParallelLinear.
The only difference is that its inputs are all fake quantized.
"""
def __init__(self, layer: Layer, q_config):
super().__init__()
# For Linear
self.weight = layer.weight
self.bias = layer.bias
self._name = layer._name
self.is_mp = layer.is_mp
self.model_parallel_group = layer.model_parallel_group
self.gather_output = layer.gather_output
self.linear = layer.linear
# For FakeQuant
self.weight_quanter = None
self.activation_quanter = None
if q_config.weight is not None:
self.weight_quanter = q_config.weight._instance(layer)
if q_config.activation is not None:
self.activation_quanter = q_config.activation._instance(layer)
def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._linear_forward(quant_input, quant_weight)
def _linear_forward(self, input, weight):
if self.is_mp:
input_parallel = paddle.distributed.collective._c_identity(
input, group=self.model_parallel_group)
else:
input_parallel = input
output_parallel = self.linear(
input_parallel, weight, self.bias, name=self._name)
if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
output_parallel, group=self.model_parallel_group)
else:
output = output_parallel
return output
def weights_to_quanters(self):
return [('weight', 'weight_quanter')]
def activation_quanters(self):
return ['activation_quanter']
......@@ -63,6 +63,7 @@ class EMDObserverLayer(UniformObserver):
abs_max_value = float(paddle.max(paddle.flatten(inputs)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
scale_emd = abs_max_value
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
......@@ -78,8 +79,8 @@ class EMDObserverLayer(UniformObserver):
emd_loss = float(emd_loss)
if emd_loss <= self._calibration_loss:
self._calibration_loss = emd_loss
return 0, scale
scale_emd = scale
return 0, scale_emd
def cal_thresholds(self):
""" Compute thresholds for MAX function.
......
......@@ -64,6 +64,7 @@ class MSEObserverLayer(UniformObserver):
abs_max_value = float(paddle.max(paddle.abs(inputs.flatten())))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
scale_mse = abs_max_value
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
......@@ -75,8 +76,8 @@ class MSEObserverLayer(UniformObserver):
mse_loss = float(((inputs - quant_dequant_var)**2).mean())
if mse_loss <= self.calibration_loss:
self.calibration_loss = mse_loss
return 0, scale
scale_mse = scale
return 0, scale_mse
def cal_thresholds(self):
""" Compute thresholds for MAX function.
......
......@@ -89,7 +89,7 @@ class UniformObserver(BaseObserver):
_max = max(self.max_value(), 0.)
if self._symmetric:
self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2)
self._scale = max(-_min, _max)
if self._sign:
self._zero_point = 0
else:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import os
import unittest
import paddle
import tempfile
import random
import numpy as np
sys.path.append("../../")
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
import paddle.fluid as fluid
import paddle.nn as nn
from paddle.distributed.utils.launch_utils import find_free_ports, get_cluster
from paddle.quantization import QuantConfig
from paddle.quantization import QAT
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer
from paddle.nn.quant.format import LinearDequanter, LinearQuanter
from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from paddleslim.quant.layers import QuantizedColumnParallelLinear, QuantizedRowParallelLinear
import logging
from paddleslim.common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + rank_id)
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
def get_attr(layer, name):
if getattr(layer, name, None) is not None:
return getattr(layer, name, None)
else:
return get_attr(layer._layer, name)
def get_gpus(selected_gpus):
selected_gpus = [x.strip() for x in selected_gpus.split(',')]
return selected_gpus
def get_cluster_from_args(selected_gpus):
cluster_node_ips = '127.0.0.1'
node_ip = '127.0.0.1'
node_ips = [x.strip() for x in cluster_node_ips.split(',')]
node_ips.index(node_ip)
free_ports = None
free_ports = find_free_ports(len(selected_gpus))
if free_ports is not None:
free_ports = list(free_ports)
trainer_endpoints = []
for ip in node_ips:
trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports])
return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus)
def parallel_matmul(lm_output, logit_weights, parallel_output):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
world_size = hcg.get_model_parallel_world_size()
rank = hcg.get_model_parallel_rank()
if world_size > 1:
input_parallel = paddle.distributed.collective._c_identity(
lm_output, group=model_parallel_group)
logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True)
if parallel_output:
return logits
return paddle.distributed.collective._c_concat(
logits, group=model_parallel_group)
else:
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
return logits
class SimpleMPNet(nn.Layer):
def __init__(
self,
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id, ):
super().__init__()
if mp_id == 0:
init_fc1_data = np_fc1[:, :(inner_size // 2)]
init_fc2_data = np_fc2[:(inner_size // 2), :]
else:
init_fc1_data = np_fc1[:, (inner_size // 2):]
init_fc2_data = np_fc2[(inner_size // 2):, :]
self.linear1 = ColumnParallelLinear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc1_data)),
gather_output=False,
has_bias=True, )
self.linear2 = RowParallelLinear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc2_data)),
input_is_parallel=True,
has_bias=True, )
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)), )
self.embedding = VocabParallelEmbedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=1.0), )
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = parallel_matmul(x, self.embedding.weight, False)
return x
class SimpleDPNet(nn.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1,
np_fc2):
super().__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc1)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)), )
self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc2)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)), )
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)), )
self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=1.0), )
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
class TestDistMPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
self.onnx_format = False
self.check_export_model_accuracy = True
self.diff_threshold = 0.01
self.fuse_conv_bn = False
def train_batch(self, batch, model, optimizer, is_mp):
output = model(batch)
loss = output.mean()
loss.backward() # do backward
optimizer.step() # update parameters
optimizer.clear_grad()
return loss
def build_optimizer(self, model):
optimizer = paddle.optimizer.SGD(
learning_rate=0.001, parameters=model.parameters())
return optimizer
def build_model_optimizer(self, qat):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
mp_id = hcg.get_model_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
set_random_seed(1024, dp_id, rank_id)
np_fc1 = np.ones((hidden_size, inner_size))
np_fc2 = np.ones((inner_size, hidden_size))
model_a = SimpleMPNet(
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id, )
model_a = qat.quantize(model_a, inplace=True)
optimizer_a = self.build_optimizer(model_a)
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
np_fc1, np_fc2)
model_b = qat.quantize(model_b, inplace=True)
optimizer_b = self.build_optimizer(model_b)
return model_a, optimizer_a, model_b, optimizer_b
def train(self, model_a, optimizer_a, model_b, optimizer_b):
for epoch in range(5):
np_data = np.random.randint(
0,
vocab_size,
(batch_size, seq_length, ), )
batch = paddle.to_tensor(np_data, dtype='int32')
loss_a = self.train_batch(batch, model_a, optimizer_a, True)
loss_b = self.train_batch(batch, model_b, optimizer_b, False)
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=1e-6)
def test_mp_model_1(self):
if (not fluid.core.is_compiled_with_cuda() or
fluid.core.get_cuda_device_count() == 0):
return
selected_gpus = get_gpus('0,1')
cluster = None
pod = None
observer = FakeQuanterWithAbsMaxObserver()
q_config = QuantConfig(activation=None, weight=None)
q_config.add_type_config(
ColumnParallelLinear, activation=observer, weight=observer)
q_config.add_type_config(
RowParallelLinear, activation=observer, weight=observer)
q_config.add_type_config(
nn.Linear, activation=observer, weight=observer)
q_config.add_qat_layer_mapping(ColumnParallelLinear,
QuantizedColumnParallelLinear)
q_config.add_qat_layer_mapping(RowParallelLinear,
QuantizedRowParallelLinear)
qat = QAT(q_config)
model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer(
qat)
self.train(model_a, optimizer_a, model_b, optimizer_b)
def test_mp_model_2(self):
if (not fluid.core.is_compiled_with_cuda() or
fluid.core.get_cuda_device_count() == 0):
return
selected_gpus = get_gpus('0,1')
cluster = None
pod = None
observer = FakeQuanterWithAbsMaxObserver()
q_config = QuantConfig(activation=None, weight=None)
q_config.add_type_config(
ColumnParallelLinear, activation=observer, weight=observer)
q_config.add_type_config(
RowParallelLinear, activation=observer, weight=observer)
q_config.add_type_config(
nn.Linear, activation=observer, weight=observer)
q_config.add_qat_layer_mapping(ColumnParallelLinear,
QuantizedColumnParallelLinear)
q_config.add_qat_layer_mapping(RowParallelLinear,
QuantizedRowParallelLinear)
qat = QAT(q_config)
model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer(
qat)
self.train(model_a, optimizer_a, model_b, optimizer_b)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../../")
import os
import unittest
import paddle
import tempfile
import numpy as np
from paddle.vision.models import resnet18
from paddle.quantization import QuantConfig
from paddle.quantization import PTQ
from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver
from paddleslim.quant.observers.hist import PercentHistObserverLayer
from paddleslim.quant.observers.kl import KLObserverLayer
from paddleslim.quant.observers.mse import MSEObserverLayer
from paddleslim.quant.observers.avg import AVGObserverLayer
from paddleslim.quant.observers.emd import EMDObserverLayer
from paddleslim.quant.observers.kl import KLObserverLayer
from paddle.nn.quant.format import LinearDequanter, LinearQuanter
import logging
from paddleslim.common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
self.features = paddle.nn.Sequential(
paddle.nn.Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1),
paddle.nn.AvgPool2D(kernel_size=2, stride=2),
paddle.nn.Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0), paddle.nn.AvgPool2D(kernel_size=2, stride=2))
self.fc = paddle.nn.Sequential(
paddle.nn.Linear(in_features=400, out_features=120),
paddle.nn.Linear(in_features=120, out_features=84),
paddle.nn.Linear(in_features=84, out_features=num_classes), )
def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
class TestPTQObserverAcc(unittest.TestCase):
def __init__(self, observer, observer_type, *args, **kvargs):
super(TestPTQObserverAcc, self).__init__(*args, **kvargs)
self.observer = observer
self.observer_type = observer_type
def setUp(self):
paddle.set_device("cpu")
self.init_case()
self.dummy_input = paddle.rand([1, 3, 224, 224])
self.temp_dir = tempfile.TemporaryDirectory(dir="./")
self.path = os.path.join(self.temp_dir.name, 'qat')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
seed = 1
np.random.seed(seed)
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
def tearDown(self):
self.temp_dir.cleanup()
def runTest(self):
self.test_convergence()
def init_case(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_type_config(
paddle.nn.Conv2D, activation=self.observer, weight=self.observer)
def _count_layers(self, model, layer_type):
count = 0
for _layer in model.sublayers(True):
if isinstance(_layer, layer_type):
count += 1
return count
def test_convergence(self):
model = ImperativeLenet()
place = paddle.CUDAPlace(0) \
if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
transform = paddle.vision.transforms.Compose([
paddle.vision.transforms.Transpose(),
paddle.vision.transforms.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_reader = paddle.io.DataLoader(
train_dataset,
drop_last=True,
places=place,
batch_size=64,
return_list=True)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=64, return_list=True)
def train(model):
adam = paddle.optimizer.Adam(
learning_rate=0.0001, parameters=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.loss.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
def test(model):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.to_tensor(data[1])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info(
"Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test | Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
def ptq_sample(model):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.to_tensor(data[1])
label = paddle.reshape(label, [-1, 1])
out = model(img)
if batch_id % 100 == 0:
_logger.info("PTQ sampling | step {}".format(batch_id))
train(model)
top1_1, top5_1 = test(model)
ptq = PTQ(self.q_config)
model.eval()
quant_model = ptq.quantize(model, inplace=False)
ptq_sample(quant_model)
converted_model = ptq.convert(quant_model, inplace=False)
top1_2, top5_2 = test(converted_model)
_logger.info(
"Before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
_logger.info(
"After quantization: top1: {}, top5: {}".format(top1_2, top5_2))
_logger.info("\n")
diff = 0.01
self.assertTrue(
top1_1 - top1_2 < diff,
msg="The acc of quant model is too lower than fp32 model")
_logger.info('done')
return
observer_suite = unittest.TestSuite()
observer_suite.addTest(
TestPTQObserverAcc(
observer=HistObserver(sign=True, symmetric=True),
observer_type=PercentHistObserverLayer))
observer_suite.addTest(
TestPTQObserverAcc(
observer=KLObserver(bins_count=256), observer_type=KLObserverLayer))
observer_suite.addTest(
TestPTQObserverAcc(observer=AVGObserver(), observer_type=AVGObserverLayer))
observer_suite.addTest(
TestPTQObserverAcc(observer=EMDObserver(), observer_type=EMDObserverLayer))
observer_suite.addTest(
TestPTQObserverAcc(observer=MSEObserver(), observer_type=MSEObserverLayer))
if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(observer_suite)
os.system('rm -rf ILSVRC2012_data_demo.tar.gz')
os.system('rm -rf ILSVRC2012_data_demo')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册