From c4ee247a3456b118fdcdd42fc0788ddd703c536d Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Mon, 8 May 2023 17:29:54 +0800 Subject: [PATCH] Fix bug of loading saved model and tests. (#1744) * fix bug of loading saved model. * fix distill demo. * fix flops.py. * fix tests. --- demo/distillation/distill.py | 9 ++++----- paddleslim/analysis/flops.py | 9 +++++++-- paddleslim/auto_compression/compressor.py | 13 ++++++++++--- tests/dygraph/test_flops.py | 5 ++--- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/demo/distillation/distill.py b/demo/distillation/distill.py index 43d4a43f..2adeae58 100644 --- a/demo/distillation/distill.py +++ b/demo/distillation/distill.py @@ -97,8 +97,8 @@ def compress(args): raise ValueError("{} is not supported.".format(args.data)) image_shape = [int(m) for m in image_shape.split(",")] - assert args.model in model_list, "{} is not in lists: {}".format(args.model, - model_list) + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) student_program = paddle.static.Program() s_startup = paddle.static.Program() places = paddle.static.cuda_places( @@ -202,7 +202,7 @@ def compress(args): _logger.info( "train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}". format(epoch_id, step_id, - lr.get_lr(), loss_1[0], loss_2[0], loss_3[0])) + lr.get_lr(), loss_1, loss_2, loss_3)) lr.step() val_acc1s = [] val_acc5s = [] @@ -216,8 +216,7 @@ def compress(args): if step_id % args.log_period == 0: _logger.info( "valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}". - format(epoch_id, step_id, val_loss[0], val_acc1[0], - val_acc5[0])) + format(epoch_id, step_id, val_loss, val_acc1, val_acc5)) if args.save_inference: paddle.static.save_inference_model( os.path.join("./saved_models", str(epoch_id)), [image], [out], diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index 04756313..2554cc15 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -84,7 +84,8 @@ def _graph_flops(graph, only_conv=True, detail=False): output_shape = op.outputs("Out")[0].shape() _, c_out, h_out, w_out = output_shape k_size = op.attr("ksize") - flops += h_out * w_out * c_out * (k_size[0]**2) + if op.attr('pooling_type') == 'avg': + flops += (h_out * w_out * c_out * (k_size[0]**2) * 2) elif op.type() in ['mul', 'matmul', 'matmul_v2']: x_shape = list(op.inputs("X")[0].shape()) @@ -101,7 +102,11 @@ def _graph_flops(graph, only_conv=True, detail=False): input_shape = list(op.inputs("X")[0].shape()) if input_shape[0] == -1: input_shape[0] = 1 - flops += np.product(input_shape) + if op.type() == 'batch_norm': + op_flops = np.product(input_shape) * 2 + else: + op_flops = np.product(input_shape) + flops += op_flops if detail: return flops, params2flops diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index b54d3b1d..ed0a53c8 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -570,7 +570,7 @@ class AutoCompression: tmp_base_name = "_".join([prefix, str(os.getppid()), s_datetime]) tmp_dir = os.path.join(base_dir, tmp_base_name) if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) + os.makedirs(tmp_dir, exist_ok=True) return tmp_dir def compress(self): @@ -609,10 +609,17 @@ class AutoCompression: shutil.rmtree(self.tmp_dir) if self.eval_function is not None and self.final_metric < 0.0: + model_filename = None + if self.model_filename is None: + model_filename = "model.pdmodel" + elif self.model_filename.endswith(".pdmodel"): + model_filename = self.model_filename + else: + model_filename = self.model_filename + '.pdmodel' + [inference_program, feed_target_names, fetch_targets]= load_inference_model( \ final_model_path, \ - model_filename=self.model_filename, params_filename=self.params_filename, - executor=self._exe) + model_filename=model_filename, executor=self._exe) self.final_metric = self.eval_function( self._exe, inference_program, feed_target_names, fetch_targets) diff --git a/tests/dygraph/test_flops.py b/tests/dygraph/test_flops.py index 7fe1da8d..a870b476 100644 --- a/tests/dygraph/test_flops.py +++ b/tests/dygraph/test_flops.py @@ -73,9 +73,8 @@ class TestFLOPsCase2(unittest.TestCase): def add_cases(suite): suite.addTest( - TestFlops( - net=paddle.vision.models.mobilenet_v1, gt=11792896.0)) - suite.addTest(TestFlops(net=paddle.vision.models.resnet50, gt=83872768.0)) + TestFlops(net=paddle.vision.models.mobilenet_v1, gt=12920832.0)) + suite.addTest(TestFlops(net=paddle.vision.models.resnet50, gt=86112768.0)) suite.addTest(TestFLOPsCase1()) suite.addTest(TestFLOPsCase2()) -- GitLab