未验证 提交 6238fd7b 编写于 作者: M minghaoBD 提交者: GitHub

revert cifar compensation on the develop branch (#862)

上级 dfe5d3f7
......@@ -67,6 +67,8 @@ def compress(args):
start_time = time.time()
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
......
......@@ -145,6 +145,8 @@ def compress(args):
start_time = time.time()
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
......@@ -178,6 +180,8 @@ def compress(args):
train_reader_cost += time.time() - reader_start
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
train_start = time.time()
logits = model(x_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册