未验证 提交 da0f5a6d 编写于 作者: M minghaoBD 提交者: GitHub

fix cifar10 dataloader and model loading, test=develop (#825)

* fix cifar10 dataloader and model loading, test=develop

* update readme

* update readme
上级 5438acae
...@@ -89,9 +89,11 @@ python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshol ...@@ -89,9 +89,11 @@ python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshol
## 推理: ## 推理:
```bash ```bash
python3.7 evalualte.py --pruned_model models/ --data imagenet python3.7 evaluate.py --pruned_model models/model-pruned.pdparams --data imagenet
``` ```
**注意**,上述`pruned_model` 参数应该指向pdparams文件。
剪裁训练代码示例: 剪裁训练代码示例:
```python ```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True) model = mobilenet_v1(num_classes=class_dim, pretrained=True)
......
...@@ -67,8 +67,6 @@ def compress(args): ...@@ -67,8 +67,6 @@ def compress(args):
start_time = time.time() start_time = time.time()
x_data = data[0] x_data = data[0]
y_data = paddle.to_tensor(data[1]) y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data) logits = model(x_data)
loss = F.cross_entropy(logits, y_data) loss = F.cross_entropy(logits, y_data)
......
...@@ -145,8 +145,6 @@ def compress(args): ...@@ -145,8 +145,6 @@ def compress(args):
start_time = time.time() start_time = time.time()
x_data = data[0] x_data = data[0]
y_data = paddle.to_tensor(data[1]) y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data) logits = model(x_data)
loss = F.cross_entropy(logits, y_data) loss = F.cross_entropy(logits, y_data)
...@@ -180,8 +178,6 @@ def compress(args): ...@@ -180,8 +178,6 @@ def compress(args):
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
x_data = data[0] x_data = data[0]
y_data = paddle.to_tensor(data[1]) y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
train_start = time.time() train_start = time.time()
logits = model(x_data) logits = model(x_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册