未验证 提交 c1a52d0e 编写于 作者: Z Zeyu Chen 提交者: GitHub

Merge pull request #22 from Steffy-zxf/add-optimizer

Add more optimizer
......@@ -33,7 +33,7 @@ class RunConfig(object):
use_cuda=False,
checkpoint_dir=None,
num_epoch=10,
batch_size=None,
batch_size=8,
enable_memory_optim=True,
strategy=None):
""" Construct finetune Config """
......
......@@ -47,6 +47,30 @@ class DefaultStrategy(object):
if self._optimizer_name.lower() == "sgd":
self.optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adagrad":
self.optimizer = fluid.optimizer.Adagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adamax":
self.optimizer = fluid.optimizer.Adamax(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "ftrl":
self.optimizer = fluid.optimizer.Ftrl(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "larsmomentum":
self.optimizer = fluid.optimizer.LarsMomentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "momentum":
self.optimizer = fluid.optimizer.Momentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "rmsprop":
self.optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=self.learning_rate)
else:
self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate)
......@@ -132,6 +156,30 @@ class DefaultFinetuneStrategy(DefaultStrategy):
if self._optimizer_name.lower() == "sgd":
self.optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adagrad":
self.optimizer = fluid.optimizer.Adagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "adamax":
self.optimizer = fluid.optimizer.Adamax(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "ftrl":
self.optimizer = fluid.optimizer.Ftrl(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "larsmomentum":
self.optimizer = fluid.optimizer.LarsMomentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "momentum":
self.optimizer = fluid.optimizer.Momentum(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "decayedadagrad":
self.optimizer = fluid.optimizer.DecayedAdagrad(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "rmsprop":
self.optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=self.learning_rate)
else:
self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册