From 90a90cf993ad24c866f9b84bcb10485644c4cdbd Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 31 Dec 2019 14:09:56 +0800 Subject: [PATCH] Update to use new `Optimizer` interface --- mnist.py | 5 +++-- model.py | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mnist.py b/mnist.py index 787499a..617fd18 100644 --- a/mnist.py +++ b/mnist.py @@ -118,8 +118,6 @@ if __name__ == '__main__': guard = null_guard() with guard: - # sgd = SGDOptimizer(learning_rate=1e-3) - sgd = MomentumOptimizer(learning_rate=1e-3, momentum=0.9) train_loader = fluid.io.xmap_readers( lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), np.array([x[1] for x in b]).reshape(-1, 1)], @@ -131,6 +129,9 @@ if __name__ == '__main__': paddle.batch(paddle.dataset.mnist.test(), batch_size=4, drop_last=True), 1, 1) model = MNIST() + sgd = MomentumOptimizer(learning_rate=1e-3, momentum=0.9, + parameter_list=model.parameters()) + # sgd = SGDOptimizer(learning_rate=1e-3) model.prepare(sgd, 'cross_entropy') for e in range(2): diff --git a/model.py b/model.py index 5cf0a64..58f5446 100644 --- a/model.py +++ b/model.py @@ -110,6 +110,9 @@ class StaticGraphAdapter(object): self.mode = 'test' return self._run(inputs, None, device, device_ids) + def parameters(self, *args, **kwargs): + return None + def save(self, path): prog = self._progs.get('train', None) if prog is None or self.model._optimizer is None: @@ -287,6 +290,9 @@ class DynamicGraphAdapter(object): inputs = to_list(inputs) return self.model.forward(*[to_variable(x) for x in inputs]) + def parameters(self, *args, **kwargs): + return super(Model, self.model).parameters(*args, **kwargs) + def save(self, path): params = self.model.state_dict() fluid.save_dygraph(params, path) @@ -346,3 +352,6 @@ class Model(fluid.dygraph.Layer): def prepare(self, optimizer, loss_functions): self._optimizer = optimizer self._loss_functions = to_list(loss_functions) + + def parameters(self, *args, **kwargs): + return self._adapter.parameters(*args, **kwargs) -- GitLab