提交 d9f12c49 编写于 作者: P pkuliuliu

support graph mode in DPOptimizer

上级 425cc952
......@@ -87,8 +87,7 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
if __name__ == "__main__":
# This configure just can run in pynative mode.
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
......
......@@ -71,7 +71,7 @@ class DPOptimizerClassFactory:
def __init__(self, micro_batches=2):
self._mech_factory = NoiseMechanismsFactory()
self.mech = None
self._mech = None
self._micro_batches = check_int_positive('micro_batches', micro_batches)
def set_mechanisms(self, policy, *args, **kwargs):
......@@ -81,9 +81,9 @@ class DPOptimizerClassFactory:
Args:
policy (str): Choose mechanism type.
"""
self.mech = self._mech_factory.create(policy, *args, **kwargs)
self._mech = self._mech_factory.create(policy, *args, **kwargs)
def create(self, policy, *args, **kwargs):
def create(self, policy):
"""
Create DP optimizer.
......@@ -93,25 +93,29 @@ class DPOptimizerClassFactory:
Returns:
Optimizer, A optimizer with DP.
"""
if policy == 'SGD':
cls = self._get_dp_optimizer_class(nn.SGD, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'Momentum':
cls = self._get_dp_optimizer_class(nn.Momentum, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'Adam':
cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs)
return cls
msg = "The {} is not implement, please choose ['SGD', 'Momentum', 'Adam']".format(policy)
LOGGER.error(TAG, msg)
raise NameError(msg)
def _get_dp_optimizer_class(self, cls, mech, micro_batches):
dp_opt_class = None
policy_ = policy.lower()
if policy_ == 'sgd':
dp_opt_class = self._get_dp_optimizer_class(nn.SGD)
elif policy_ == 'momentum':
dp_opt_class = self._get_dp_optimizer_class(nn.Momentum)
elif policy_ == 'adam':
dp_opt_class = self._get_dp_optimizer_class(nn.Adam)
else:
msg = "The {} optimizer is not implement, please choose ['SGD', 'Momentum', 'Adam']" \
.format(policy)
LOGGER.error(TAG, msg)
raise NameError(msg)
return dp_opt_class
def _get_dp_optimizer_class(self, opt_class):
"""
Wrap original mindspore optimizer with `self._mech`.
"""
mech = self._mech
micro_batches = self._micro_batches
class DPOptimizer(cls):
class DPOptimizer(opt_class):
"""
Initialize the DPOptimizerClass.
......@@ -124,7 +128,7 @@ class DPOptimizerClassFactory:
self._mech = mech
self._tuple_add = _TupleAdd()
self._hyper_map = C.HyperMap()
self._micro_float = Tensor(micro_batches, mstype.float32)
self._micro_batches = Tensor(micro_batches, mstype.float32)
self._mech_param_updater = None
if self._mech is not None and self._mech._decay_policy is not None:
......@@ -139,14 +143,20 @@ class DPOptimizerClassFactory:
"""
construct a compute flow.
"""
grad_noise = self._hyper_map(self._mech, gradients)
grads = self._tuple_add(gradients, grad_noise)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)
# generate noise
grad_noise_tuple = ()
for grad_item in gradients:
grad_noise = self._mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
# add noise
gradients = self._tuple_add(gradients, grad_noise_tuple)
# div by self._micro_batches
gradients = self._hyper_map(F.partial(_grad_scale, self._micro_batches), gradients)
# update mech parameters
if self._mech_param_updater is not None:
multiplier = self._mech_param_updater()
grads = F.depend(grads, multiplier)
gradients = super(DPOptimizer, self).construct(grads)
gradients = F.depend(gradients, multiplier)
gradients = super(DPOptimizer, self).construct(gradients)
return gradients
return DPOptimizer
......@@ -142,10 +142,6 @@ class DPModel(Model):
raise ValueError(msg)
if noise_mech is None:
if "DPOptimizer" in opt_name:
if context.get_context('mode') != context.PYNATIVE_MODE:
msg = 'DPOptimizer just support pynative mode currently.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if 'Ada' in opt._mech.__class__.__name__ and clip_mech is not None:
msg = "When DPOptimizer's mech method is adaptive, clip_mech must be None."
LOGGER.error(TAG, msg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册