提交 c3d166ab 编写于 作者: L LielinJiang

enhance error message, fix device bug

上级 e3b4d985
......@@ -76,6 +76,9 @@ def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model_list = [x for x in models.__dict__["__all__"]]
assert FLAGS.arch in model_list, "Expected FLAGS.arch in {}, but received {}".format(
model_list, FLAGS.arch)
model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and
not FLAGS.resume)
......@@ -94,7 +97,13 @@ def main():
len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks),
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels)
model.prepare(
optim,
CrossEntropy(),
Accuracy(topk=(1, 5)),
inputs,
labels,
FLAGS.device)
if FLAGS.eval_only:
model.evaluate(
......@@ -152,7 +161,7 @@ if __name__ == '__main__':
type=str,
help="checkpoint path to resume")
parser.add_argument(
"--eval-only", action='store_true', help="enable dygraph mode")
"--eval-only", action='store_true', help="only evaluate the model")
parser.add_argument(
"--lr-scheduler",
default='piecewise',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册