提交 bbc64b88 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1528 support vm for BesselI0e and BesselI1e

Merge pull request !1528 from jiangjinsheng/vm_bessel
......@@ -24,6 +24,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
from .grad_base import bprop_getters
from ..primitive import constexpr
from ..composite.multitype_ops import _constexpr_utils as const_utils
shape_op = P.Shape()
reduce_sum = P.ReduceSum()
......@@ -897,3 +898,39 @@ def get_bprop_atan2(self):
return binop_grad_common(x, y, bc_dx, bc_dy)
return bprop
@bprop_getters.register(P.BesselI0e)
def get_bprop_bessel_i0e(self):
"""Generate bprop for BesselI0e"""
sign = P.Sign()
bessel_i1e = P.BesselI1e()
def bprop(x, out, dout):
dx = dout * (bessel_i1e(x) - sign(x) * out)
return (dx,)
return bprop
@bprop_getters.register(P.BesselI1e)
def get_bprop_bessel_i1e(self):
"""Generate bprop for BesselI1e"""
sign = P.Sign()
bessel_i0e = P.BesselI0e()
less = P.Less()
select = P.Select()
reciprocal = P.Reciprocal()
cast = P.Cast()
dtype = P.DType()
def bprop(x, out, dout):
zeros = zeros_like(x)
np_eps = const_utils.get_np_eps(dtype(x))
eps = cast(np_eps, dtype(x))
x_is_valid = less(eps, x)
x_safe = select(x_is_valid, x, eps + zeros)
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe))
dx = select(x_is_valid, tmp, 0.5 + zeros)
return (dx,)
return bprop
......@@ -207,6 +207,8 @@ from .reduce_prod import _reduce_prod_tbe
from .flatten_grad import _flatten_grad_tbe
from .scatter_add import _scatter_add_tbe
from .atan2 import _atan2_tbe
from .bessel_i0e import _bessel_i0e_tbe
from .bessel_i1e import _bessel_i1e_tbe
from .batch_to_space_nd import _batch_to_space_nd_tbe
from .space_to_batch_nd import _space_to_batch_nd_tbe
from .bitwise_and import bitwise_and_op_info
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselI0e op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_i0e_op_info = TBERegOp("BesselI0e") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_i0e.so") \
.compute_cost(10) \
.kernel_name("bessel_i0e") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(bessel_i0e_op_info)
def _bessel_i0e_tbe():
"""BesselI0e TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselI1e op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_i1e_op_info = TBERegOp("BesselI1e") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_i1e.so") \
.compute_cost(10) \
.kernel_name("bessel_i1e") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(bessel_i1e_op_info)
def _bessel_i1e_tbe():
"""BesselI1e TBE register"""
return
......@@ -35,15 +35,10 @@ strided_slice_grad_d_op_info = TBERegOp("StridedSliceGrad") \
.input(0, "dy", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -631,3 +631,10 @@ def scalar_in_sequence(x, y):
if x in y:
return True
return False
@constexpr
def get_np_eps(input_dtype):
nptype = mstype.dtype_to_nptype(input_dtype)
eps = np.finfo(nptype).eps
return float(eps)
......@@ -49,7 +49,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
NPUAllocFloatStatus, NPUClearFloatStatus,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll)
from .random_ops import (RandomChoiceWithMask)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
......@@ -274,7 +274,9 @@ __all__ = [
"SquareSumAll",
"BitwiseAnd",
"BitwiseOr",
"BitwiseXor"
"BitwiseXor",
"BesselI0e",
"BesselI1e",
]
__all__.extend(_quant_ops.__all__)
......
......@@ -636,7 +636,7 @@ class CumSum(PrimitiveWithInfer):
Inputs:
- **input** (Tensor) - The input tensor to accumulate.
- **axis** (int) - The axis to accumulate the tensor's value.
- **axis** (int) - The axis to accumulate the tensor's value. Only constant value is allowed.
Outputs:
Tensor, the shape of the output tensor is consistent with the input tensor's.
......@@ -2323,3 +2323,61 @@ class BitwiseXor(_BitwiseBinaryOp):
>>> bitwise_xor(input_x1, input_x2)
[0, 1, 0, 0, -2, 3, 2]
"""
class BesselI0e(PrimitiveWithInfer):
"""
Computes BesselI0e of input element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape as `input_x`.
Examples:
>>> bessel_i0e = P.BesselI0e()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = bessel_i0e(input_x)
[0.7979961, 0.5144438, 0.75117415, 0.9157829]
"""
@prim_attr_register
def __init__(self):
"""init BesselI0e"""
def infer_shape(self, x):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
class BesselI1e(PrimitiveWithInfer):
"""
Computes BesselI1e of input element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape as `input_x`.
Examples:
>>> bessel_i1e = P.BesselI1e()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = bessel_i1e(input_x)
[0.09507662, 0.19699717, 0.11505538, 0.04116856]
"""
@prim_attr_register
def __init__(self):
"""init BesselI1e"""
def infer_shape(self, x):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
......@@ -1705,8 +1705,8 @@ class ApplyRMSProp(PrimitiveWithInfer):
- **var** (Tensor) - Weights to be update.
- **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
- **grad** (Tensor) - Gradients, must have the same type as `var`.
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
- **grad** (Tensor) - Gradients, must have the same type as `var`.
- **decay** (float) - Decay rate.
- **momentum** (float) - Momentum.
- **epsilon** (float) - Ridge term.
......
......@@ -672,6 +672,14 @@ test_case_math_ops = [
'desc_const': [1],
'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))],
'desc_bprop': []}),
('BesselI0e', {
'block': P.BesselI0e(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('BesselI1e', {
'block': P.BesselI1e(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
]
test_case_nn_ops = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册