提交 78e55ba0 编写于 作者: P pkuliuliu

Fix Example Error with Wrong Policy of Noise Mech

上级 cb2825e3
...@@ -22,7 +22,7 @@ mnist_cfg = edict({ ...@@ -22,7 +22,7 @@ mnist_cfg = edict({
'num_classes': 10, # the number of classes of model's output 'num_classes': 10, # the number of classes of model's output
'lr': 0.01, # the learning rate of model's optimizer 'lr': 0.01, # the learning rate of model's optimizer
'momentum': 0.9, # the momentum value of model's optimizer 'momentum': 0.9, # the momentum value of model's optimizer
'epoch_size': 10, # training epochs 'epoch_size': 5, # training epochs
'batch_size': 256, # batch size for training 'batch_size': 256, # batch size for training
'image_height': 32, # the height of training samples 'image_height': 32, # the height of training samples
'image_width': 32, # the width of training samples 'image_width': 32, # the width of training samples
...@@ -31,9 +31,9 @@ mnist_cfg = edict({ ...@@ -31,9 +31,9 @@ mnist_cfg = edict({
'device_target': 'Ascend', # device used 'device_target': 'Ascend', # device used
'data_path': './MNIST_unzip', # the path of training and testing data set 'data_path': './MNIST_unzip', # the path of training and testing data set
'dataset_sink_mode': False, # whether deliver all training data to device one time 'dataset_sink_mode': False, # whether deliver all training data to device one time
'micro_batches': 16, # the number of small batches split from an original batch 'micro_batches': 32, # the number of small batches split from an original batch
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters 'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training 'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training
# parameters' gradients # parameters' gradients
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training 'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training
......
...@@ -22,7 +22,7 @@ mnist_cfg = edict({ ...@@ -22,7 +22,7 @@ mnist_cfg = edict({
'num_classes': 10, # the number of classes of model's output 'num_classes': 10, # the number of classes of model's output
'lr': 0.01, # the learning rate of model's optimizer 'lr': 0.01, # the learning rate of model's optimizer
'momentum': 0.9, # the momentum value of model's optimizer 'momentum': 0.9, # the momentum value of model's optimizer
'epoch_size': 10, # training epochs 'epoch_size': 5, # training epochs
'batch_size': 256, # batch size for training 'batch_size': 256, # batch size for training
'image_height': 32, # the height of training samples 'image_height': 32, # the height of training samples
'image_width': 32, # the width of training samples 'image_width': 32, # the width of training samples
...@@ -31,9 +31,9 @@ mnist_cfg = edict({ ...@@ -31,9 +31,9 @@ mnist_cfg = edict({
'device_target': 'Ascend', # device used 'device_target': 'Ascend', # device used
'data_path': './MNIST_unzip', # the path of training and testing data set 'data_path': './MNIST_unzip', # the path of training and testing data set
'dataset_sink_mode': False, # whether deliver all training data to device one time 'dataset_sink_mode': False, # whether deliver all training data to device one time
'micro_batches': 16, # the number of small batches split from an original batch 'micro_batches': 32, # the number of small batches split from an original batch
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters 'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training 'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training
# parameters' gradients # parameters' gradients
'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training 'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training 'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training
......
...@@ -155,7 +155,7 @@ if __name__ == "__main__": ...@@ -155,7 +155,7 @@ if __name__ == "__main__":
dataset_sink_mode=cfg.dataset_sink_mode) dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============") LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-5_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name) param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'),
......
...@@ -141,7 +141,7 @@ if __name__ == "__main__": ...@@ -141,7 +141,7 @@ if __name__ == "__main__":
dataset_sink_mode=cfg.dataset_sink_mode) dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============") LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-5_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name) param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'),
......
...@@ -111,7 +111,7 @@ if __name__ == "__main__": ...@@ -111,7 +111,7 @@ if __name__ == "__main__":
dp_opt.set_mechanisms(cfg.noise_mechanisms, dp_opt.set_mechanisms(cfg.noise_mechanisms,
norm_bound=cfg.norm_bound, norm_bound=cfg.norm_bound,
initial_noise_multiplier=cfg.initial_noise_multiplier, initial_noise_multiplier=cfg.initial_noise_multiplier,
decay_policy='Exp') decay_policy=None)
# Create a factory class of clip mechanisms, this method is to adaptive clip # Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric', # gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm, # learning_rate is the learning rate to update clip_norm,
...@@ -147,7 +147,7 @@ if __name__ == "__main__": ...@@ -147,7 +147,7 @@ if __name__ == "__main__":
dataset_sink_mode=cfg.dataset_sink_mode) dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============") LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-5_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name) param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size)
......
...@@ -656,14 +656,16 @@ class _TrainOneStepCell(Cell): ...@@ -656,14 +656,16 @@ class _TrainOneStepCell(Cell):
record_grad = self.grad(self.network, weights)(record_datas[0], record_grad = self.grad(self.network, weights)(record_datas[0],
record_labels[0], sens) record_labels[0], sens)
beta = self._zero beta = self._zero
square_sum = self._zero # calcu beta
for grad in record_grad: if self._clip_mech is not None:
square_sum = self._add(square_sum, square_sum = self._zero
self._reduce_sum(self._square_all(grad))) for grad in record_grad:
norm_grad = self._sqrt(square_sum) square_sum = self._add(square_sum,
beta = self._add(beta, self._reduce_sum(self._square_all(grad)))
self._cast(self._less(norm_grad, self._norm_bound), norm_grad = self._sqrt(square_sum)
mstype.float32)) beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_bound),
mstype.float32))
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE,
self._norm_bound) self._norm_bound)
...@@ -675,14 +677,16 @@ class _TrainOneStepCell(Cell): ...@@ -675,14 +677,16 @@ class _TrainOneStepCell(Cell):
record_grad = self.grad(self.network, weights)(record_datas[i], record_grad = self.grad(self.network, weights)(record_datas[i],
record_labels[i], record_labels[i],
sens) sens)
square_sum = self._zero # calcu beta
for grad in record_grad: if self._clip_mech is not None:
square_sum = self._add(square_sum, square_sum = self._zero
self._reduce_sum(self._square_all(grad))) for grad in record_grad:
norm_grad = self._sqrt(square_sum) square_sum = self._add(square_sum,
beta = self._add(beta, self._reduce_sum(self._square_all(grad)))
self._cast(self._less(norm_grad, self._norm_bound), norm_grad = self._sqrt(square_sum)
mstype.float32)) beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_bound),
mstype.float32))
record_grad = self._clip_by_global_norm(record_grad, record_grad = self._clip_by_global_norm(record_grad,
GRADIENT_CLIP_TYPE, GRADIENT_CLIP_TYPE,
...@@ -690,7 +694,6 @@ class _TrainOneStepCell(Cell): ...@@ -690,7 +694,6 @@ class _TrainOneStepCell(Cell):
grads = self._tuple_add(grads, record_grad) grads = self._tuple_add(grads, record_grad)
total_loss = P.TensorAdd()(total_loss, loss) total_loss = P.TensorAdd()(total_loss, loss)
loss = self._div(total_loss, self._micro_float) loss = self._div(total_loss, self._micro_float)
beta = self._div(beta, self._micro_batches)
if self._noise_mech is not None: if self._noise_mech is not None:
grad_noise_tuple = () grad_noise_tuple = ()
...@@ -710,8 +713,9 @@ class _TrainOneStepCell(Cell): ...@@ -710,8 +713,9 @@ class _TrainOneStepCell(Cell):
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
if self._clip_mech is not None: if self._clip_mech is not None:
beta = self._div(beta, self._micro_batches)
next_norm_bound = self._clip_mech(beta, self._norm_bound) next_norm_bound = self._clip_mech(beta, self._norm_bound)
self._norm_bound = self._assign(self._norm_bound, next_norm_bound) self._norm_bound = self._assign(self._norm_bound, next_norm_bound)
loss = F.depend(loss, next_norm_bound) loss = F.depend(loss, self._norm_bound)
return F.depend(loss, self.optimizer(grads)) return F.depend(loss, self.optimizer(grads))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册