diff --git a/demo/autofinetune/img_cls.py b/demo/autofinetune/img_cls.py index be7dbebc2a4111464d6066e8ebc50896ba7a28e4..163a98bbe6067415658ea0715ac871a2ce9f020f 100644 --- a/demo/autofinetune/img_cls.py +++ b/demo/autofinetune/img_cls.py @@ -18,6 +18,11 @@ parser.add_argument( help="Whether use GPU for fine-tuning.") parser.add_argument( "--checkpoint_dir", type=str, default=None, help="Path to save log data.") +parser.add_argument( + "--module", + type=str, + default="mobilenet", + help="Module used as feature extractor.") # the name of hyperparameters to be searched should keep with hparam.py parser.add_argument( @@ -37,6 +42,15 @@ parser.add_argument( parser.add_argument( "--model_path", type=str, default="", help="load model path") +module_map = { + "resnet50": "resnet_v2_50_imagenet", + "resnet101": "resnet_v2_101_imagenet", + "resnet152": "resnet_v2_152_imagenet", + "mobilenet": "mobilenet_v2_imagenet", + "nasnet": "nasnet_imagenet", + "pnasnet": "pnasnet_imagenet" +} + def is_path_valid(path): if path == "": @@ -49,8 +63,9 @@ def is_path_valid(path): def finetune(args): - # Load Paddlehub resnet50 pretrained model - module = hub.Module(name="mobilenet_v2_imagenet") + + # Load Paddlehub pretrained model, default as mobilenet + module = hub.Module(name=args.module) input_dict, output_dict, program = module.context(trainable=True) # Download dataset and use ImageClassificationReader to read dataset @@ -61,6 +76,7 @@ def finetune(args): images_mean=module.get_pretrained_images_mean(), images_std=module.get_pretrained_images_std(), dataset=dataset) + # The last 2 layer of resnet_v2_101_imagenet network feature_map = output_dict["feature_map"] @@ -69,7 +85,6 @@ def finetune(args): # Select finetune strategy, setup config and finetune strategy = hub.DefaultFinetuneStrategy(learning_rate=args.learning_rate) - config = hub.RunConfig( use_cuda=True, num_epoch=args.epochs, @@ -112,4 +127,9 @@ def finetune(args): if __name__ == "__main__": args = parser.parse_args() + if not args.module in module_map: + hub.logger.error("module should in %s" % module_map.keys()) + exit(1) + args.module = module_map[args.module] + finetune(args)