From b151afe31f900f9ad95f5ae54f08b9b3f652119f Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Wed, 14 Jul 2021 10:30:24 +0800 Subject: [PATCH] fix cifar10 dataloader and model loading, test=develop (#825) (#844) --- demo/dygraph/unstructured_pruning/README.md | 4 +++- demo/dygraph/unstructured_pruning/evaluate.py | 2 -- demo/dygraph/unstructured_pruning/train.py | 4 ---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/demo/dygraph/unstructured_pruning/README.md b/demo/dygraph/unstructured_pruning/README.md index 3a861dff..a1068542 100644 --- a/demo/dygraph/unstructured_pruning/README.md +++ b/demo/dygraph/unstructured_pruning/README.md @@ -89,9 +89,11 @@ python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshol ## 推理: ```bash -python3.7 evalualte.py --pruned_model models/ --data imagenet +python3.7 evaluate.py --pruned_model models/model-pruned.pdparams --data imagenet ``` +**注意**,上述`pruned_model` 参数应该指向pdparams文件。 + 剪裁训练代码示例: ```python model = mobilenet_v1(num_classes=class_dim, pretrained=True) diff --git a/demo/dygraph/unstructured_pruning/evaluate.py b/demo/dygraph/unstructured_pruning/evaluate.py index d2c6aa56..e8827d1f 100644 --- a/demo/dygraph/unstructured_pruning/evaluate.py +++ b/demo/dygraph/unstructured_pruning/evaluate.py @@ -67,8 +67,6 @@ def compress(args): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py index 39b6804b..5cf71789 100644 --- a/demo/dygraph/unstructured_pruning/train.py +++ b/demo/dygraph/unstructured_pruning/train.py @@ -145,8 +145,6 @@ def compress(args): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) @@ -180,8 +178,6 @@ def compress(args): train_reader_cost += time.time() - reader_start x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) train_start = time.time() logits = model(x_data) -- GitLab