diff --git a/examples/image_classification/main.py b/examples/image_classification/main.py index 546991528631909d5f75caec4df96c63053e7fdb..76360df91cd64a66e2e288c90a37ac667cdc3eea 100644 --- a/examples/image_classification/main.py +++ b/examples/image_classification/main.py @@ -76,6 +76,9 @@ def main(): device = set_device(FLAGS.device) fluid.enable_dygraph(device) if FLAGS.dynamic else None + model_list = [x for x in models.__dict__["__all__"]] + assert FLAGS.arch in model_list, "Expected FLAGS.arch in {}, but received {}".format( + model_list, FLAGS.arch) model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and not FLAGS.resume) @@ -94,7 +97,13 @@ def main(): len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks), parameter_list=model.parameters()) - model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels) + model.prepare( + optim, + CrossEntropy(), + Accuracy(topk=(1, 5)), + inputs, + labels, + FLAGS.device) if FLAGS.eval_only: model.evaluate( @@ -152,7 +161,7 @@ if __name__ == '__main__': type=str, help="checkpoint path to resume") parser.add_argument( - "--eval-only", action='store_true', help="enable dygraph mode") + "--eval-only", action='store_true', help="only evaluate the model") parser.add_argument( "--lr-scheduler", default='piecewise',