提交 90a90cf9 编写于 作者: Y Yang Zhang

Update to use new `Optimizer` interface

上级 7a64fd68
...@@ -118,8 +118,6 @@ if __name__ == '__main__': ...@@ -118,8 +118,6 @@ if __name__ == '__main__':
guard = null_guard() guard = null_guard()
with guard: with guard:
# sgd = SGDOptimizer(learning_rate=1e-3)
sgd = MomentumOptimizer(learning_rate=1e-3, momentum=0.9)
train_loader = fluid.io.xmap_readers( train_loader = fluid.io.xmap_readers(
lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
np.array([x[1] for x in b]).reshape(-1, 1)], np.array([x[1] for x in b]).reshape(-1, 1)],
...@@ -131,6 +129,9 @@ if __name__ == '__main__': ...@@ -131,6 +129,9 @@ if __name__ == '__main__':
paddle.batch(paddle.dataset.mnist.test(), paddle.batch(paddle.dataset.mnist.test(),
batch_size=4, drop_last=True), 1, 1) batch_size=4, drop_last=True), 1, 1)
model = MNIST() model = MNIST()
sgd = MomentumOptimizer(learning_rate=1e-3, momentum=0.9,
parameter_list=model.parameters())
# sgd = SGDOptimizer(learning_rate=1e-3)
model.prepare(sgd, 'cross_entropy') model.prepare(sgd, 'cross_entropy')
for e in range(2): for e in range(2):
......
...@@ -110,6 +110,9 @@ class StaticGraphAdapter(object): ...@@ -110,6 +110,9 @@ class StaticGraphAdapter(object):
self.mode = 'test' self.mode = 'test'
return self._run(inputs, None, device, device_ids) return self._run(inputs, None, device, device_ids)
def parameters(self, *args, **kwargs):
return None
def save(self, path): def save(self, path):
prog = self._progs.get('train', None) prog = self._progs.get('train', None)
if prog is None or self.model._optimizer is None: if prog is None or self.model._optimizer is None:
...@@ -287,6 +290,9 @@ class DynamicGraphAdapter(object): ...@@ -287,6 +290,9 @@ class DynamicGraphAdapter(object):
inputs = to_list(inputs) inputs = to_list(inputs)
return self.model.forward(*[to_variable(x) for x in inputs]) return self.model.forward(*[to_variable(x) for x in inputs])
def parameters(self, *args, **kwargs):
return super(Model, self.model).parameters(*args, **kwargs)
def save(self, path): def save(self, path):
params = self.model.state_dict() params = self.model.state_dict()
fluid.save_dygraph(params, path) fluid.save_dygraph(params, path)
...@@ -346,3 +352,6 @@ class Model(fluid.dygraph.Layer): ...@@ -346,3 +352,6 @@ class Model(fluid.dygraph.Layer):
def prepare(self, optimizer, loss_functions): def prepare(self, optimizer, loss_functions):
self._optimizer = optimizer self._optimizer = optimizer
self._loss_functions = to_list(loss_functions) self._loss_functions = to_list(loss_functions)
def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册