未验证 提交 bc182afb 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix device (#5450)

上级 a381d625
......@@ -50,7 +50,7 @@ MobileNetV3 是 2019 年提出的一种基于 NAS 的新的轻量级网络,为
| 模型 | top1/5 acc (参考精度) | top1/5 acc (复现精度) | 下载链接 |
|:---------:|:------:|:----------:|:----------:|
| Mo | -/- | 0.601/0.826 | [预训练模型](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_pretrained.pdparams) \| [Inference模型(coming soon!)]() \| [日志](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/train_mobilenet_v3_small.log) |
| Mo | -/- | 0.601/0.826 | [预训练模型](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_pretrained.pdparams) \| [Inference模型](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_infer.tar) \| [日志](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/train_mobilenet_v3_small.log) |
<a name="3"></a>
......
......@@ -19,16 +19,14 @@ import numpy as np
from reprod_log import ReprodLogger
def train_one_epoch(
model,
criterion,
optimizer,
data_loader,
device,
epoch,
print_freq,
amp_level=None,
scaler=None):
def train_one_epoch(model,
criterion,
optimizer,
data_loader,
epoch,
print_freq,
amp_level=None,
scaler=None):
model.train()
# training log
train_reader_cost = 0.0
......@@ -85,7 +83,7 @@ def train_one_epoch(
reader_start = time.time()
def evaluate(model, criterion, data_loader, device, print_freq=100, amp_level=None):
def evaluate(model, criterion, data_loader, print_freq=100, amp_level=None):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
......@@ -160,7 +158,10 @@ def main(args):
print(args)
device = paddle.set_device(args.device)
try:
paddle.set_device(args.device)
except:
print("device set error, use default device...")
# multi cards
if paddle.distributed.get_world_size() > 1:
......@@ -219,20 +220,20 @@ def main(args):
model.set_state_dict(layer_state_dict)
opt_state_dict = paddle.load(os.path.join(args.resume, '.pdopt'))
optimizer.load_state_dict(opt_state_dict)
scaler = None
if args.amp_level is not None:
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if args.amp_level == 'O2':
model = paddle.amp.decorate(models=model, level='O2')
# multi cards
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)
if args.test_only and paddle.distributed.get_rank() == 0:
top1 = evaluate(model, criterion, data_loader_test, device=device, amp_level=args.amp_level)
top1 = evaluate(
model, criterion, data_loader_test, amp_level=args.amp_level)
return top1
print("Start training")
......@@ -240,11 +241,12 @@ def main(args):
best_top1 = 0.0
for epoch in range(args.start_epoch, args.epochs):
train_one_epoch(model, criterion, optimizer, data_loader, device,
epoch, args.print_freq, args.amp_level, scaler)
train_one_epoch(model, criterion, optimizer, data_loader, epoch,
args.print_freq, args.amp_level, scaler)
lr_scheduler.step()
if paddle.distributed.get_rank() == 0:
top1 = evaluate(model, criterion, data_loader_test, device=device, amp_level=args.amp_level)
top1 = evaluate(
model, criterion, data_loader_test, amp_level=args.amp_level)
if args.output_dir:
paddle.save(model.state_dict(),
os.path.join(args.output_dir,
......@@ -285,9 +287,7 @@ def get_args_parser(add_help=True):
metavar='N',
help='number of total epochs to run')
parser.add_argument(
'--amp_level',
default=None,
help='amp level can set to be : O1/O2')
'--amp_level', default=None, help='amp level can set to be : O1/O2')
parser.add_argument(
'-j',
'--workers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册