未验证 提交 aaf021c9 编写于 作者: L Leo Chen 提交者: GitHub

Integrate QAT into distributed optimizer (#54241)

* Support AMP program for onnx QAT API

* Integrate QAT into distributed optimizer

* Reduce the size of test data and increase time limit

* Use logger and reduce time limit of unittests

* Rename and move unittest into fleet test

* Test qat_init API
上级 acf4a2ae
...@@ -984,6 +984,59 @@ class DistributedStrategy: ...@@ -984,6 +984,59 @@ class DistributedStrategy:
else: else:
logger.warning("asp should have value of bool type") logger.warning("asp should have value of bool type")
@property
def qat(self):
"""
Indicating whether we are using quantization aware training
Default Value: False
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.qat = True # by default this is false
"""
return self.strategy.qat
@qat.setter
@is_strict_auto
def qat(self, flag):
assert isinstance(flag, bool), "qat should have value of bool type"
self.strategy.qat = flag
@property
def qat_configs(self):
"""
Set quantization training configurations. In general, qat has serveral configurable
settings that can be configured through a dict.
**Notes**:
channel_wise_abs_max(bool): Whether to use `per_channel` quantization training. Default is True.
weight_bits(int): quantization bit number for weight. Default is 8.
activation_bits(int): quantization bit number for activation. Default is 8.
not_quant_pattern(list[str]): When the skip pattern is detected in an op's name scope,
the corresponding op will not be quantized.
algo(str): Other quantization training algorithm.
Exampless:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.qat = True
strategy.qat_configs = {
"channel_wise_abs_max": True,
"weight_bits": 8,
"activation_bits: 8,
"not_quant_pattern": ['skip_quant']}
"""
return get_msg_dict(self.strategy.qat_configs)
@qat_configs.setter
def qat_configs(self, configs):
check_configs_key(self.strategy.qat_configs, configs, "qat_configs")
assign_configs_value(self.strategy.qat_configs, configs)
@property @property
def recompute(self): def recompute(self):
""" """
......
...@@ -1170,6 +1170,38 @@ class Fleet: ...@@ -1170,6 +1170,38 @@ class Fleet:
amp_optimizer = self._get_amp_optimizer() amp_optimizer = self._get_amp_optimizer()
return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test) return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
def _get_qat_optimizer(self):
# imitate target optimizer retrieval
qat_optimizer = None
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
if hasattr(optimizer, 'qat_init'):
qat_optimizer = optimizer
break
if qat_optimizer is None:
if hasattr(self.user_defined_optimizer, 'qat_init'):
qat_optimizer = self.user_defined_optimizer
assert (
qat_optimizer is not None
), "qat_init can only be used when the qat(quantization aware training) strategy is turned on."
return qat_optimizer
def qat_init(self, place, scope=None, test_program=None):
"""
Init the qat training, such as insert qdq ops and scale variables.
Args:
place(CUDAPlace): place is used to initialize
scale parameters.
scope(Scope): The scope is used to find parameters and variables.
test_program(Program): The program is used for testing.
"""
qat_optimizer = self._get_qat_optimizer()
return qat_optimizer.qat_init(
place, scope=scope, test_program=test_program
)
def _final_strategy(self): def _final_strategy(self):
if "valid_strategy" not in self._context: if "valid_strategy" not in self._context:
print( print(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from .amp_optimizer import AMPOptimizer from .amp_optimizer import AMPOptimizer
from .asp_optimizer import ASPOptimizer from .asp_optimizer import ASPOptimizer
from .qat_optimizer import QATOptimizer
from .recompute_optimizer import RecomputeOptimizer from .recompute_optimizer import RecomputeOptimizer
from .gradient_merge_optimizer import GradientMergeOptimizer from .gradient_merge_optimizer import GradientMergeOptimizer
from .ps_optimizer import ParameterServerOptimizer from .ps_optimizer import ParameterServerOptimizer
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import copy
import paddle
from paddle.static.quantization.quanter import (
_quant_config_default,
quant_aware,
)
from .meta_optimizer_base import MetaOptimizerBase
class QATOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super().__init__(optimizer)
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
"GraphExecutionOptimizer",
"RecomputeOptimizer",
"GradientMergeOptimizer",
]
self.meta_optimizers_black_list = []
def _set_basic_info(
self, loss, role_maker, user_defined_optimizer, user_defined_strategy
):
super()._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy
)
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if self.user_defined_strategy.qat:
return True
return False
def _disable_strategy(self, dist_strategy):
dist_strategy.qat = False
dist_strategy.qat_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.qat = True
dist_strategy.qat_configs = {
'channel_wise_abs_max': True,
'weight_bits': 8,
'activation_bits': 8,
'not_quant_pattern': [],
'algo': "",
}
def _gen_qat_config(self):
# Align the config to auto_parallel quantization pass
config = self.user_defined_strategy.qat_configs
qat_config = copy.deepcopy(_quant_config_default)
qat_config['quantize_op_types'] = [
'conv2d',
'depthwise_conv2d',
'mul',
'matmul',
'matmul_v2',
]
qat_config['weight_quantize_type'] = (
'channel_wise_abs_max'
if config['channel_wise_abs_max']
else 'abs_max'
)
qat_config['weight_bits'] = config['weight_bits']
qat_config['activation_bits'] = config['activation_bits']
qat_config['not_quant_pattern'] = list(config['not_quant_pattern'])
return qat_config
def _replace_program(self, main_program, refer_program):
main_program._rebuild_from_desc(refer_program.desc)
def minimize_impl(
self, loss, startup_program=None, parameter_list=None, no_grad_set=None
):
optimize_ops, params_grads = self.inner_opt.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
)
device = paddle.device.get_device()
place = paddle.set_device(device)
qat_config = self._gen_qat_config()
qat_program = quant_aware(
loss.block.program, place, config=qat_config, return_program=True
)
self._replace_program(loss.block.program, qat_program)
return optimize_ops, params_grads
def qat_init(self, place, scope=None, test_program=None):
if test_program is not None:
qat_config = self._gen_qat_config()
qat_program = quant_aware(
test_program,
place,
scope=scope,
config=qat_config,
for_test=True,
return_program=True,
)
self._replace_program(test_program, qat_program)
...@@ -2544,7 +2544,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2544,7 +2544,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
if name in self.processed_vars: if name in self.processed_vars:
continue continue
is_weight = ( is_weight = (
True if var_node.name() in self.persistable_vars else False True
if var_node.name() in self.persistable_vars
or var_node.name() in self.persistable_cast_output_vars
else False
) )
# if var node is weight and weight_preprocess_func is not None, # if var node is weight and weight_preprocess_func is not None,
...@@ -2645,7 +2648,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2645,7 +2648,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
continue continue
if var_node.name() in self.persistable_vars: if (
var_node.name() in self.persistable_vars
or var_node.name() in self.persistable_cast_output_vars
):
has_weight = True has_weight = True
return has_weight return has_weight
...@@ -2748,6 +2754,16 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2748,6 +2754,16 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
] ]
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
# Mark the output of cast op where the input is weight for AMP program
self.persistable_cast_output_vars = []
for op in graph.all_op_nodes():
if (
op.name() == "cast"
and op.inputs[0].name() in self.persistable_vars
):
self.persistable_cast_output_vars.append(op.outputs[0].name())
# Do the preproccess of quantization, such as skipping some ops # Do the preproccess of quantization, such as skipping some ops
# for not being quantized. # for not being quantized.
for op in ops: for op in ops:
......
...@@ -575,6 +575,11 @@ if((WITH_GPU OR WITH_XPU) AND LOCAL_ALL_PLAT) ...@@ -575,6 +575,11 @@ if((WITH_GPU OR WITH_XPU) AND LOCAL_ALL_PLAT)
test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
endif() endif()
if((WITH_GPU OR WITH_XPU) AND LOCAL_ALL_PLAT)
py_test_modules(
test_fleet_qat_meta_optimizer MODULES test_fleet_qat_meta_optimizer ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
endif()
if(WITH_NCCL) if(WITH_NCCL)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
if((WITH_GPU) AND LOCAL_ALL_PLAT) if((WITH_GPU) AND LOCAL_ALL_PLAT)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import fluid, nn
from paddle.distributed import fleet
paddle.enable_static()
fleet.init(is_collective=True)
class SimpleNet(nn.Layer):
def __init__(self, input_size, output_size):
super().__init__()
self.linear1 = nn.Linear(input_size, output_size)
self.linear2 = nn.Linear(input_size, output_size)
self.linear3 = nn.Linear(input_size, output_size)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
class TestFleetWithQAT(unittest.TestCase):
def setUp(self):
self.input_size = 4096
self.output_size = 4096
self.batch_size = 8
def setup_strategy(self, strategy):
strategy.qat = True
def generate_program(self, strategy):
train_prog, startup_prog = fluid.Program(), fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
input_x = paddle.static.data(
name='X',
shape=[self.batch_size, self.input_size],
dtype='float32',
)
input_y = paddle.static.data(
name='Y',
shape=[self.batch_size, self.output_size],
dtype='float32',
)
model = SimpleNet(self.input_size, self.output_size)
mse = paddle.nn.MSELoss()
out = model(input_x)
loss = mse(out, input_y)
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy
)
optimizer.minimize(loss)
return train_prog, startup_prog, input_x, input_y, optimizer
def execute_program(self, train_prog, startup_prog, input_x, input_y):
place = (
fluid.CUDAPlace(0)
if paddle.fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place)
exe.run(startup_prog)
data = (
np.random.randn(self.batch_size, self.input_size),
np.random.randn(self.batch_size, self.output_size),
)
exe.run(train_prog, feed=feeder.feed([data]))
def valid_program(self, train_prog, eval_prog):
ops_type = [op.type for op in train_prog.block(0).ops]
self.assertEqual(
ops_type.count('matmul_v2'), 3
) # SimpleNet has 3 linear layers
self.assertEqual(ops_type.count('quantize_linear'), 6)
# There are three linear layers and each layer has this op in weight.
self.assertEqual(
ops_type.count('dequantize_linear'), 6
) # Dequantize Op will follow quantize op (fake quantize), so the number is same.
def test_fleet_with_qat(self):
dist_strategy = paddle.distributed.fleet.DistributedStrategy()
self.setup_strategy(dist_strategy)
(
train_prog,
startup_prog,
input_x,
input_y,
optimizer,
) = self.generate_program(dist_strategy)
place = (
fluid.CUDAPlace(0)
if paddle.fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
)
eval_prog = train_prog.clone(for_test=True)
optimizer.qat_init(
place, scope=paddle.static.global_scope(), test_program=eval_prog
)
self.execute_program(train_prog, startup_prog, input_x, input_y)
self.valid_program(train_prog, eval_prog)
class TestFleetWithAMPQAT(TestFleetWithQAT):
def setup_strategy(self, strategy):
strategy.qat = True
strategy.amp = True
def valid_program(self, train_prog, eval_prog):
ops_type = [op.type for op in train_prog.block(0).ops]
self.assertEqual(
ops_type.count('matmul_v2'), 3
) # SimpleNet has 3 linear layers
self.assertEqual(ops_type.count('quantize_linear'), 6)
# There are three linear layers and each layer has this op in weight.
self.assertEqual(
ops_type.count('dequantize_linear'), 6
) # Dequantize Op will follow quantize op (fake quantize), so the number is same.
if __name__ == "__main__":
unittest.main()
...@@ -231,6 +231,7 @@ if(WIN32) ...@@ -231,6 +231,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_quant_post_quant_aware) list(REMOVE_ITEM TEST_OPS test_quant_post_quant_aware)
list(REMOVE_ITEM TEST_OPS test_quant_aware_user_defined) list(REMOVE_ITEM TEST_OPS test_quant_aware_user_defined)
list(REMOVE_ITEM TEST_OPS test_quant_aware_config) list(REMOVE_ITEM TEST_OPS test_quant_aware_config)
list(REMOVE_ITEM TEST_OPS test_quant_amp)
endif() endif()
...@@ -488,10 +489,11 @@ if(NOT WIN32) ...@@ -488,10 +489,11 @@ if(NOT WIN32)
set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT
120) 120)
set_tests_properties(test_quant_aware PROPERTIES TIMEOUT 900) set_tests_properties(test_quant_aware PROPERTIES TIMEOUT 200)
set_tests_properties(test_quant_post_quant_aware PROPERTIES TIMEOUT 900) set_tests_properties(test_quant_post_quant_aware PROPERTIES TIMEOUT 200)
set_tests_properties(test_quant_aware_user_defined PROPERTIES TIMEOUT 900) set_tests_properties(test_quant_aware_user_defined PROPERTIES TIMEOUT 200)
set_tests_properties(test_quant_aware_config PROPERTIES TIMEOUT 900) set_tests_properties(test_quant_aware_config PROPERTIES TIMEOUT 200)
set_tests_properties(test_quant_amp PROPERTIES TIMEOUT 200)
endif() endif()
set_tests_properties(test_graph PROPERTIES TIMEOUT 120) set_tests_properties(test_graph PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import unittest
import numpy as np
from test_quant_aware import MobileNet
import paddle
from paddle.static.quantization.quanter import convert, quant_aware
logging.basicConfig(level="INFO", format="%(message)s")
class TestQuantAMP(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def generate_config(self):
config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'onnx_format': True,
}
return config
def test_accuracy(self):
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32'
)
label = paddle.static.data(
name='label', shape=[None, 1], dtype='int64'
)
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(
input=out, label=label
)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5),
)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
)
optimizer.minimize(avg_cost)
val_prog = main_prog.clone(for_test=True)
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
def transform(x):
return np.reshape(x, [1, 28, 28])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform
)
test_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform
)
batch_size = 64 if os.environ.get('DATASET') == 'full' else 8
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=batch_size,
)
valid_loader = paddle.io.DataLoader(
test_dataset,
places=place,
feed_list=[image, label],
batch_size=batch_size,
return_list=False,
)
def train(program):
iter = 0
stop_iter = None if os.environ.get('DATASET') == 'full' else 10
for data in train_loader():
cost, top1, top5 = exe.run(
program,
feed=data,
fetch_list=[avg_cost, acc_top1, acc_top5],
)
iter += 1
if iter % 100 == 0:
logging.info(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5
)
)
if stop_iter is not None and iter == stop_iter:
break
def test(program):
iter = 0
stop_iter = None if os.environ.get('DATASET') == 'full' else 10
result = [[], [], []]
for data in valid_loader():
cost, top1, top5 = exe.run(
program,
feed=data,
fetch_list=[avg_cost, acc_top1, acc_top5],
)
iter += 1
if iter % 100 == 0:
logging.info(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5
)
)
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
if stop_iter is not None and iter == stop_iter:
break
logging.info(
' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2])
)
)
return np.mean(result[1]), np.mean(result[2])
train(main_prog)
top1_1, top5_1 = test(main_prog)
config = self.generate_config()
quant_train_prog = quant_aware(
main_prog, place, config, for_test=False, return_program=True
)
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
train(quant_train_prog)
convert_eval_prog = convert(quant_eval_prog, place, config)
top1_2, top5_2 = test(convert_eval_prog)
# values before quantization and after quantization should be close
logging.info(f"before quantization: top1: {top1_1}, top5: {top5_1}")
logging.info(f"after quantization: top1: {top1_2}, top5: {top5_2}")
if __name__ == '__main__':
unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import unittest import unittest
...@@ -21,6 +22,8 @@ import paddle ...@@ -21,6 +22,8 @@ import paddle
from paddle.nn.initializer import KaimingUniform from paddle.nn.initializer import KaimingUniform
from paddle.static.quantization.quanter import convert, quant_aware from paddle.static.quantization.quanter import convert, quant_aware
logging.basicConfig(level="INFO", format="%(message)s")
train_parameters = { train_parameters = {
"input_size": [3, 224, 224], "input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406], "input_mean": [0.485, 0.456, 0.406],
...@@ -299,7 +302,7 @@ class TestQuantAwareCase(StaticCase): ...@@ -299,7 +302,7 @@ class TestQuantAwareCase(StaticCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5 iter, cost, top1, top5
) )
...@@ -319,7 +322,7 @@ class TestQuantAwareCase(StaticCase): ...@@ -319,7 +322,7 @@ class TestQuantAwareCase(StaticCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5 iter, cost, top1, top5
) )
...@@ -329,7 +332,7 @@ class TestQuantAwareCase(StaticCase): ...@@ -329,7 +332,7 @@ class TestQuantAwareCase(StaticCase):
result[2].append(top5) result[2].append(top5)
if stop_iter is not None and iter == stop_iter: if stop_iter is not None and iter == stop_iter:
break break
print( logging.info(
' avg loss {}, acc_top1 {}, acc_top5 {}'.format( ' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2]) np.mean(result[0]), np.mean(result[1]), np.mean(result[2])
) )
...@@ -355,8 +358,8 @@ class TestQuantAwareCase(StaticCase): ...@@ -355,8 +358,8 @@ class TestQuantAwareCase(StaticCase):
top1_2, top5_2 = test(convert_eval_prog) top1_2, top5_2 = test(convert_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print(f"before quantization: top1: {top1_1}, top5: {top5_1}") logging.info(f"before quantization: top1: {top1_1}, top5: {top5_1}")
print(f"after quantization: top1: {top1_2}, top5: {top5_2}") logging.info(f"after quantization: top1: {top1_2}, top5: {top5_2}")
convert_op_nums_1, convert_quant_op_nums_1 = self.get_convert_op_number( convert_op_nums_1, convert_quant_op_nums_1 = self.get_convert_op_number(
convert_eval_prog convert_eval_prog
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import unittest import unittest
...@@ -20,6 +22,8 @@ from test_quant_aware import MobileNet ...@@ -20,6 +22,8 @@ from test_quant_aware import MobileNet
import paddle import paddle
from paddle.static.quantization.quanter import convert, quant_aware from paddle.static.quantization.quanter import convert, quant_aware
logging.basicConfig(level="INFO", format="%(message)s")
class TestQuantAwareBase(unittest.TestCase): class TestQuantAwareBase(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -107,7 +111,7 @@ class TestQuantAwareBase(unittest.TestCase): ...@@ -107,7 +111,7 @@ class TestQuantAwareBase(unittest.TestCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5 iter, cost, top1, top5
) )
...@@ -127,7 +131,7 @@ class TestQuantAwareBase(unittest.TestCase): ...@@ -127,7 +131,7 @@ class TestQuantAwareBase(unittest.TestCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5 iter, cost, top1, top5
) )
...@@ -137,7 +141,7 @@ class TestQuantAwareBase(unittest.TestCase): ...@@ -137,7 +141,7 @@ class TestQuantAwareBase(unittest.TestCase):
result[2].append(top5) result[2].append(top5)
if stop_iter is not None and iter == stop_iter: if stop_iter is not None and iter == stop_iter:
break break
print( logging.info(
' avg loss {}, acc_top1 {}, acc_top5 {}'.format( ' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2]) np.mean(result[0]), np.mean(result[1]), np.mean(result[2])
) )
...@@ -164,8 +168,8 @@ class TestQuantAwareBase(unittest.TestCase): ...@@ -164,8 +168,8 @@ class TestQuantAwareBase(unittest.TestCase):
top1_2, top5_2 = test(convert_eval_prog) top1_2, top5_2 = test(convert_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print(f"before quantization: top1: {top1_1}, top5: {top5_1}") logging.info(f"before quantization: top1: {top1_1}, top5: {top5_1}")
print(f"after quantization: top1: {top1_2}, top5: {top5_2}") logging.info(f"after quantization: top1: {top1_2}, top5: {top5_2}")
class TestQuantAwareNone(TestQuantAwareBase): class TestQuantAwareNone(TestQuantAwareBase):
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import unittest import unittest
...@@ -20,6 +21,8 @@ from test_quant_aware import MobileNet, StaticCase ...@@ -20,6 +21,8 @@ from test_quant_aware import MobileNet, StaticCase
import paddle import paddle
from paddle.static.quantization.quanter import convert, quant_aware from paddle.static.quantization.quanter import convert, quant_aware
logging.basicConfig(level="INFO", format="%(message)s")
def pact(x): def pact(x):
helper = paddle.fluid.layer_helper.LayerHelper("pact", **locals()) helper = paddle.fluid.layer_helper.LayerHelper("pact", **locals())
...@@ -123,7 +126,7 @@ class TestQuantAwareCase1(StaticCase): ...@@ -123,7 +126,7 @@ class TestQuantAwareCase1(StaticCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5 iter, cost, top1, top5
) )
...@@ -143,7 +146,7 @@ class TestQuantAwareCase1(StaticCase): ...@@ -143,7 +146,7 @@ class TestQuantAwareCase1(StaticCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format(
iter, cost, top1, top5 iter, cost, top1, top5
) )
...@@ -153,7 +156,7 @@ class TestQuantAwareCase1(StaticCase): ...@@ -153,7 +156,7 @@ class TestQuantAwareCase1(StaticCase):
result[2].append(top5) result[2].append(top5)
if stop_iter is not None and iter == stop_iter: if stop_iter is not None and iter == stop_iter:
break break
print( logging.info(
' avg loss {}, acc_top1 {}, acc_top5 {}'.format( ' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2]) np.mean(result[0]), np.mean(result[1]), np.mean(result[2])
) )
...@@ -184,8 +187,8 @@ class TestQuantAwareCase1(StaticCase): ...@@ -184,8 +187,8 @@ class TestQuantAwareCase1(StaticCase):
quant_eval_prog = convert(quant_eval_prog, place, config) quant_eval_prog = convert(quant_eval_prog, place, config)
top1_2, top5_2 = test(quant_eval_prog) top1_2, top5_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print(f"before quantization: top1: {top1_1}, top5: {top5_1}") logging.info(f"before quantization: top1: {top1_1}, top5: {top5_1}")
print(f"after quantization: top1: {top1_2}, top5: {top5_2}") logging.info(f"after quantization: top1: {top1_2}, top5: {top5_2}")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import random import random
import unittest import unittest
...@@ -24,6 +25,7 @@ from paddle.static.quantization.quanter import convert, quant_aware ...@@ -24,6 +25,7 @@ from paddle.static.quantization.quanter import convert, quant_aware
np.random.seed(0) np.random.seed(0)
random.seed(0) random.seed(0)
paddle.seed(0) paddle.seed(0)
logging.basicConfig(level="INFO", format="%(message)s")
class RandomDataset(paddle.io.Dataset): class RandomDataset(paddle.io.Dataset):
...@@ -106,7 +108,7 @@ class TestQuantPostQuantAwareCase1(StaticCase): ...@@ -106,7 +108,7 @@ class TestQuantPostQuantAwareCase1(StaticCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'train iter={}, avg loss {}, acc_top1 {}'.format( 'train iter={}, avg loss {}, acc_top1 {}'.format(
iter, cost, top1 iter, cost, top1
) )
...@@ -121,14 +123,14 @@ class TestQuantPostQuantAwareCase1(StaticCase): ...@@ -121,14 +123,14 @@ class TestQuantPostQuantAwareCase1(StaticCase):
) )
iter += 1 iter += 1
if iter % 100 == 0: if iter % 100 == 0:
print( logging.info(
'eval iter={}, avg loss {}, acc_top1 {}'.format( 'eval iter={}, avg loss {}, acc_top1 {}'.format(
iter, cost, top1 iter, cost, top1
) )
) )
result[0].append(cost) result[0].append(cost)
result[1].append(top1) result[1].append(top1)
print( logging.info(
' avg loss {}, acc_top1 {}'.format( ' avg loss {}, acc_top1 {}'.format(
np.mean(result[0]), np.mean(result[1]) np.mean(result[0]), np.mean(result[1])
) )
...@@ -180,8 +182,8 @@ class TestQuantPostQuantAwareCase1(StaticCase): ...@@ -180,8 +182,8 @@ class TestQuantPostQuantAwareCase1(StaticCase):
quant_eval_prog = convert(quant_eval_prog, place, config) quant_eval_prog = convert(quant_eval_prog, place, config)
top1_2 = test(quant_eval_prog) top1_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print(f"before quantization: top1: {top1_1}") logging.info(f"before quantization: top1: {top1_1}")
print(f"after quantization: top1: {top1_2}") logging.info(f"after quantization: top1: {top1_2}")
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册