未验证 提交 0a76937e 编写于 作者: G gmm 提交者: GitHub

Merge branch 'develop' into no_paddle

......@@ -100,17 +100,32 @@ class DARTSearch(object):
def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
epoch):
objs = AvgrageMeter()
ce_losses = AvgrageMeter()
kd_losses = AvgrageMeter()
e_losses = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
self.model.train()
step_id = 0
for train_data, valid_data in zip(train_loader(), valid_loader()):
for step_id, (
train_data,
valid_data) in enumerate(zip(train_loader(), valid_loader())):
train_image, train_label = train_data
valid_image, valid_label = valid_data
train_image = to_variable(train_image)
train_label = to_variable(train_label)
train_label.stop_gradient = True
valid_image = to_variable(valid_image)
valid_label = to_variable(valid_label)
valid_label.stop_gradient = True
n = train_image.shape[0]
if epoch >= self.epochs_no_archopt:
architect.step(train_data, valid_data)
architect.step(train_image, train_label, valid_image,
valid_label)
loss, ce_loss, kd_loss, e_loss = self.model.loss(train_data)
logits = self.model(train_image)
prec1 = fluid.layers.accuracy(input=logits, label=train_label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=train_label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, train_label))
if self.use_data_parallel:
loss = self.model.scale_loss(loss)
......@@ -122,22 +137,18 @@ class DARTSearch(object):
optimizer.minimize(loss)
self.model.clear_gradients()
batch_size = train_data[0].shape[0]
objs.update(loss.numpy(), batch_size)
ce_losses.update(ce_loss.numpy(), batch_size)
kd_losses.update(kd_loss.numpy(), batch_size)
e_losses.update(e_loss.numpy(), batch_size)
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)
if step_id % self.log_freq == 0:
#logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
# epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
logger.info(
"Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}".
format(epoch, step_id,
loss.numpy(),
ce_loss.numpy(), kd_loss.numpy(), e_loss.numpy()))
step_id += 1
return objs.avg[0]
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
0]))
return top1.avg[0]
def valid_one_epoch(self, valid_loader, epoch):
objs = AvgrageMeter()
......@@ -145,7 +156,7 @@ class DARTSearch(object):
top5 = AvgrageMeter()
self.model.eval()
for step_id, valid_data in enumerate(valid_loader):
for step_id, (image, label) in enumerate(valid_loader):
image = to_variable(image)
label = to_variable(label)
n = image.shape[0]
......@@ -235,12 +246,14 @@ class DARTSearch(object):
genotype = get_genotype(base_model)
logger.info('genotype = %s', genotype)
self.train_one_epoch(train_loader, valid_loader, architect,
optimizer, epoch)
train_top1 = self.train_one_epoch(train_loader, valid_loader,
architect, optimizer, epoch)
logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))
if epoch == self.num_epochs - 1:
# valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, 1))
valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch,
valid_top1))
if save_parameters:
fluid.save_dygraph(
self.model.state_dict(),
......
......@@ -542,7 +542,7 @@ class depthwise_conv2d(PruneWorker):
self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups)
self.op.set_attr("groups", new_groups)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
......
......@@ -33,7 +33,7 @@ class TestPrune(unittest.TestCase):
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
self.assertTrue(1597440 == flops(main_program))
self.assertTrue(792576 == flops(main_program))
if __name__ == '__main__':
......
......@@ -57,7 +57,7 @@ class TestPrune(unittest.TestCase):
conv_op = graph.var("conv4_weights").outputs()[0]
walker = conv2d_walker(conv_op, [])
walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[])
print walker.pruned_params
print(walker.pruned_params)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册