提交 c65c49eb 编写于 作者: W wuzewu 提交者: zhangxuefei

update strategy, use adam as default optimizer

上级 1510df97
...@@ -44,12 +44,12 @@ class DefaultStrategy(object): ...@@ -44,12 +44,12 @@ class DefaultStrategy(object):
self._optimizer_name = optimizer_name self._optimizer_name = optimizer_name
def execute(self, loss): def execute(self, loss):
if self.optimizer.lower() == "adam": if self._optimizer_name.lower() == "sgd":
self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate)
elif self.optimizer.lower() == "sgd":
self.optimizer = fluid.optimizer.SGD( self.optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate) learning_rate=self.learning_rate)
else:
self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate)
if self.optimizer is not None: if self.optimizer is not None:
self.optimizer.minimize(loss) self.optimizer.minimize(loss)
...@@ -129,12 +129,12 @@ class DefaultFinetuneStrategy(DefaultStrategy): ...@@ -129,12 +129,12 @@ class DefaultFinetuneStrategy(DefaultStrategy):
self.regularization_coeff = regularization_coeff self.regularization_coeff = regularization_coeff
def execute(self, loss): def execute(self, loss):
if self._optimizer_name.lower() == "adam": if self._optimizer_name.lower() == "sgd":
self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate)
elif self._optimizer_name.lower() == "sgd":
self.optimizer = fluid.optimizer.SGD( self.optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate) learning_rate=self.learning_rate)
else:
self.optimizer = fluid.optimizer.Adam(
learning_rate=self.learning_rate)
# get pretrained parameters # get pretrained parameters
program = loss.block.program program = loss.block.program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册