diff --git a/demo/dygraph/unstructured_pruning/README.md b/demo/dygraph/unstructured_pruning/README.md index 3a861dffc742c66a5bf4193dcad6d74e30da394f..a1068542587da7f1131a2590380c20475b44f62f 100644 --- a/demo/dygraph/unstructured_pruning/README.md +++ b/demo/dygraph/unstructured_pruning/README.md @@ -89,9 +89,11 @@ python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshol ## 推理: ```bash -python3.7 evalualte.py --pruned_model models/ --data imagenet +python3.7 evaluate.py --pruned_model models/model-pruned.pdparams --data imagenet ``` +**注意**,上述`pruned_model` 参数应该指向pdparams文件。 + 剪裁训练代码示例: ```python model = mobilenet_v1(num_classes=class_dim, pretrained=True) diff --git a/demo/dygraph/unstructured_pruning/evaluate.py b/demo/dygraph/unstructured_pruning/evaluate.py index d2c6aa56c13cc2e0cae84f00a8f71e66b6bff05b..e8827d1f5e7a6ec6ee20ff6134f996966064dd34 100644 --- a/demo/dygraph/unstructured_pruning/evaluate.py +++ b/demo/dygraph/unstructured_pruning/evaluate.py @@ -67,8 +67,6 @@ def compress(args): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py index 39b6804bfca609db4506d4df17e96040623b3233..5cf71789220da3e1e3672f9bb06ce8252aa716fc 100644 --- a/demo/dygraph/unstructured_pruning/train.py +++ b/demo/dygraph/unstructured_pruning/train.py @@ -145,8 +145,6 @@ def compress(args): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) @@ -180,8 +178,6 @@ def compress(args): train_reader_cost += time.time() - reader_start x_data = data[0] y_data = paddle.to_tensor(data[1]) - if args.data == 'cifar10': - y_data = paddle.unsqueeze(y_data, 1) train_start = time.time() logits = model(x_data)