提交 45a7cc62 编写于 作者: Z zhangxuefei

update img_cls.py

上级 80c18efc
...@@ -18,6 +18,11 @@ parser.add_argument( ...@@ -18,6 +18,11 @@ parser.add_argument(
help="Whether use GPU for fine-tuning.") help="Whether use GPU for fine-tuning.")
parser.add_argument( parser.add_argument(
"--checkpoint_dir", type=str, default=None, help="Path to save log data.") "--checkpoint_dir", type=str, default=None, help="Path to save log data.")
parser.add_argument(
"--module",
type=str,
default="mobilenet",
help="Module used as feature extractor.")
# the name of hyperparameters to be searched should keep with hparam.py # the name of hyperparameters to be searched should keep with hparam.py
parser.add_argument( parser.add_argument(
...@@ -37,6 +42,15 @@ parser.add_argument( ...@@ -37,6 +42,15 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--model_path", type=str, default="", help="load model path") "--model_path", type=str, default="", help="load model path")
module_map = {
"resnet50": "resnet_v2_50_imagenet",
"resnet101": "resnet_v2_101_imagenet",
"resnet152": "resnet_v2_152_imagenet",
"mobilenet": "mobilenet_v2_imagenet",
"nasnet": "nasnet_imagenet",
"pnasnet": "pnasnet_imagenet"
}
def is_path_valid(path): def is_path_valid(path):
if path == "": if path == "":
...@@ -49,8 +63,9 @@ def is_path_valid(path): ...@@ -49,8 +63,9 @@ def is_path_valid(path):
def finetune(args): def finetune(args):
# Load Paddlehub resnet50 pretrained model
module = hub.Module(name="mobilenet_v2_imagenet") # Load Paddlehub pretrained model, default as mobilenet
module = hub.Module(name=args.module)
input_dict, output_dict, program = module.context(trainable=True) input_dict, output_dict, program = module.context(trainable=True)
# Download dataset and use ImageClassificationReader to read dataset # Download dataset and use ImageClassificationReader to read dataset
...@@ -61,6 +76,7 @@ def finetune(args): ...@@ -61,6 +76,7 @@ def finetune(args):
images_mean=module.get_pretrained_images_mean(), images_mean=module.get_pretrained_images_mean(),
images_std=module.get_pretrained_images_std(), images_std=module.get_pretrained_images_std(),
dataset=dataset) dataset=dataset)
# The last 2 layer of resnet_v2_101_imagenet network # The last 2 layer of resnet_v2_101_imagenet network
feature_map = output_dict["feature_map"] feature_map = output_dict["feature_map"]
...@@ -69,7 +85,6 @@ def finetune(args): ...@@ -69,7 +85,6 @@ def finetune(args):
# Select finetune strategy, setup config and finetune # Select finetune strategy, setup config and finetune
strategy = hub.DefaultFinetuneStrategy(learning_rate=args.learning_rate) strategy = hub.DefaultFinetuneStrategy(learning_rate=args.learning_rate)
config = hub.RunConfig( config = hub.RunConfig(
use_cuda=True, use_cuda=True,
num_epoch=args.epochs, num_epoch=args.epochs,
...@@ -112,4 +127,9 @@ def finetune(args): ...@@ -112,4 +127,9 @@ def finetune(args):
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
finetune(args) finetune(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册