提交 be377fb1 编写于 作者: Z zhenghuanhuan

1IKCU

fix [MA][diff_privacy][Doc] the tutorials of diff_privacy has problem
上级 30f6c260
...@@ -92,15 +92,15 @@ if __name__ == "__main__": ...@@ -92,15 +92,15 @@ if __name__ == "__main__":
parser.add_argument('--data_path', type=str, default="./MNIST_unzip", parser.add_argument('--data_path', type=str, default="./MNIST_unzip",
help='path where the dataset is saved') help='path where the dataset is saved')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
parser.add_argument('--micro_batches', type=float, default=None, parser.add_argument('--micro_batches', type=int, default=None,
help='optional, if use differential privacy, need to set micro_batches') help='optional, if use differential privacy, need to set micro_batches')
parser.add_argument('--l2_norm_bound', type=float, default=1, parser.add_argument('--l2_norm_bound', type=float, default=0.1,
help='optional, if use differential privacy, need to set l2_norm_bound') help='optional, if use differential privacy, need to set l2_norm_bound')
parser.add_argument('--initial_noise_multiplier', type=float, default=0.001, parser.add_argument('--initial_noise_multiplier', type=float, default=0.001,
help='optional, if use differential privacy, need to set initial_noise_multiplier') help='optional, if use differential privacy, need to set initial_noise_multiplier')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False) context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
network = LeNet5() network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
...@@ -61,6 +61,5 @@ def mnist_train(epoch_size, batch_size, lr, momentum): ...@@ -61,6 +61,5 @@ def mnist_train(epoch_size, batch_size, lr, momentum):
if __name__ == '__main__': if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
enable_mem_reuse=False)
mnist_train(10, 32, 0.01, 0.9) mnist_train(10, 32, 0.01, 0.9)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册