diff --git a/paddleslim/nas/darts/train_search.py b/paddleslim/nas/darts/train_search.py index 38a6a3c9b3a76a2159e7a040b415c9ed497b64b8..ce37a44fe9b1a35a2adff58fcdaf13321f7d8ec2 100644 --- a/paddleslim/nas/darts/train_search.py +++ b/paddleslim/nas/darts/train_search.py @@ -100,17 +100,32 @@ class DARTSearch(object): def train_one_epoch(self, train_loader, valid_loader, architect, optimizer, epoch): objs = AvgrageMeter() - ce_losses = AvgrageMeter() - kd_losses = AvgrageMeter() - e_losses = AvgrageMeter() + top1 = AvgrageMeter() + top5 = AvgrageMeter() self.model.train() - step_id = 0 - for train_data, valid_data in zip(train_loader(), valid_loader()): + for step_id, ( + train_data, + valid_data) in enumerate(zip(train_loader(), valid_loader())): + train_image, train_label = train_data + valid_image, valid_label = valid_data + train_image = to_variable(train_image) + train_label = to_variable(train_label) + train_label.stop_gradient = True + valid_image = to_variable(valid_image) + valid_label = to_variable(valid_label) + valid_label.stop_gradient = True + n = train_image.shape[0] + if epoch >= self.epochs_no_archopt: - architect.step(train_data, valid_data) + architect.step(train_image, train_label, valid_image, + valid_label) - loss, ce_loss, kd_loss, e_loss = self.model.loss(train_data) + logits = self.model(train_image) + prec1 = fluid.layers.accuracy(input=logits, label=train_label, k=1) + prec5 = fluid.layers.accuracy(input=logits, label=train_label, k=5) + loss = fluid.layers.reduce_mean( + fluid.layers.softmax_with_cross_entropy(logits, train_label)) if self.use_data_parallel: loss = self.model.scale_loss(loss) @@ -122,22 +137,18 @@ class DARTSearch(object): optimizer.minimize(loss) self.model.clear_gradients() - batch_size = train_data[0].shape[0] - objs.update(loss.numpy(), batch_size) - ce_losses.update(ce_loss.numpy(), batch_size) - kd_losses.update(kd_loss.numpy(), batch_size) - e_losses.update(e_loss.numpy(), batch_size) + objs.update(loss.numpy(), n) + top1.update(prec1.numpy(), n) + top5.update(prec5.numpy(), n) if step_id % self.log_freq == 0: #logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format( # epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0])) logger.info( - "Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}". - format(epoch, step_id, - loss.numpy(), - ce_loss.numpy(), kd_loss.numpy(), e_loss.numpy())) - step_id += 1 - return objs.avg[0] + "Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}". + format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[ + 0])) + return top1.avg[0] def valid_one_epoch(self, valid_loader, epoch): objs = AvgrageMeter() @@ -145,7 +156,7 @@ class DARTSearch(object): top5 = AvgrageMeter() self.model.eval() - for step_id, valid_data in enumerate(valid_loader): + for step_id, (image, label) in enumerate(valid_loader): image = to_variable(image) label = to_variable(label) n = image.shape[0] @@ -235,12 +246,14 @@ class DARTSearch(object): genotype = get_genotype(base_model) logger.info('genotype = %s', genotype) - self.train_one_epoch(train_loader, valid_loader, architect, - optimizer, epoch) + train_top1 = self.train_one_epoch(train_loader, valid_loader, + architect, optimizer, epoch) + logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1)) if epoch == self.num_epochs - 1: - # valid_top1 = self.valid_one_epoch(valid_loader, epoch) - logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, 1)) + valid_top1 = self.valid_one_epoch(valid_loader, epoch) + logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, + valid_top1)) if save_parameters: fluid.save_dygraph( self.model.state_dict(), diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 9407bdda02d13efeb59345f1d292cd6b0b9f432d..cbb1cc30466b6b1899af27be5b3a7e4d8366fb1d 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -542,7 +542,7 @@ class depthwise_conv2d(PruneWorker): self._visit(filter_var, 0) new_groups = filter_var.shape()[0] - len(pruned_idx) - op.set_attr("groups", new_groups) + self.op.set_attr("groups", new_groups) for op in filter_var.outputs(): self._prune_op(op, filter_var, 0, pruned_idx) diff --git a/tests/test_flops.py b/tests/test_flops.py index cd16b8618d0271e6a0b7e609f8820e16c380b9db..9d50ebc573a4320c1343874309ce75815bd53bd2 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -33,7 +33,7 @@ class TestPrune(unittest.TestCase): sum2 = conv4 + sum1 conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv6 = conv_bn_layer(conv5, 8, 3, "conv6") - self.assertTrue(1597440 == flops(main_program)) + self.assertTrue(792576 == flops(main_program)) if __name__ == '__main__': diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py index b80f6903904dd0bfbb75495247336bc64c089e3a..6db1155ca05a1425def66611aea0ddc3de78c200 100644 --- a/tests/test_prune_walker.py +++ b/tests/test_prune_walker.py @@ -57,7 +57,7 @@ class TestPrune(unittest.TestCase): conv_op = graph.var("conv4_weights").outputs()[0] walker = conv2d_walker(conv_op, []) walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[]) - print walker.pruned_params + print(walker.pruned_params) if __name__ == '__main__':