diff --git a/demo/dygraph/unstructured_pruning/evaluate.py b/demo/dygraph/unstructured_pruning/evaluate.py index e8827d1f5e7a6ec6ee20ff6134f996966064dd34..d2c6aa56c13cc2e0cae84f00a8f71e66b6bff05b 100644 --- a/demo/dygraph/unstructured_pruning/evaluate.py +++ b/demo/dygraph/unstructured_pruning/evaluate.py @@ -67,6 +67,8 @@ def compress(args): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py index 7bd3f479b9ce3243b9922a50365ca3b47cbf4ea7..1c859684c394d00d7d1fd2c2c6bf5499a735c251 100644 --- a/demo/dygraph/unstructured_pruning/train.py +++ b/demo/dygraph/unstructured_pruning/train.py @@ -145,6 +145,8 @@ def compress(args): start_time = time.time() x_data = data[0] y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) logits = model(x_data) loss = F.cross_entropy(logits, y_data) @@ -178,6 +180,8 @@ def compress(args): train_reader_cost += time.time() - reader_start x_data = data[0] y_data = paddle.to_tensor(data[1]) + if args.data == 'cifar10': + y_data = paddle.unsqueeze(y_data, 1) train_start = time.time() logits = model(x_data)