未验证 提交 0a76937e 编写于 作者: G gmm 提交者: GitHub

Merge branch 'develop' into no_paddle

...@@ -100,17 +100,32 @@ class DARTSearch(object): ...@@ -100,17 +100,32 @@ class DARTSearch(object):
def train_one_epoch(self, train_loader, valid_loader, architect, optimizer, def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
epoch): epoch):
objs = AvgrageMeter() objs = AvgrageMeter()
ce_losses = AvgrageMeter() top1 = AvgrageMeter()
kd_losses = AvgrageMeter() top5 = AvgrageMeter()
e_losses = AvgrageMeter()
self.model.train() self.model.train()
step_id = 0 for step_id, (
for train_data, valid_data in zip(train_loader(), valid_loader()): train_data,
valid_data) in enumerate(zip(train_loader(), valid_loader())):
train_image, train_label = train_data
valid_image, valid_label = valid_data
train_image = to_variable(train_image)
train_label = to_variable(train_label)
train_label.stop_gradient = True
valid_image = to_variable(valid_image)
valid_label = to_variable(valid_label)
valid_label.stop_gradient = True
n = train_image.shape[0]
if epoch >= self.epochs_no_archopt: if epoch >= self.epochs_no_archopt:
architect.step(train_data, valid_data) architect.step(train_image, train_label, valid_image,
valid_label)
loss, ce_loss, kd_loss, e_loss = self.model.loss(train_data) logits = self.model(train_image)
prec1 = fluid.layers.accuracy(input=logits, label=train_label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=train_label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, train_label))
if self.use_data_parallel: if self.use_data_parallel:
loss = self.model.scale_loss(loss) loss = self.model.scale_loss(loss)
...@@ -122,22 +137,18 @@ class DARTSearch(object): ...@@ -122,22 +137,18 @@ class DARTSearch(object):
optimizer.minimize(loss) optimizer.minimize(loss)
self.model.clear_gradients() self.model.clear_gradients()
batch_size = train_data[0].shape[0] objs.update(loss.numpy(), n)
objs.update(loss.numpy(), batch_size) top1.update(prec1.numpy(), n)
ce_losses.update(ce_loss.numpy(), batch_size) top5.update(prec5.numpy(), n)
kd_losses.update(kd_loss.numpy(), batch_size)
e_losses.update(e_loss.numpy(), batch_size)
if step_id % self.log_freq == 0: if step_id % self.log_freq == 0:
#logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format( #logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
# epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0])) # epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
logger.info( logger.info(
"Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}". "Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
loss.numpy(), 0]))
ce_loss.numpy(), kd_loss.numpy(), e_loss.numpy())) return top1.avg[0]
step_id += 1
return objs.avg[0]
def valid_one_epoch(self, valid_loader, epoch): def valid_one_epoch(self, valid_loader, epoch):
objs = AvgrageMeter() objs = AvgrageMeter()
...@@ -145,7 +156,7 @@ class DARTSearch(object): ...@@ -145,7 +156,7 @@ class DARTSearch(object):
top5 = AvgrageMeter() top5 = AvgrageMeter()
self.model.eval() self.model.eval()
for step_id, valid_data in enumerate(valid_loader): for step_id, (image, label) in enumerate(valid_loader):
image = to_variable(image) image = to_variable(image)
label = to_variable(label) label = to_variable(label)
n = image.shape[0] n = image.shape[0]
...@@ -235,12 +246,14 @@ class DARTSearch(object): ...@@ -235,12 +246,14 @@ class DARTSearch(object):
genotype = get_genotype(base_model) genotype = get_genotype(base_model)
logger.info('genotype = %s', genotype) logger.info('genotype = %s', genotype)
self.train_one_epoch(train_loader, valid_loader, architect, train_top1 = self.train_one_epoch(train_loader, valid_loader,
optimizer, epoch) architect, optimizer, epoch)
logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))
if epoch == self.num_epochs - 1: if epoch == self.num_epochs - 1:
# valid_top1 = self.valid_one_epoch(valid_loader, epoch) valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, 1)) logger.info("Epoch {}, valid_acc {:.6f}".format(epoch,
valid_top1))
if save_parameters: if save_parameters:
fluid.save_dygraph( fluid.save_dygraph(
self.model.state_dict(), self.model.state_dict(),
......
...@@ -542,7 +542,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -542,7 +542,7 @@ class depthwise_conv2d(PruneWorker):
self._visit(filter_var, 0) self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx) new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups) self.op.set_attr("groups", new_groups)
for op in filter_var.outputs(): for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx) self._prune_op(op, filter_var, 0, pruned_idx)
......
...@@ -33,7 +33,7 @@ class TestPrune(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestPrune(unittest.TestCase):
sum2 = conv4 + sum1 sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6") conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
self.assertTrue(1597440 == flops(main_program)) self.assertTrue(792576 == flops(main_program))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -57,7 +57,7 @@ class TestPrune(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestPrune(unittest.TestCase):
conv_op = graph.var("conv4_weights").outputs()[0] conv_op = graph.var("conv4_weights").outputs()[0]
walker = conv2d_walker(conv_op, []) walker = conv2d_walker(conv_op, [])
walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[]) walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[])
print walker.pruned_params print(walker.pruned_params)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册