diff --git a/example/mnist_demo/lenet5_dp_model_train.py b/example/mnist_demo/lenet5_dp_model_train.py index 61a359a205a33f19fd4d5004499ec68c254d5447..089c23f0c3b7617cc1246093f0989111e129d5eb 100644 --- a/example/mnist_demo/lenet5_dp_model_train.py +++ b/example/mnist_demo/lenet5_dp_model_train.py @@ -92,15 +92,15 @@ if __name__ == "__main__": parser.add_argument('--data_path', type=str, default="./MNIST_unzip", help='path where the dataset is saved') parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') - parser.add_argument('--micro_batches', type=float, default=None, + parser.add_argument('--micro_batches', type=int, default=None, help='optional, if use differential privacy, need to set micro_batches') - parser.add_argument('--l2_norm_bound', type=float, default=1, + parser.add_argument('--l2_norm_bound', type=float, default=0.1, help='optional, if use differential privacy, need to set l2_norm_bound') parser.add_argument('--initial_noise_multiplier', type=float, default=0.001, help='optional, if use differential privacy, need to set initial_noise_multiplier') args = parser.parse_args() - context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False) + context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) network = LeNet5() net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") diff --git a/example/mnist_demo/mnist_train.py b/example/mnist_demo/mnist_train.py index 81daa8b8cde31984a81d9cd639c9fda7f780266d..7e595b4e57c19b31f8ee844638490f7fffe211ba 100644 --- a/example/mnist_demo/mnist_train.py +++ b/example/mnist_demo/mnist_train.py @@ -61,6 +61,5 @@ def mnist_train(epoch_size, batch_size, lr, momentum): if __name__ == '__main__': - context.set_context(mode=context.GRAPH_MODE, device_target="CPU", - enable_mem_reuse=False) + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") mnist_train(10, 32, 0.01, 0.9)