diff --git a/example/mnist_demo/lenet5_config.py b/example/mnist_demo/lenet5_config.py index f1a27456d47d6f5a8bacacb23d0e14d21a59b2b8..896d7d0fc31686d70f860b96edcacc9739e7a97c 100644 --- a/example/mnist_demo/lenet5_config.py +++ b/example/mnist_demo/lenet5_config.py @@ -22,7 +22,7 @@ mnist_cfg = edict({ 'num_classes': 10, # the number of classes of model's output 'lr': 0.01, # the learning rate of model's optimizer 'momentum': 0.9, # the momentum value of model's optimizer - 'epoch_size': 5, # training epochs + 'epoch_size': 10, # training epochs 'batch_size': 256, # batch size for training 'image_height': 32, # the height of training samples 'image_width': 32, # the width of training samples diff --git a/example/mnist_demo/lenet5_dp.py b/example/mnist_demo/lenet5_dp.py index 6468cd38b1e2c47ca72dbb05f0ef31bcaf5ffba7..65aa63ce4b73a0fd0c7dc254963880d502ab17d0 100644 --- a/example/mnist_demo/lenet5_dp.py +++ b/example/mnist_demo/lenet5_dp.py @@ -155,7 +155,7 @@ if __name__ == "__main__": dataset_sink_mode=cfg.dataset_sink_mode) LOGGER.info(TAG, "============== Starting Testing ==============") - ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-5_234.ckpt' + ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' param_dict = load_checkpoint(ckpt_file_name) load_param_into_net(network, param_dict) ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'),