From c3d166ab23f4522c35a6b89a78d35c0d00276257 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Thu, 16 Apr 2020 03:54:31 +0000 Subject: [PATCH] enhance error message, fix device bug --- examples/image_classification/main.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/image_classification/main.py b/examples/image_classification/main.py index 5469915..76360df 100644 --- a/examples/image_classification/main.py +++ b/examples/image_classification/main.py @@ -76,6 +76,9 @@ def main(): device = set_device(FLAGS.device) fluid.enable_dygraph(device) if FLAGS.dynamic else None + model_list = [x for x in models.__dict__["__all__"]] + assert FLAGS.arch in model_list, "Expected FLAGS.arch in {}, but received {}".format( + model_list, FLAGS.arch) model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and not FLAGS.resume) @@ -94,7 +97,13 @@ def main(): len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks), parameter_list=model.parameters()) - model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels) + model.prepare( + optim, + CrossEntropy(), + Accuracy(topk=(1, 5)), + inputs, + labels, + FLAGS.device) if FLAGS.eval_only: model.evaluate( @@ -152,7 +161,7 @@ if __name__ == '__main__': type=str, help="checkpoint path to resume") parser.add_argument( - "--eval-only", action='store_true', help="enable dygraph mode") + "--eval-only", action='store_true', help="only evaluate the model") parser.add_argument( "--lr-scheduler", default='piecewise', -- GitLab