提交 10053f3d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3098 IndexedSlices adapter of sparse optimizer

Merge pull request !3098 from wangnan39/sparse_optimizer_adapter_indexedslice
...@@ -108,24 +108,26 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po ...@@ -108,24 +108,26 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", @_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "IndexedSlices",
"Tensor", "Tensor", "Tensor", "Bool") "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2, ps_parameter): moment1, moment2, ps_parameter):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
indices = gradient.indices()
values = gradient.values()
if ps_parameter: if ps_parameter:
op_shape = P.Shape() op_shape = P.Shape()
_ps_pull = P.Pull() _ps_pull = P.Pull()
_ps_push = P.Push("Adam", [0, 1, 2]) _ps_push = P.Push("Adam", [0, 1, 2])
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0])) op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
success = F.depend(success, _ps_pull(_ps_push((beta1_power, beta2_power, lr, beta1, beta2, success = F.depend(success, _ps_pull(_ps_push((beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0]), shapes), params)) eps, values, indices), shapes), params))
else: else:
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0])) eps, values, indices))
return success return success
...@@ -149,17 +151,19 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b ...@@ -149,17 +151,19 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b
@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tuple", "Tensor", "Tensor", "Tensor") "Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor")
def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2): moment1, moment2):
"""Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse.""" """Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse."""
success = True success = True
op_shape = P.Shape() op_shape = P.Shape()
values = gradient.values()
indices = gradient.indices()
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0])) op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0]), shapes), params)) eps, values, indices), shapes), params))
return success return success
......
...@@ -25,20 +25,22 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") ...@@ -25,20 +25,22 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
_ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt")
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices", "Tensor",
"Tensor", "Bool") "Tensor", "Bool")
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment,
ps_parameter): ps_parameter):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
indices = gradient.indices()
values = gradient.values()
if ps_parameter: if ps_parameter:
op_shape = P.Shape() op_shape = P.Shape()
_ps_pull = P.Pull() _ps_pull = P.Pull()
_ps_push = P.Push("Ftrl", [0, 1, 2]) _ps_push = P.Push("Ftrl", [0, 1, 2])
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0])) shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices))
success = F.depend(success, _ps_pull(_ps_push((gradient[1], gradient[0]), shapes), weight)) success = F.depend(success, _ps_pull(_ps_push((values, indices), shapes), weight))
else: else:
success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0])) success = F.depend(success, spars_opt(weight, moment, linear, values, indices))
return success return success
...@@ -58,14 +60,16 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra ...@@ -58,14 +60,16 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra
return success return success
@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", @_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices",
"Tensor", "Tensor") "Tensor", "Tensor")
def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient,
weight, moment): weight, moment):
success = True success = True
op_shape = P.Shape() op_shape = P.Shape()
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0])) values = gradient.values()
success = F.depend(success, pull(push((gradient[1], gradient[0]), shapes), weight)) indices = gradient.indices()
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((values, indices), shapes), weight))
return success return success
......
...@@ -27,14 +27,14 @@ from .optimizer import Optimizer ...@@ -27,14 +27,14 @@ from .optimizer import Optimizer
_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor",
"Tensor", "Tensor", "Tensor") "IndexedSlices", "Tensor", "Tensor", "Tensor")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2): moment1, moment2):
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0])) eps, gradient.values(), gradient.indices()))
return success return success
......
...@@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P ...@@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor, IndexedSlices
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
...@@ -490,12 +490,14 @@ op_gather = P.GatherV2() ...@@ -490,12 +490,14 @@ op_gather = P.GatherV2()
_apply_decay = C.MultitypeFuncGraph("apply_decay") _apply_decay = C.MultitypeFuncGraph("apply_decay")
@_apply_decay.register("Number", "Bool", "Tensor", "Tuple") @_apply_decay.register("Number", "Bool", "Tensor", "IndexedSlices")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay.""" """Get grad with weight_decay."""
if if_apply: if if_apply:
weight = op_gather(weight, gradient[0], 0) indices = gradient.indices()
return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2] values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values()))
shape = gradient.dense_shape()
return IndexedSlices(indices, values, shape)
return gradient return gradient
...@@ -518,9 +520,9 @@ def tensor_grad_scale(scale, grad): ...@@ -518,9 +520,9 @@ def tensor_grad_scale(scale, grad):
return grad * scale return grad * scale
@_grad_scale.register("Number", "Tuple") @_grad_scale.register("Number", "IndexedSlices")
def tensor_grad_scale_with_sparse(scale, grad): def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
if scale == 1.0: if scale == 1.0:
return grad return grad
return grad[0], grad[1] * scale, grad[2] return IndexedSlices(grad.indices(), grad.values() * scale, grad.dense_shape())
...@@ -23,11 +23,12 @@ from .optimizer import Optimizer ...@@ -23,11 +23,12 @@ from .optimizer import Optimizer
_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tuple", "Tensor", "Tensor") @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor",
"Tensor")
def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum):
"""Apply sparse proximal_ada_grad optimizer to the weight parameter.""" """Apply sparse proximal_ada_grad optimizer to the weight parameter."""
success = True success = True
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient[1], gradient[0])) success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))
return success return success
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from mindspore import context from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.common.tensor import IndexedSlices
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, AllGather from mindspore.ops.operations.comm_ops import AllReduce, AllGather
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
...@@ -77,7 +78,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc ...@@ -77,7 +78,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc
return grad return grad
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function") @reduce_opt.register("Number", "Bool", "Function", "Bool", "IndexedSlices", "Function")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce): def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce):
""" """
Apply allgather on gradient instead of allreduce for sparse feature. Apply allgather on gradient instead of allreduce for sparse feature.
...@@ -88,21 +89,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr ...@@ -88,21 +89,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients. allgather (Primitive): The communication operator for sparse gradients.
allreduce_filter (bool): When it is true, allgather would apply. allreduce_filter (bool): When it is true, allgather would apply.
grad (tuple): The indices, gradient tensor and tensor_shape before operation. grad (IndexedSlices): The gradient before operation.
allreduce (Primitive): The communication operator for gradients. allreduce (Primitive): The communication operator for gradients.
Returns: Returns:
Tuple, include indices, the gradient tensor and tensor_shape after operation. IndexedSlices, the gradient after operation.
""" """
if allreduce_filter: if allreduce_filter:
indices = allgather(grad[0]) indices = allgather(grad.indices())
dout = allgather(grad[1]) dout = allgather(grad.values())
if mean: if mean:
degree = F.scalar_cast(degree, F.dtype(grad[1])) degree = F.scalar_cast(degree, F.dtype(grad.values()))
cast_op = P.Cast() cast_op = P.Cast()
mul_op = P.Mul() mul_op = P.Mul()
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = (indices, dout, grad[2]) grad = IndexedSlices(indices, dout, grad.dense_shape())
return grad return grad
...@@ -123,18 +124,18 @@ def _tensors_get_datatype(grad): ...@@ -123,18 +124,18 @@ def _tensors_get_datatype(grad):
return F.dtype(grad) return F.dtype(grad)
@_get_datatype.register("Tuple") @_get_datatype.register("IndexedSlices")
def _tensors_get_datatype_with_sparse(grad): def _tensors_get_datatype_with_sparse(grad):
""" """
Acquire gradient datatype. Acquire gradient datatype.
Args: Args:
grad (Tuple): The gradient tensor before operation. grad (IndexedSlices): The gradient before operation.
Returns: Returns:
mstype, the datatype of gradient. mstype, the datatype of gradient.
""" """
return F.dtype(grad[1]) return F.dtype(grad.values())
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") _cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
...@@ -155,20 +156,20 @@ def _tensors_cast_datatype(datatype, grad): ...@@ -155,20 +156,20 @@ def _tensors_cast_datatype(datatype, grad):
return F.cast(grad, datatype) return F.cast(grad, datatype)
@_cast_datatype.register("TypeType", "Tuple") @_cast_datatype.register("TypeType", "IndexedSlices")
def _tensors_cast_datatype_with_sparse(datatype, grad): def _tensors_cast_datatype_with_sparse(datatype, grad):
""" """
Cast gradient to datatype. Cast gradient to datatype.
Args: Args:
datatype (mstype): the destination datatype of gradient. datatype (mstype): the destination datatype of gradient.
grad (Tuple): The gradient tensor before operation. grad (IndexedSlices): The gradient before operation.
Returns: Returns:
Tuple, the gradient tuple after operation. IndexedSlices, the gradient after operation.
""" """
dout = F.cast(grad[1], datatype) dout = F.cast(grad.values(), datatype)
return (grad[0], dout, grad[2]) return IndexedSlices(grad.indices(), dout, grad.dense_shape())
class DistributedGradReducer(Cell): class DistributedGradReducer(Cell):
......
...@@ -25,6 +25,7 @@ from .grad_base import bprop_getters ...@@ -25,6 +25,7 @@ from .grad_base import bprop_getters
from ..primitive import constexpr from ..primitive import constexpr
from ... import context from ... import context
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import IndexedSlices
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum() unsorted_segment_sum = P.UnsortedSegmentSum()
...@@ -206,7 +207,7 @@ def get_bprop_embedding_lookup(self): ...@@ -206,7 +207,7 @@ def get_bprop_embedding_lookup(self):
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
# Reshape the 'actual_dout' on device # Reshape the 'actual_dout' on device
actual_dout = reshape_op(dout, actual_dout_shape_changed) actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return IndexedSlices(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
return bprop_sparse return bprop_sparse
...@@ -335,7 +336,7 @@ def get_bprop_sparse_gather_v2(self): ...@@ -335,7 +336,7 @@ def get_bprop_sparse_gather_v2(self):
values_shape = indices_size + x_tail_shp values_shape = indices_size + x_tail_shp
values = reshape(dout, values_shape) values = reshape(dout, values_shape)
indices = reshape(indices, indices_size) indices = reshape(indices, indices_size)
return (indices, values, x_shp), zeros_like(indices), zeros_like(axis) return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
if F.rank(dout) == 0: if F.rank(dout) == 0:
dout = P.ExpandDims()(dout, -1) dout = P.ExpandDims()(dout, -1)
if F.rank(indices) == 0: if F.rank(indices) == 0:
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from .. import operations as P from .. import operations as P
from ...common.tensor import IndexedSlices
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp, _GetTensorSlice, _MirrorOperator, ReduceOp,
...@@ -46,9 +47,9 @@ def get_bprop_all_reduce(self): ...@@ -46,9 +47,9 @@ def get_bprop_all_reduce(self):
if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout) dx = all_reduce_grad(dout)
else: else:
indices = all_gather(dout[0]) indices = all_gather(dout.indices())
grad = all_gather(dout[1]) grad = all_gather(dout.values())
dx = (indices, grad, dout[2]) dx = IndexedSlices(indices, grad, dout.dense_shape())
return (dx,) return (dx,)
else: else:
...@@ -59,12 +60,12 @@ def get_bprop_all_reduce(self): ...@@ -59,12 +60,12 @@ def get_bprop_all_reduce(self):
z = cast(z, dtype(dx)) z = cast(z, dtype(dx))
dx = mul(dx, z) dx = mul(dx, z)
else: else:
indices = all_gather(dout[0]) indices = all_gather(dout.indices())
grad = all_gather(dout[1]) grad = all_gather(dout.values())
z = equal(x, out) z = equal(x, out)
z = cast(z, dtype(grad)) z = cast(z, dtype(grad))
grad = mul(grad, z) grad = mul(grad, z)
dx = (indices, grad, dout[2]) dx = IndexedSlices(indices, grad, dout.dense_shape())
return (dx,) return (dx,)
return bprop return bprop
...@@ -194,19 +195,19 @@ def get_bprop_mirror_operator(self): ...@@ -194,19 +195,19 @@ def get_bprop_mirror_operator(self):
num = F.scalar_cast(dev_num, F.dtype(dx)) num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else: else:
indices = all_gather(dout[0]) indices = all_gather(dout.indices())
grad = all_gather(dout[1]) grad = all_gather(dout.values())
float_one = F.scalar_cast(1.0, F.dtype(grad)) float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad)) num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = (indices, grad, dout[2]) dx = (indices, grad, dout.dense_shape())
else: else:
if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout) dx = all_reduce(dout)
else: else:
indices = all_gather(dout[0]) indices = all_gather(dout.indices())
grad = all_gather(dout[1]) grad = all_gather(dout.values())
dx = (indices, grad, dout[2]) dx = (indices, grad, dout.dense_shape())
return (dx,) return (dx,)
return bprop return bprop
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Optimizer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
context.set_context(enable_sparse=True)
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (op_sqrt(next_v) + eps)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v))
return next_v
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tuple", "Bool")
def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
return gradient[2][2]
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
class AdamWeightDecaySparse(Optimizer):
"""
Implements Adam algorithm weight decay fix.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`,
and might be in sparse format.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
self.params = self.parameters
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
self.decay_flag = tuple(decay_filter(x) for x in self.params)
self.map = C.Map()
def construct(self, gradients):
lr = self.get_lr()
updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
return updated_velocity
def test_AdamWeightDecaySparse():
""" test_AdamWeightDecaySparse """
context.set_context(mode=context.GRAPH_MODE)
class Loss(nn.Cell):
def __init__(self):
super(Loss, self).__init__()
def construct(self, base, target):
return base
class NetWithSparseGatherV2(nn.Cell):
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
self.gatherv2 = P.SparseGatherV2()
self.axis = 0
def construct(self, indices):
return self.gatherv2(self.w1, indices, self.axis) * self.w2
inputs = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
loss = Loss()
optimizer = AdamWeightDecaySparse(net.trainable_params())
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
...@@ -19,8 +19,8 @@ import mindspore as ms ...@@ -19,8 +19,8 @@ import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor, IndexedSlices
from mindspore.ops import composite as C from mindspore.ops import composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator
from mindspore.ops._grad.grad_base import bprop_getters from mindspore.ops._grad.grad_base import bprop_getters
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
...@@ -65,7 +65,7 @@ def get_bprop_gather_v2(self): ...@@ -65,7 +65,7 @@ def get_bprop_gather_v2(self):
"""Generate bprop for GatherV2""" """Generate bprop for GatherV2"""
def bprop(x, indices, axis, out, dout): def bprop(x, indices, axis, out, dout):
return (indices, dout, x), axis, out return IndexedSlices(indices, dout, x), axis, out
return bprop return bprop
...@@ -78,7 +78,7 @@ def test_bprop_with_sparse_feature_allreduce(): ...@@ -78,7 +78,7 @@ def test_bprop_with_sparse_feature_allreduce():
if shape is None: if shape is None:
shape = [8, 8] shape = [8, 8]
self.all_reduce = AllReduce() self.all_reduce = AllReduce()
self.gatherv2 = VirtualGatherV2() self.gatherv2 = P.GatherV2()
self.index = Tensor(np.ones(shape), dtype=ms.int32) self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis self.axis = axis
...@@ -102,7 +102,7 @@ def test_bprop_with_sparse_feature_mirror(): ...@@ -102,7 +102,7 @@ def test_bprop_with_sparse_feature_mirror():
if shape is None: if shape is None:
shape = [8, 8] shape = [8, 8]
self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP) self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP)
self.gatherv2 = VirtualGatherV2() self.gatherv2 = P.GatherV2()
self.index = Tensor(np.ones(shape), dtype=ms.int32) self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis self.axis = axis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册