From 6aa211fbb6c20e714f32e9d4952527e6ecd60245 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 29 Dec 2021 01:08:22 +0800 Subject: [PATCH] add amp train for mv3 (#5445) --- tutorials/mobilenetv3_prod/Step6/train.py | 50 ++++++++++++++++++----- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/tutorials/mobilenetv3_prod/Step6/train.py b/tutorials/mobilenetv3_prod/Step6/train.py index be4ccd3d..e1b10c49 100755 --- a/tutorials/mobilenetv3_prod/Step6/train.py +++ b/tutorials/mobilenetv3_prod/Step6/train.py @@ -26,7 +26,9 @@ def train_one_epoch( data_loader, device, epoch, - print_freq, ): + print_freq, + amp_level=None, + scaler=None): model.train() # training log train_reader_cost = 0.0 @@ -40,10 +42,20 @@ def train_one_epoch( for batch_idx, (image, target) in enumerate(data_loader): train_reader_cost += time.time() - reader_start train_start = time.time() - output = model(image) - loss = criterion(output, target) - loss.backward() - optimizer.step() + + if amp_level is not None: + with paddle.amp.auto_cast(level=amp_level): + output = model(image) + loss = criterion(output, target) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + else: + output = model(image) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.clear_grad() train_run_cost += time.time() - train_start acc = utils.accuracy(output, target, topk=(1, 5)) @@ -73,15 +85,20 @@ def train_one_epoch( reader_start = time.time() -def evaluate(model, criterion, data_loader, device, print_freq=100): +def evaluate(model, criterion, data_loader, device, print_freq=100, amp_level=None): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' with paddle.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): - output = model(image) - loss = criterion(output, target) + if amp_level is not None: + with paddle.amp.auto_cast(level=amp_level): + output = model(image) + loss = criterion(output, target) + else: + output = model(image) + loss = criterion(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets @@ -202,13 +219,20 @@ def main(args): model.set_state_dict(layer_state_dict) opt_state_dict = paddle.load(os.path.join(args.resume, '.pdopt')) optimizer.load_state_dict(opt_state_dict) + + scaler = None + if args.amp_level is not None: + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + if args.amp_level == 'O2': + model = paddle.amp.decorate(models=model, level='O2') + # multi cards if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) if args.test_only and paddle.distributed.get_rank() == 0: - top1 = evaluate(model, criterion, data_loader_test, device=device) + top1 = evaluate(model, criterion, data_loader_test, device=device, amp_level=args.amp_level) return top1 print("Start training") @@ -217,10 +241,10 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): train_one_epoch(model, criterion, optimizer, data_loader, device, - epoch, args.print_freq) + epoch, args.print_freq, args.amp_level, scaler) lr_scheduler.step() if paddle.distributed.get_rank() == 0: - top1 = evaluate(model, criterion, data_loader_test, device=device) + top1 = evaluate(model, criterion, data_loader_test, device=device, amp_level=args.amp_level) if args.output_dir: paddle.save(model.state_dict(), os.path.join(args.output_dir, @@ -260,6 +284,10 @@ def get_args_parser(add_help=True): type=int, metavar='N', help='number of total epochs to run') + parser.add_argument( + '--amp_level', + default=None, + help='amp level can set to be : O1/O2') parser.add_argument( '-j', '--workers', -- GitLab