diff --git a/examples/handwritten_number_recognition/mnist.py b/examples/handwritten_number_recognition/mnist.py index 340784920696e9b808ef2876d486e9bc4b47acb9..36db96ca3715e877c443480bbda598a7545dc1b7 100644 --- a/examples/handwritten_number_recognition/mnist.py +++ b/examples/handwritten_number_recognition/mnist.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ def main(): val_dataset, epochs=FLAGS.epoch, batch_size=FLAGS.batch_size, - save_dir='mnist_checkpoint') + save_dir=FLAGS.output_dir) if __name__ == '__main__': @@ -81,7 +81,10 @@ if __name__ == '__main__': parser.add_argument( "-b", "--batch_size", default=128, type=int, help="batch size") parser.add_argument( - "--output-dir", type=str, default='output', help="checkpoint save dir") + "--output-dir", + type=str, + default='mnist_checkpoint', + help="checkpoint save dir") parser.add_argument( "-r", "--resume",