diff --git a/example/vgg16_cifar10/train.py b/example/vgg16_cifar10/train.py index a4aa587c3dd63ef7eaa403a29189cf595b69c523..87cea2af03c603cfc07bc23ca4e9d77401f4c1de 100644 --- a/example/vgg16_cifar10/train.py +++ b/example/vgg16_cifar10/train.py @@ -68,7 +68,8 @@ if __name__ == '__main__': lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) dataset = dataset.create_dataset(args_opt.data_path, cfg.epoch_size) batch_num = dataset.get_dataset_size()