diff --git a/mnist.py b/mnist.py index 787499a7d708c442d55df2f899cbaa7f6581e93b..617fd18586b332c05e293b798d7b7dca13322080 100644 --- a/mnist.py +++ b/mnist.py @@ -118,8 +118,6 @@ if __name__ == '__main__': guard = null_guard() with guard: - # sgd = SGDOptimizer(learning_rate=1e-3) - sgd = MomentumOptimizer(learning_rate=1e-3, momentum=0.9) train_loader = fluid.io.xmap_readers( lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), np.array([x[1] for x in b]).reshape(-1, 1)], @@ -131,6 +129,9 @@ if __name__ == '__main__': paddle.batch(paddle.dataset.mnist.test(), batch_size=4, drop_last=True), 1, 1) model = MNIST() + sgd = MomentumOptimizer(learning_rate=1e-3, momentum=0.9, + parameter_list=model.parameters()) + # sgd = SGDOptimizer(learning_rate=1e-3) model.prepare(sgd, 'cross_entropy') for e in range(2): diff --git a/model.py b/model.py index 5cf0a647bf1592970b84238c1514933c218bd47c..58f5446fe35f5a901c30eb9e81c4d7775a913373 100644 --- a/model.py +++ b/model.py @@ -110,6 +110,9 @@ class StaticGraphAdapter(object): self.mode = 'test' return self._run(inputs, None, device, device_ids) + def parameters(self, *args, **kwargs): + return None + def save(self, path): prog = self._progs.get('train', None) if prog is None or self.model._optimizer is None: @@ -287,6 +290,9 @@ class DynamicGraphAdapter(object): inputs = to_list(inputs) return self.model.forward(*[to_variable(x) for x in inputs]) + def parameters(self, *args, **kwargs): + return super(Model, self.model).parameters(*args, **kwargs) + def save(self, path): params = self.model.state_dict() fluid.save_dygraph(params, path) @@ -346,3 +352,6 @@ class Model(fluid.dygraph.Layer): def prepare(self, optimizer, loss_functions): self._optimizer = optimizer self._loss_functions = to_list(loss_functions) + + def parameters(self, *args, **kwargs): + return self._adapter.parameters(*args, **kwargs)