提交 b052ecf4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2804 Add TBE operators SparseApplyFtrlV2\SparseApplyAdagradV2 for VM.

Merge pull request !2804 from liuxiao93/ops-for-VM
......@@ -70,11 +70,13 @@ static std::map<string, string> tbe_func_adapter_map = {
{"strided_slice", "strided_slice_d"},
{"strided_slice_grad", "strided_slice_grad_d"},
{"sparse_apply_ftrl", "sparse_apply_ftrl_d"},
{"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"},
{"apply_ada_max", "apply_ada_max_d"},
{"apply_adadelta", "apply_adadelta_d"},
{"apply_adagrad", "apply_adagrad_d"},
{"apply_adagrad_v2", "apply_adagradv2_d"},
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
{"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"},
{"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
{"apply_add_sign", "apply_add_sign_d"},
......
......@@ -38,6 +38,8 @@ from .apply_add_sign import _apply_add_sign_tbe
from .apply_power_sign import _apply_power_sign_tbe
from .apply_gradient_descent import _apply_gradient_descent_tbe
from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe
from .sparse_apply_ftrl_v2 import _sparse_apply_ftrl_v2_tbe
from .sparse_apply_adagrad_v2 import _sparse_apply_adagrad_v2_tbe
from .approximate_equal import _approximate_equal_tbe
from .adam_apply_one import _adam_apply_one_tbe
from .assign import _assign_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseApplyAdagradV2D op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sparse_apply_adagrad_v2_d_op_info = TBERegOp("SparseApplyAdagradV2") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sparse_apply_adagrad_v2_d.so") \
.compute_cost(10) \
.kernel_name("sparse_apply_adagrad_v2_d") \
.partial_flag(True) \
.attr("lr", "required", "float", "all") \
.attr("epsilon", "required", "float", "all") \
.attr("use_locking", "optional", "bool", "all") \
.attr("update_slots", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "grad", False, "required", "all") \
.input(3, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(sparse_apply_adagrad_v2_d_op_info)
def _sparse_apply_adagrad_v2_tbe():
"""SparseApplyAdagradV2D TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseApplyFtrlV2D op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sparse_apply_ftrl_v2_d_op_info = TBERegOp("SparseApplyFtrlV2") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sparse_apply_ftrl_v2_d.so") \
.compute_cost(10) \
.kernel_name("sparse_apply_ftrl_v2_d") \
.partial_flag(True) \
.attr("lr", "required", "float", "all") \
.attr("l1", "required", "float", "all") \
.attr("l2", "required", "float", "all") \
.attr("l2_shrinkage", "required", "float", "all") \
.attr("lr_power", "required", "float", "all") \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "linear", False, "required", "all") \
.input(3, "grad", False, "required", "all") \
.input(4, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.output(2, "linear", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(sparse_apply_ftrl_v2_d_op_info)
def _sparse_apply_ftrl_v2_tbe():
"""SparseApplyFtrlV2D TBE register"""
return
......@@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
......@@ -284,6 +284,7 @@ __all__ = [
"Abs",
"BinaryCrossEntropy",
"SparseApplyAdagrad",
"SparseApplyAdagradV2",
"SpaceToDepth",
"DepthToSpace",
"Conv2DBackpropInput",
......@@ -294,6 +295,7 @@ __all__ = [
"ApplyFtrl",
"SpaceToBatch",
"SparseApplyFtrl",
"SparseApplyFtrlV2",
"ApplyProximalAdagrad",
"SparseApplyProximalAdagrad",
"ApplyAdaMax",
......
......@@ -3600,6 +3600,88 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
return var_type, accum_type
class SparseApplyAdagradV2(PrimitiveWithInfer):
r"""
Update relevant entries according to the adagrad scheme.
.. math::
accum += grad * grad
.. math::
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
Args:
lr (float): Learning rate.
epsilon (float): A small value added for numerical stability.
use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False.
update_slots (bool): If `True`, the computation logic will be different to `False`. Default: True.
Inputs:
- **var** (Parameter) - Variable to be updated. The type must be float32.
- **accum** (Parameter) - Accum to be updated. The shape must be the same as `var`'s shape,
the type must be float32.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except first dimension,
the type must be float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
Outputs:
Tuple of 2 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6)
>>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
>>>
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
>>> indices = Tensor([0, 1, 2], mstype.int32)
>>> result = net(grad, indices)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
)
@prim_attr_register
def __init__(self, lr, epsilon, use_locking=False, update_slots=True):
self.lr = validator.check_value_type("lr", lr, [float], self.name)
self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name)
self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name)
self.update_slots = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
return var_type, accum_type
class ApplyProximalAdagrad(PrimitiveWithInfer):
r"""
Update relevant entries according to the proximal adagrad algorithm.
......@@ -3664,7 +3746,8 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output'])
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'],
outputs=['var', 'accum'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape):
......@@ -4371,6 +4454,98 @@ class SparseApplyFtrl(PrimitiveWithInfer):
return var_dtype, accum_dtype, linear_dtype
class SparseApplyFtrlV2(PrimitiveWithInfer):
"""
Update relevant entries according to the FTRL-proximal scheme.
Args:
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l2 (float): l2 regularization strength, must be greater than or equal to zero.
l2_shrinkage (float): L2 shrinkage regularization.
lr_power (float): Learning rate power controls how the learning rate decreases during training,
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False.
Inputs:
- **var** (Parameter) - The variable to be updated. The data type must be float32.
- **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`.
- **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension. The type must be int32.
Outputs:
Tuple of 3 Tensor, the updated parameters.
- **var** (Tensor): Tensor, has the same shape and type as `var`.
- **accum** (Tensor): Tensor, has the same shape and type as `accum`.
- **linear** (Tensor): Tensor, has the same shape and type as `linear`.
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlV2Net(nn.Cell):
>>> def __init__(self):
>>> super(SparseApplyFtrlV2Net, self).__init__()
>>> self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0,
l2_shrinkage=0.0, lr_power=-0.5)
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>> self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
>>>
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices)
>>> return out
>>>
>>> net = SparseApplyFtrlV2Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> indices = Tensor(np.ones([3]), mindspore.int32)
>>> output = net(grad, indices)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
)
@prim_attr_register
def __init__(self, lr, l1, l2, l2_shrinkage, lr_power, use_locking=False):
validator.check_value_type("lr", lr, [float], self.name)
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape, linear_shape
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype
class ConfusionMulGrad(PrimitiveWithInfer):
"""
`output0` is the result of which input0 dot multily input1.
......
......@@ -306,6 +306,19 @@ class SparseApplyFtrlNet(nn.Cell):
return out
class SparseApplyFtrlV2Net(nn.Cell):
def __init__(self):
super(SparseApplyFtrlV2Net, self).__init__()
self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.001, l1=0.0, l2=0.0, l2_shrinkage=0.0, lr_power=-0.5)
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
def construct(self, grad, indices):
out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices)
return out
class SparseApplyProximalAdagradNet(nn.Cell):
def __init__(self):
super(SparseApplyProximalAdagradNet, self).__init__()
......@@ -467,6 +480,18 @@ class SparseApplyAdagradNet(nn.Cell):
return out
class SparseApplyAdagradV2Net(nn.Cell):
def __init__(self):
super(SparseApplyAdagradV2Net, self).__init__()
self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=0.01, epsilon=0.001)
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
def construct(self, grad, indices):
out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices)
return out
class ApplyRMSNet(nn.Cell):
def __init__(self):
super(ApplyRMSNet, self).__init__()
......@@ -1376,10 +1401,18 @@ test_case_nn_ops = [
'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))],
'desc_bprop': [[3, 3], [3, 3]],
'skip': ['backward']}),
('SparseApplyAdagradV2', {
'block': SparseApplyAdagradV2Net(),
'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))],
'skip': ['backward']}),
('SparseApplyFtrl', {
'block': SparseApplyFtrlNet(),
'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))],
'skip': ['backward']}),
('SparseApplyFtrlV2', {
'block': SparseApplyFtrlV2Net(),
'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))],
'skip': ['backward']}),
('ApplyProximalAdagrad', {
'block': ApplyProximalAdagradNet(),
'desc_inputs': [[3, 3]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册