diff --git a/demo/object_detection/train_faster_rcnn.py b/demo/object_detection/train_faster_rcnn.py index 883182a8195f3ead6a5be93bc30262f90a793c93..e64600ec79e13088512c94fd8a4a720030ff9164 100644 --- a/demo/object_detection/train_faster_rcnn.py +++ b/demo/object_detection/train_faster_rcnn.py @@ -63,7 +63,7 @@ def finetune(args): enable_memory_optim=False, checkpoint_dir=args.checkpoint_dir, strategy=hub.finetune.strategy.DefaultFinetuneStrategy( - learning_rate=0.00025, optimizer_name="adam")) + learning_rate=0.00025, optimizer_name="momentum", momentum=0.9)) task = hub.FasterRCNNTask( data_reader=data_reader, diff --git a/demo/object_detection/train_ssd.py b/demo/object_detection/train_ssd.py index 0bdec2937be9552f79266c43eebb74bc29d66ef5..e4c77b9761c822cace27e52e5bd539f2148fa9b1 100644 --- a/demo/object_detection/train_ssd.py +++ b/demo/object_detection/train_ssd.py @@ -45,7 +45,7 @@ def finetune(args): enable_memory_optim=False, checkpoint_dir=args.checkpoint_dir, strategy=hub.finetune.strategy.DefaultFinetuneStrategy( - learning_rate=0.00025, optimizer_name="adam")) + learning_rate=0.00025, optimizer_name="momentum", momentum=0.9)) task = hub.SSDTask( data_reader=data_reader, diff --git a/demo/object_detection/train_yolo.py b/demo/object_detection/train_yolo.py index 16d7dcc75ddcf23a66310843eca2441ff71ea647..f9e0711bc57e17a90f7162db9a2f2d86c7c3409d 100644 --- a/demo/object_detection/train_yolo.py +++ b/demo/object_detection/train_yolo.py @@ -45,7 +45,7 @@ def finetune(args): enable_memory_optim=False, checkpoint_dir=args.checkpoint_dir, strategy=hub.finetune.strategy.DefaultFinetuneStrategy( - learning_rate=0.00025, optimizer_name="adam")) + learning_rate=0.00025, optimizer_name="momentum", momentum=0.9)) task = hub.YOLOTask( data_reader=data_reader, diff --git a/paddlehub/finetune/strategy.py b/paddlehub/finetune/strategy.py index 9073f9414e4b44078c59dfa5cb251cc031456a9f..dba1902823e045a52fad4a83e33de076d0473248 100644 --- a/paddlehub/finetune/strategy.py +++ b/paddlehub/finetune/strategy.py @@ -133,39 +133,39 @@ def set_gradual_unfreeze(depth_params_dict, unfreeze_depths): class DefaultStrategy(object): - def __init__(self, learning_rate=1e-4, optimizer_name="adam"): + def __init__(self, learning_rate=1e-4, optimizer_name="adam", **kwargs): self.learning_rate = learning_rate self._optimizer_name = optimizer_name if self._optimizer_name.lower() == "sgd": self.optimizer = fluid.optimizer.SGD( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "adagrad": self.optimizer = fluid.optimizer.Adagrad( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "adamax": self.optimizer = fluid.optimizer.Adamax( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "decayedadagrad": self.optimizer = fluid.optimizer.DecayedAdagrad( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "ftrl": self.optimizer = fluid.optimizer.Ftrl( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "larsmomentum": self.optimizer = fluid.optimizer.LarsMomentum( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "momentum": self.optimizer = fluid.optimizer.Momentum( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "decayedadagrad": self.optimizer = fluid.optimizer.DecayedAdagrad( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) elif self._optimizer_name.lower() == "rmsprop": self.optimizer = fluid.optimizer.RMSPropOptimizer( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) else: self.optimizer = fluid.optimizer.Adam( - learning_rate=self.learning_rate) + learning_rate=self.learning_rate, **kwargs) def execute(self, loss, data_reader, config, dev_count): if self.optimizer is not None: @@ -186,10 +186,13 @@ class CombinedStrategy(DefaultStrategy): learning_rate=1e-4, scheduler=None, regularization=None, - clip=None): + clip=None, + **kwargs): super(CombinedStrategy, self).__init__( - optimizer_name=optimizer_name, learning_rate=learning_rate) - + optimizer_name=optimizer_name, + learning_rate=learning_rate, + **kwargs) + self.kwargs = kwargs # init set self.scheduler = { "warmup": 0.0, @@ -379,7 +382,9 @@ class CombinedStrategy(DefaultStrategy): # set optimizer super(CombinedStrategy, self).__init__( - optimizer_name=self._optimizer_name, learning_rate=scheduled_lr) + optimizer_name=self._optimizer_name, + learning_rate=scheduled_lr, + **self.kwargs) # discriminative learning rate # based on layer @@ -568,7 +573,8 @@ class AdamWeightDecayStrategy(CombinedStrategy): lr_scheduler="linear_decay", warmup_proportion=0.1, weight_decay=0.01, - optimizer_name="adam"): + optimizer_name="adam", + **kwargs): scheduler = {"warmup": warmup_proportion} if lr_scheduler == "noam_decay": scheduler["noam_decay"] = True @@ -587,14 +593,16 @@ class AdamWeightDecayStrategy(CombinedStrategy): learning_rate=learning_rate, scheduler=scheduler, regularization=regularization, - clip=clip) + clip=clip, + **kwargs) class L2SPFinetuneStrategy(CombinedStrategy): def __init__(self, learning_rate=1e-4, optimizer_name="adam", - regularization_coeff=1e-3): + regularization_coeff=1e-3, + **kwargs): scheduler = {} regularization = {"L2SP": regularization_coeff} clip = {} @@ -603,14 +611,16 @@ class L2SPFinetuneStrategy(CombinedStrategy): learning_rate=learning_rate, scheduler=scheduler, regularization=regularization, - clip=clip) + clip=clip, + **kwargs) class DefaultFinetuneStrategy(CombinedStrategy): def __init__(self, learning_rate=1e-4, optimizer_name="adam", - regularization_coeff=1e-3): + regularization_coeff=1e-3, + **kwargs): scheduler = {} regularization = {"L2": regularization_coeff} clip = {} @@ -620,7 +630,8 @@ class DefaultFinetuneStrategy(CombinedStrategy): learning_rate=learning_rate, scheduler=scheduler, regularization=regularization, - clip=clip) + clip=clip, + **kwargs) class ULMFiTStrategy(CombinedStrategy): @@ -632,7 +643,8 @@ class ULMFiTStrategy(CombinedStrategy): dis_blocks=3, factor=2.6, frz_blocks=3, - params_layer=None): + params_layer=None, + **kwargs): scheduler = { "slanted_triangle": { @@ -656,4 +668,5 @@ class ULMFiTStrategy(CombinedStrategy): learning_rate=learning_rate, scheduler=scheduler, regularization=regularization, - clip=clip) + clip=clip, + **kwargs)