From bcb13c271c96cdc58916911922bb29700d163802 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Sun, 29 Sep 2019 13:17:27 +0800 Subject: [PATCH] Fix slim resnet prefix (#3446) * fix resnet prefix * fix data path --- PaddleSlim/classification/distillation/compress.py | 7 ++----- PaddleSlim/classification/models/resnet.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/PaddleSlim/classification/distillation/compress.py b/PaddleSlim/classification/distillation/compress.py index f8a721fb..8c6ac9ae 100644 --- a/PaddleSlim/classification/distillation/compress.py +++ b/PaddleSlim/classification/distillation/compress.py @@ -106,15 +106,12 @@ def compress(args): fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) - val_reader = paddle.batch( - reader.val(data_dir='../data/ILSVRC2012'), batch_size=args.batch_size) + val_reader = paddle.batch(reader.val(), batch_size=args.batch_size) val_feed_list = [('image', image.name), ('label', label.name)] val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', acc_top5.name)] train_reader = paddle.batch( - reader.train(data_dir='../data/ILSVRC2012'), - batch_size=args.batch_size, - drop_last=True) + reader.train(), batch_size=args.batch_size, drop_last=True) train_feed_list = [('image', image.name), ('label', label.name)] train_fetch_list = [('loss', avg_cost.name)] diff --git a/PaddleSlim/classification/models/resnet.py b/PaddleSlim/classification/models/resnet.py index b40c7bf4..0adf4276 100644 --- a/PaddleSlim/classification/models/resnet.py +++ b/PaddleSlim/classification/models/resnet.py @@ -29,7 +29,7 @@ class ResNet(): def net(self, input, class_dim=1000, conv1_name='conv1', fc_name=None): layers = self.layers - prefix_name = self.prefix_name + '_' + prefix_name = self.prefix_name if self.prefix_name is '' else self.prefix_name + '_' supported_layers = [34, 50, 101, 152] assert layers in supported_layers, \ "supported layers are {} but input layer is {}".format(supported_layers, layers) -- GitLab