diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index f7acb185add5d40b749e7442111891869dfaeb22..a44a8b5809c7b082e39edd1dbde1ecbe665abd54 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -161,7 +161,13 @@ def main(config, device, logger, vdl_writer): if config["Global"]["pretrained_model"] is not None: pre_best_model_dict = load_model(config, model) - quanter = QAT(config=quant_config, act_preprocess=PACT) + freeze_params = False + if config['Architecture']["algorithm"] in ["Distillation"]: + for key in config['Architecture']["Models"]: + freeze_params = freeze_params or config['Architecture']['Models'][ + key].get('freeze_params', False) + act = None if freeze_params else 'PACT' + quanter = QAT(config=quant_config, act_preprocess=act) quanter.quantize(model) if config['Global']['distributed']: