提交 e1eb5273 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1682 Add tbe op SparseApplyProximalAdagradD-SparseApplyFtrlD-ApplyProximalAdagrad for VM.

Merge pull request !1682 from liuxiao/SparseApplyProximalAdagradD-SparseApplyFtrlD-ApplyProximalAdagradD
......@@ -65,6 +65,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"dropout_do_mask", "drop_out_do_mask"},
{"strided_slice", "strided_slice_d"},
{"strided_slice_grad", "strided_slice_grad_d"},
{"sparse_apply_ftrl", "sparse_apply_ftrl_d"},
{"transpose", "transpose_d"},
{"fill", "fill_d"},
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
......
......@@ -112,6 +112,9 @@ from .softplus_grad import _softplus_grad_tbe
from .softmax_grad_ext import _softmax_grad_ext_tbe
from .square import _square_tbe
from .sqrt import _sqrt_tbe
from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d
from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad
from .apply_proximal_adagrad import _apply_proximal_adagrad
from .transpose_d import _transpose_d_tbe
from .unsorted_segment_sum import _unsorted_segment_sum_tbe
from .logsoftmax_grad import _logsoftmax_grad_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyProximalAdagrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_proximal_adagrad.so") \
.compute_cost(10) \
.kernel_name("apply_proximal_adagrad") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "l1", False, "required", "all") \
.input(4, "l2", False, "required", "all") \
.input(5, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()
@op_info_register(apply_proximal_adagrad_op_info)
def _apply_proximal_adagrad():
"""ApplyProximalAdagrad TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseApplyFtrl op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sparse_apply_ftrl.so") \
.compute_cost(10) \
.kernel_name("sparse_apply_ftrl") \
.partial_flag(True) \
.attr("lr", "required", "float", "all") \
.attr("l1", "required", "float", "all") \
.attr("l2", "required", "float", "all") \
.attr("lr_power", "required", "float", "all") \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "linear", False, "required", "all") \
.input(3, "grad", False, "required", "all") \
.input(4, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.output(2, "linear", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(sparse_apply_ftrl_d_op_info)
def _sparse_apply_ftrl_d():
"""SparseApplyFtrl TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseApplyProximalAdagrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sparse_apply_proximal_adagrad.so") \
.compute_cost(10) \
.kernel_name("sparse_apply_proximal_adagrad") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "l1", False, "required", "all") \
.input(4, "l2", False, "required", "all") \
.input(5, "grad", False, "required", "all") \
.input(6, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \
.get_op_info()
@op_info_register(sparse_apply_proximal_adagrad_op_info)
def _sparse_apply_proximal_adagrad():
"""SparseApplyProximalAdagrad TBE register"""
return
......@@ -68,7 +68,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
SmoothL1Loss, Softmax, Softplus,
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
ApplyRMSProp, ApplyCenteredRMSProp)
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop
from . import _quant_ops
......@@ -265,6 +266,9 @@ __all__ = [
"Round",
"ApplyFtrl",
"SpaceToBatch",
"SparseApplyFtrl",
"ApplyProximalAdagrad",
"SparseApplyProximalAdagrad",
"BatchToSpace",
"Atan2",
"ApplyRMSProp",
......
......@@ -2807,6 +2807,126 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
return var_type
class ApplyProximalAdagrad(PrimitiveWithInfer):
r"""
Update relevant entries according to the proximal adagrad algorithm.
.. math::
accum += grad * grad
.. math::
prox_v = var - lr * grad * \frac{1}{\sqrt{accum}}
.. math::
var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0)
Args:
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
Inputs:
- **var** (Tensor) - Variable to be updated.
- **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape.
- **lr** (Union[Number, Tensor]): The learning rate value, must be positive. It should be
a scalar tensor or number.
- **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
- **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape.
Outputs:
Tensor, has the same shape and type as `var`.
Examples:
>>> var = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> accum = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> grad = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> lr = 0.01
>>> l1 = 0.0
>>> l2 = 0.0
>>> apply_proximal_ada_grad = P.ApplyProximalAdagrad()
>>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad)
"""
@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
return var_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype}
validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name)
return var_dtype
class SparseApplyProximalAdagrad(PrimitiveWithInfer):
r"""
Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad,
an additional index tensor is input.
.. math::
accum += grad * grad
.. math::
prox_v = var - lr * grad * \frac{1}{\sqrt{accum}}
.. math::
var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0)
Args:
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
Inputs:
- **var** (Tensor) - Variable tensor to be updated.
- **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape.
- **lr** (Union[Number, Tensor]): The learning rate value, must be positive. It should be
a scalar tensor or number.
- **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
- **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
Outputs:
Tensor, has the same shape and type as `var`.
Examples:
>>> var = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> accum = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> grad = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> indices = Tensor(np.ones((3,), np.int32))
>>> lr = 0.01
>>> l1 = 0.0
>>> l2 = 0.0
>>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad()
>>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices)
"""
@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
outputs=['output'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
return var_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype}
validator.check_scalar_or_tensor_type_same(scalar_args, [mstype.float32], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
return var_dtype
class LARSUpdate(PrimitiveWithInfer):
"""
Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient.
......@@ -2963,6 +3083,85 @@ class ApplyFtrl(PrimitiveWithInfer):
return var_type
class SparseApplyFtrl(PrimitiveWithInfer):
"""
Update relevant entries according to the FTRL-proximal scheme.
Args:
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l2 (float): l2 regularization strength, must be greater than or equal to zero.
lr_power (float): Learning rate power controls how the learning rate decreases during training,
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
use_locking (bool): Use locks for update operation if True . Default: False.
Inputs:
- **var** (Tensor): The variable to be updated.
- **accum** (Tensor): The accum to be updated, must be same type and shape as `var`.
- **linear** (Tensor): The linear to be updated, must be same type and shape as `var`.
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension. The type must be int32.
Outputs:
- **var** (Tensor): Tensor, has the same shape and type as `var`.
- **accum** (Tensor): Tensor, has the same shape and type as `accum`.
- **linear** (Tensor): Tensor, has the same shape and type as `linear`.
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlNet(nn.Cell):
>>> def __init__(self):
>>> super(SparseApplyFtrlNet, self).__init__()
>>> self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
>>> self.var = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="accum")
>>> self.linear = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="linear")
>>>
>>> def construct(self, grad, indices):
>>> out = self.apply_ftrl(self.var, self.accum, self.linear, grad, indices)
>>> return out
>>>
>>> net = SparseApplyFtrlNet()
>>> grad = Tensor(np.random.random(3, 3).astype(np.float32))
>>> indices = Tnsor(np.ones([3]), mindspore.float32)
>>> output = net(grad, indices)
"""
@prim_attr_register
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
validator.check_value_type("lr", lr, [float], self.name)
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number("lr", lr, 0.0, Rel.GT, self.name)
self.l1 = validator.check_number("l1", l1, 0.0, Rel.GE, self.name)
self.l2 = validator.check_number("l2", l2, 0.0, Rel.GE, self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape, linear_shape
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype
class ConfusionMulGrad(PrimitiveWithInfer):
"""
`output0` is the result of which input0 dot multily input1.
......@@ -3124,7 +3323,7 @@ class CTCLoss(PrimitiveWithInfer):
>>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64)
>>> labels_values = Tensor(np.array([2, 2]), mindspore.int32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> ctc_loss = P.CTCloss()
>>> ctc_loss = P.CTCLoss()
>>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
"""
......
......@@ -225,6 +225,46 @@ class ApplyFtrlNet(nn.Cell):
out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power)
return out
class SparseApplyFtrlNet(nn.Cell):
def __init__(self):
super(SparseApplyFtrlNet, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5)
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
def construct(self, grad, indices):
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
return out
class SparseApplyProximalAdagradNet(nn.Cell):
def __init__(self):
super(SparseApplyProximalAdagradNet, self).__init__()
self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
self.lr = 0.01
self.l1 = 0.0
self.l2 = 0.0
def construct(self, var, accum, grad, indices):
out = self.sparse_apply_proximal_adagrad(var, accum, self.lr, self.l1, self.l2, grad, indices)
return out
class ApplyProximalAdagradNet(nn.Cell):
def __init__(self):
super(ApplyProximalAdagradNet, self).__init__()
self.apply_proximal_adagrad = P.ApplyProximalAdagrad()
self.lr = 0.01
self.l1 = 0.0
self.l2 = 0.0
def construct(self, var, accum, grad):
out = self.apply_proximal_adagrad(var, accum, self.lr, self.l1, self.l2, grad)
return out
class ApplyRMSNet(nn.Cell):
def __init__(self):
super(ApplyRMSNet, self).__init__()
......@@ -982,6 +1022,18 @@ test_case_nn_ops = [
'block': P.SparseApplyAdagrad(0.5),
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
'skip': ['backward']}),
('SparseApplyFtrl', {
'block': SparseApplyFtrlNet(),
'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))],
'skip': ['backward']}),
('ApplyProximalAdagrad', {
'block': ApplyProximalAdagradNet(),
'desc_inputs': [[3, 3], [3, 3], [3, 3]],
'skip': ['backward']}),
('SparseApplyProximalAdagrad', {
'block': SparseApplyProximalAdagradNet(),
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
'skip': ['backward']}),
('Flatten_1', {
'block': NetForFlatten(),
'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册