未验证 提交 6aa211fb 编写于 作者: Z zhangbo9674 提交者: GitHub

add amp train for mv3 (#5445)

上级 04671c18
......@@ -26,7 +26,9 @@ def train_one_epoch(
data_loader,
device,
epoch,
print_freq, ):
print_freq,
amp_level=None,
scaler=None):
model.train()
# training log
train_reader_cost = 0.0
......@@ -40,10 +42,20 @@ def train_one_epoch(
for batch_idx, (image, target) in enumerate(data_loader):
train_reader_cost += time.time() - reader_start
train_start = time.time()
if amp_level is not None:
with paddle.amp.auto_cast(level=amp_level):
output = model(image)
loss = criterion(output, target)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
else:
output = model(image)
loss = criterion(output, target)
loss.backward()
optimizer.step()
optimizer.clear_grad()
train_run_cost += time.time() - train_start
acc = utils.accuracy(output, target, topk=(1, 5))
......@@ -73,13 +85,18 @@ def train_one_epoch(
reader_start = time.time()
def evaluate(model, criterion, data_loader, device, print_freq=100):
def evaluate(model, criterion, data_loader, device, print_freq=100, amp_level=None):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with paddle.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq,
header):
if amp_level is not None:
with paddle.amp.auto_cast(level=amp_level):
output = model(image)
loss = criterion(output, target)
else:
output = model(image)
loss = criterion(output, target)
......@@ -203,12 +220,19 @@ def main(args):
opt_state_dict = paddle.load(os.path.join(args.resume, '.pdopt'))
optimizer.load_state_dict(opt_state_dict)
scaler = None
if args.amp_level is not None:
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if args.amp_level == 'O2':
model = paddle.amp.decorate(models=model, level='O2')
# multi cards
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)
if args.test_only and paddle.distributed.get_rank() == 0:
top1 = evaluate(model, criterion, data_loader_test, device=device)
top1 = evaluate(model, criterion, data_loader_test, device=device, amp_level=args.amp_level)
return top1
print("Start training")
......@@ -217,10 +241,10 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
train_one_epoch(model, criterion, optimizer, data_loader, device,
epoch, args.print_freq)
epoch, args.print_freq, args.amp_level, scaler)
lr_scheduler.step()
if paddle.distributed.get_rank() == 0:
top1 = evaluate(model, criterion, data_loader_test, device=device)
top1 = evaluate(model, criterion, data_loader_test, device=device, amp_level=args.amp_level)
if args.output_dir:
paddle.save(model.state_dict(),
os.path.join(args.output_dir,
......@@ -260,6 +284,10 @@ def get_args_parser(add_help=True):
type=int,
metavar='N',
help='number of total epochs to run')
parser.add_argument(
'--amp_level',
default=None,
help='amp level can set to be : O1/O2')
parser.add_argument(
'-j',
'--workers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册