未验证 提交 c1a52d0e 编写于 作者: Z Zeyu Chen 提交者: GitHub

Merge pull request #22 from Steffy-zxf/add-optimizer

Add more optimizer
...@@ -33,7 +33,7 @@ class RunConfig(object): ...@@ -33,7 +33,7 @@ class RunConfig(object):
use_cuda=False, use_cuda=False,
checkpoint_dir=None, checkpoint_dir=None,
num_epoch=10, num_epoch=10,
batch_size=None, batch_size=8,
enable_memory_optim=True, enable_memory_optim=True,
strategy=None): strategy=None):
""" Construct finetune Config """ """ Construct finetune Config """
......
...@@ -47,6 +47,30 @@ class DefaultStrategy(object): ...@@ -47,6 +47,30 @@ class DefaultStrategy(object):
if self._optimizer_name.lower() == "sgd": if self._optimizer_name.lower() == "sgd":
self.optimizer = fluid.optimizer.SGD( self.optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate) learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adagrad":
self.optimizer = fluid.optimizer.Adagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adamax":
self.optimizer = fluid.optimizer.Adamax(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "ftrl":
self.optimizer = fluid.optimizer.Ftrl(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "larsmomentum":
self.optimizer = fluid.optimizer.LarsMomentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "momentum":
self.optimizer = fluid.optimizer.Momentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "rmsprop":
self.optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=self.learning_rate)
else: else:
self.optimizer = fluid.optimizer.Adam( self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate) learning_rate=self.learning_rate)
...@@ -132,6 +156,30 @@ class DefaultFinetuneStrategy(DefaultStrategy): ...@@ -132,6 +156,30 @@ class DefaultFinetuneStrategy(DefaultStrategy):
if self._optimizer_name.lower() == "sgd": if self._optimizer_name.lower() == "sgd":
self.optimizer = fluid.optimizer.SGD( self.optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate) learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adagrad":
self.optimizer = fluid.optimizer.Adagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adamax":
self.optimizer = fluid.optimizer.Adamax(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "ftrl":
self.optimizer = fluid.optimizer.Ftrl(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "larsmomentum":
self.optimizer = fluid.optimizer.LarsMomentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "momentum":
self.optimizer = fluid.optimizer.Momentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "rmsprop":
self.optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=self.learning_rate)
else: else:
self.optimizer = fluid.optimizer.Adam( self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate) learning_rate=self.learning_rate)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册