diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 7671e5f871ce6769fc51876d1fa2e5f0af63d904..315e3b4321a544e77795c43d493873fcf46e1930 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -112,10 +112,6 @@ def main(config, device, logger, vdl_writer): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - # prepare to quant - quanter = QAT(config=quant_config, act_preprocess=PACT) - quanter.quantize(model) - if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -136,31 +132,15 @@ def main(config, device, logger, vdl_writer): logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) + quanter = QAT(config=quant_config, act_preprocess=PACT) + quanter.quantize(model) + # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) -def test_reader(config, device, logger): - loader = build_dataloader(config, 'Train', device, logger) - import time - starttime = time.time() - count = 0 - try: - for data in loader(): - count += 1 - if count % 1 == 0: - batch_time = time.time() - starttime - starttime = time.time() - logger.info("reader: {}, {}, {}".format( - count, len(data[0]), batch_time)) - except Exception as e: - logger.info(e) - logger.info("finish reader: {}, Success!".format(count)) - - if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess(is_train=True) main(config, device, logger, vdl_writer) - # test_reader(config, device, logger)