提交 d9f12c49 编写于 作者: P pkuliuliu

support graph mode in DPOptimizer

上级 425cc952
...@@ -87,8 +87,7 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, ...@@ -87,8 +87,7 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
if __name__ == "__main__": if __name__ == "__main__":
# This configure just can run in pynative mode. context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target)
network = LeNet5() network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
......
...@@ -71,7 +71,7 @@ class DPOptimizerClassFactory: ...@@ -71,7 +71,7 @@ class DPOptimizerClassFactory:
def __init__(self, micro_batches=2): def __init__(self, micro_batches=2):
self._mech_factory = NoiseMechanismsFactory() self._mech_factory = NoiseMechanismsFactory()
self.mech = None self._mech = None
self._micro_batches = check_int_positive('micro_batches', micro_batches) self._micro_batches = check_int_positive('micro_batches', micro_batches)
def set_mechanisms(self, policy, *args, **kwargs): def set_mechanisms(self, policy, *args, **kwargs):
...@@ -81,9 +81,9 @@ class DPOptimizerClassFactory: ...@@ -81,9 +81,9 @@ class DPOptimizerClassFactory:
Args: Args:
policy (str): Choose mechanism type. policy (str): Choose mechanism type.
""" """
self.mech = self._mech_factory.create(policy, *args, **kwargs) self._mech = self._mech_factory.create(policy, *args, **kwargs)
def create(self, policy, *args, **kwargs): def create(self, policy):
""" """
Create DP optimizer. Create DP optimizer.
...@@ -93,25 +93,29 @@ class DPOptimizerClassFactory: ...@@ -93,25 +93,29 @@ class DPOptimizerClassFactory:
Returns: Returns:
Optimizer, A optimizer with DP. Optimizer, A optimizer with DP.
""" """
if policy == 'SGD': dp_opt_class = None
cls = self._get_dp_optimizer_class(nn.SGD, self.mech, self._micro_batches, *args, **kwargs) policy_ = policy.lower()
return cls if policy_ == 'sgd':
if policy == 'Momentum': dp_opt_class = self._get_dp_optimizer_class(nn.SGD)
cls = self._get_dp_optimizer_class(nn.Momentum, self.mech, self._micro_batches, *args, **kwargs) elif policy_ == 'momentum':
return cls dp_opt_class = self._get_dp_optimizer_class(nn.Momentum)
if policy == 'Adam': elif policy_ == 'adam':
cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs) dp_opt_class = self._get_dp_optimizer_class(nn.Adam)
return cls else:
msg = "The {} is not implement, please choose ['SGD', 'Momentum', 'Adam']".format(policy) msg = "The {} optimizer is not implement, please choose ['SGD', 'Momentum', 'Adam']" \
LOGGER.error(TAG, msg) .format(policy)
raise NameError(msg) LOGGER.error(TAG, msg)
raise NameError(msg)
def _get_dp_optimizer_class(self, cls, mech, micro_batches): return dp_opt_class
def _get_dp_optimizer_class(self, opt_class):
""" """
Wrap original mindspore optimizer with `self._mech`. Wrap original mindspore optimizer with `self._mech`.
""" """
mech = self._mech
micro_batches = self._micro_batches
class DPOptimizer(cls): class DPOptimizer(opt_class):
""" """
Initialize the DPOptimizerClass. Initialize the DPOptimizerClass.
...@@ -124,7 +128,7 @@ class DPOptimizerClassFactory: ...@@ -124,7 +128,7 @@ class DPOptimizerClassFactory:
self._mech = mech self._mech = mech
self._tuple_add = _TupleAdd() self._tuple_add = _TupleAdd()
self._hyper_map = C.HyperMap() self._hyper_map = C.HyperMap()
self._micro_float = Tensor(micro_batches, mstype.float32) self._micro_batches = Tensor(micro_batches, mstype.float32)
self._mech_param_updater = None self._mech_param_updater = None
if self._mech is not None and self._mech._decay_policy is not None: if self._mech is not None and self._mech._decay_policy is not None:
...@@ -139,14 +143,20 @@ class DPOptimizerClassFactory: ...@@ -139,14 +143,20 @@ class DPOptimizerClassFactory:
""" """
construct a compute flow. construct a compute flow.
""" """
grad_noise = self._hyper_map(self._mech, gradients) # generate noise
grads = self._tuple_add(gradients, grad_noise) grad_noise_tuple = ()
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) for grad_item in gradients:
grad_noise = self._mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
# add noise
gradients = self._tuple_add(gradients, grad_noise_tuple)
# div by self._micro_batches
gradients = self._hyper_map(F.partial(_grad_scale, self._micro_batches), gradients)
# update mech parameters # update mech parameters
if self._mech_param_updater is not None: if self._mech_param_updater is not None:
multiplier = self._mech_param_updater() multiplier = self._mech_param_updater()
grads = F.depend(grads, multiplier) gradients = F.depend(gradients, multiplier)
gradients = super(DPOptimizer, self).construct(grads) gradients = super(DPOptimizer, self).construct(gradients)
return gradients return gradients
return DPOptimizer return DPOptimizer
...@@ -142,10 +142,6 @@ class DPModel(Model): ...@@ -142,10 +142,6 @@ class DPModel(Model):
raise ValueError(msg) raise ValueError(msg)
if noise_mech is None: if noise_mech is None:
if "DPOptimizer" in opt_name: if "DPOptimizer" in opt_name:
if context.get_context('mode') != context.PYNATIVE_MODE:
msg = 'DPOptimizer just support pynative mode currently.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if 'Ada' in opt._mech.__class__.__name__ and clip_mech is not None: if 'Ada' in opt._mech.__class__.__name__ and clip_mech is not None:
msg = "When DPOptimizer's mech method is adaptive, clip_mech must be None." msg = "When DPOptimizer's mech method is adaptive, clip_mech must be None."
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册