提交 1d66467d 编写于 作者: J jinyaohui

opt add ps logic

上级 8300802b
......@@ -71,7 +71,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
......@@ -110,26 +109,45 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
"Tensor", "Tensor", "Tensor")
"Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2):
moment1, moment2, ps_parameter):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0]))
if ps_parameter:
op_shape = P.Shape()
_ps_pull = P.Pull()
_ps_push = P.Push("Adam", [0, 1, 2])
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0]))
success = F.depend(success, _ps_pull(_ps_push((beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0]), shapes), params))
else:
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0]))
return success
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor")
"Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2):
moment1, moment2, ps_parameter):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
if ps_parameter:
op_shape = P.Shape()
_ps_pull = P.Pull()
_ps_push = P.Push("Adam", [0, 1, 2])
success = F.depend(success, _ps_pull(_ps_push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
(op_shape(params), op_shape(moment1), op_shape(moment2))),
params))
else:
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success
@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tuple", "Tensor", "Tensor", "Tensor")
def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
......@@ -156,6 +174,7 @@ def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, bet
(op_shape(params), op_shape(moment1), op_shape(moment2))), params))
return success
class Adam(Optimizer):
r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
......@@ -293,13 +312,14 @@ class Adam(Optimizer):
if self.is_group_lr:
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
lr, gradients, params, moment1, moment2, self.ps_parameters)
else:
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
gradients, params, moment1, moment2, self.ps_parameters)
return success
class PSAdam(Optimizer):
'''The same usage as Adam optimizer except the parameters are set PS mode.'''
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
......@@ -346,6 +366,7 @@ class PSAdam(Optimizer):
gradients, params, moment1, moment2)
return success
class AdamWeightDecay(Optimizer):
"""
Implements Adam algorithm weight decay fix.
......
......@@ -26,22 +26,38 @@ _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt")
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor",
"Tensor")
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"Tensor", "Bool")
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment,
ps_parameter):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True
success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0]))
if ps_parameter:
op_shape = P.Shape()
_ps_pull = P.Pull()
_ps_push = P.Push("Ftrl", [0, 1, 2])
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0]))
success = F.depend(success, _ps_pull(_ps_push((gradient[1], gradient[0]), shapes), weight))
else:
success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0]))
return success
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor")
def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"Tensor", "Bool")
def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, ps_parameter):
"""Apply ftrl optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
if ps_parameter:
op_shape = P.Shape()
_ps_pull = P.Pull()
_ps_push = P.Push("Ftrl", [0, 1, 2])
success = F.depend(success, _ps_pull(_ps_push((gradient, learning_rate, l1, l2, lr_power),
(op_shape(weight), op_shape(moment), op_shape(linear))), weight))
else:
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
return success
@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple",
"Tensor", "Tensor")
def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient,
......@@ -63,6 +79,7 @@ def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2,
(op_shape(weight), op_shape(moment), op_shape(linear))), weight))
return success
def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None):
"""Check param."""
validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
......@@ -150,9 +167,10 @@ class FTRL(Optimizer):
grads = self.scale_grad(grads)
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
linear, grads, params, moments)
linear, grads, params, moments, self.ps_parameters)
return success
class PSFTRL(Optimizer):
def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,
use_locking=False, loss_scale=1.0, weight_decay=0.0):
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""momentum"""
from mindspore.ops import functional as F, composite as C
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops import _selected_ops
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
......@@ -25,11 +25,18 @@ from .optimizer import Optimizer
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
if ps_parameter:
op_shape = P.Shape()
_ps_pull = P.Pull()
_ps_push = P.Push("Momentum", [])
shapes = (op_shape(learning_rate), op_shape(gradient), op_shape(momentum))
success = F.depend(success, _ps_pull(_ps_push((learning_rate, gradient, momentum), shapes), weight))
else:
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
......@@ -127,7 +134,9 @@ class Momentum(Optimizer):
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.is_group_lr:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments)
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,
self.ps_parameters)
else:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments,
self.ps_parameters)
return success
......@@ -152,6 +152,8 @@ class Optimizer(Cell):
self.weight_decay = weight_decay * loss_scale
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.exec_weight_decay = any(self.decay_flags)
self.param_length = len(self.parameters)
......
......@@ -511,6 +511,7 @@ class Push(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None):
"""init Push"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key'])
def infer_shape(self, inputs, shapes):
......@@ -534,6 +535,7 @@ class Pull(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init Pull"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['key', 'weight'], outputs=['output'])
def infer_shape(self, key_shape, weight_shape):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册