diff --git a/02.recognize_digits/train.py b/02.recognize_digits/train.py index d9947e4eb0159fda5539333b83afd5d44d62305f..8b91432e4a64524ccf9b0aee102b8e0cc17f8110 100644 --- a/02.recognize_digits/train.py +++ b/02.recognize_digits/train.py @@ -62,6 +62,10 @@ def train_program(): return [avg_cost, acc] +def optimizer_program(): + return fluid.optimizer.Adam(learning_rate=0.001) + + def main(): train_reader = paddle.batch( paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), @@ -71,10 +75,9 @@ def main(): use_cuda = os.getenv('WITH_GPU', '0') != '0' place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - optimizer = fluid.optimizer.Adam(learning_rate=0.001) trainer = fluid.Trainer( - train_func=train_program, place=place, optimizer=optimizer) + train_func=train_program, place=place, optimizer_func=optimizer_program) # Save the parameter into a directory. The Inferencer can load the parameters from it to do infer params_dirname = "recognize_digits_network.inference.model"