提交 eeba0461 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3050 Add TBE ops Tan/TruncateDiv/TruncateMod for VM.

Merge pull request !3050 from liuxiao93/Tan/TruncateDiv/TruncateMOd
...@@ -306,6 +306,34 @@ def get_bprop_floormod(self): ...@@ -306,6 +306,34 @@ def get_bprop_floormod(self):
return bprop return bprop
@bprop_getters.register(P.TruncateDiv)
def get_bprop_truncate_div(self):
"""Grad definition for `TruncateDiv` operation."""
div_op = P.TruncateDiv()
neg = P.Neg()
mul_op = P.Mul()
def bprop(x, y, out, dout):
bc_x = div_op(dout, y)
bc_y = neg(mul_op(bc_x, out))
return binop_grad_common(x, y, bc_x, bc_y)
return bprop
@bprop_getters.register(P.TruncateMod)
def get_bprop_truncate_mod(self):
"""Grad definition for `TruncateMod` operation."""
div_op = P.TruncateDiv()
def bprop(x, y, out, dout):
bc_x = dout
bc_y = -dout * div_op(x, y)
return binop_grad_common(x, y, bc_x, bc_y)
return bprop
@bprop_getters.register(P.Mod) @bprop_getters.register(P.Mod)
def get_bprop_mod(self): def get_bprop_mod(self):
"""Grad definition for `Mod` operation.""" """Grad definition for `Mod` operation."""
...@@ -1027,6 +1055,22 @@ def get_bprop_atan(self): ...@@ -1027,6 +1055,22 @@ def get_bprop_atan(self):
return bprop return bprop
@bprop_getters.register(P.Tan)
def get_bprop_tan(self):
"""Grad definition for `Tan` operation."""
reciprocal = P.Reciprocal()
square = P.Square()
cos = P.Cos()
def bprop(x, out, dout):
cosx = cos(x)
secx2 = square(reciprocal(cosx))
dx = secx2 * dout
return (dx,)
return bprop
@bprop_getters.register(P.BesselI1e) @bprop_getters.register(P.BesselI1e)
def get_bprop_bessel_i1e(self): def get_bprop_bessel_i1e(self):
"""Generate bprop for BesselI1e""" """Generate bprop for BesselI1e"""
......
...@@ -132,6 +132,8 @@ from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d ...@@ -132,6 +132,8 @@ from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d
from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad
from .apply_proximal_adagrad import _apply_proximal_adagrad from .apply_proximal_adagrad import _apply_proximal_adagrad
from .transpose_d import _transpose_d_tbe from .transpose_d import _transpose_d_tbe
from .truncate_div import _truncate_div_tbe
from .truncate_mod import _truncate_mod_tbe
from .unsorted_segment_sum import _unsorted_segment_sum_tbe from .unsorted_segment_sum import _unsorted_segment_sum_tbe
from .unsorted_segment_prod import _unsorted_segment_prod_tbe from .unsorted_segment_prod import _unsorted_segment_prod_tbe
from .logsoftmax_grad import _logsoftmax_grad_tbe from .logsoftmax_grad import _logsoftmax_grad_tbe
...@@ -222,6 +224,7 @@ from .binary_cross_entropy import _binary_cross_entropy_tbe ...@@ -222,6 +224,7 @@ from .binary_cross_entropy import _binary_cross_entropy_tbe
from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe
from .sin import _sin_tbe from .sin import _sin_tbe
from .cos import _cos_tbe from .cos import _cos_tbe
from .tan import _tan_tbe
from .cum_sum import _cum_sum_tbe from .cum_sum import _cum_sum_tbe
from .apply_rms_prop import _apply_rms_prop_tbe from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe from .cumprod import _cumprop_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tan op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tan_op_info = TBERegOp("Tan") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("tan.so") \
.compute_cost(10) \
.kernel_name("tan") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.get_op_info()
@op_info_register(tan_op_info)
def _tan_tbe():
"""Tan TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TruncateDiv op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
truncate_div_op_info = TBERegOp("TruncateDiv") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("truncate_div.so") \
.compute_cost(10) \
.kernel_name("truncate_div") \
.partial_flag(True) \
.op_pattern("broadcast") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \
.get_op_info()
@op_info_register(truncate_div_op_info)
def _truncate_div_tbe():
"""TruncateDiv TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TruncateMod op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
truncate_mod_op_info = TBERegOp("TruncateMod") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("truncate_mod.so") \
.compute_cost(10) \
.kernel_name("truncate_mod") \
.partial_flag(True) \
.op_pattern("broadcast") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \
.get_op_info()
@op_info_register(truncate_mod_op_info)
def _truncate_mod_tbe():
"""TruncateMod TBE register"""
return
...@@ -52,8 +52,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A ...@@ -52,8 +52,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
NPUAllocFloatStatus, NPUClearFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, Reciprocal, CumSum, HistogramFixedWidth,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
from .random_ops import (RandomChoiceWithMask, Normal) from .random_ops import (RandomChoiceWithMask, Normal)
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
...@@ -267,6 +267,8 @@ __all__ = [ ...@@ -267,6 +267,8 @@ __all__ = [
'SigmoidCrossEntropyWithLogits', 'SigmoidCrossEntropyWithLogits',
'FloorDiv', 'FloorDiv',
'FloorMod', 'FloorMod',
'TruncateDiv',
'TruncateMod',
'Ceil', 'Ceil',
'Acosh', 'Acosh',
'Asinh', 'Asinh',
...@@ -323,6 +325,7 @@ __all__ = [ ...@@ -323,6 +325,7 @@ __all__ = [
"BesselI1e", "BesselI1e",
"Atan", "Atan",
"Atanh", "Atanh",
"Tan",
"BasicLSTMCell", "BasicLSTMCell",
"BroadcastTo", "BroadcastTo",
"DataFormatDimMap", "DataFormatDimMap",
......
...@@ -1744,6 +1744,65 @@ class FloorDiv(_MathBinaryOp): ...@@ -1744,6 +1744,65 @@ class FloorDiv(_MathBinaryOp):
""" """
class TruncateDiv(_MathBinaryOp):
"""
Divide the first input tensor by the second input tensor element-wise for integer types, negative numbers will
round fractional quantities towards zero.
The inputs must be two tensors or one tensor and one scalar.
When the inputs are two tensors,
both dtypes cannot be bool, and the shapes of them could be broadcast.
When the inputs are one tensor and one scalar,
the scalar only could be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is number or bool.
Outputs:
Tensor, the shape is same as the shape after broadcasting,
and the data type is the one with high precision or high digits among the two inputs.
Examples:
>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
>>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> truncate_div = P.TruncateDiv()
>>> truncate_div(input_x, input_y)
[0, 1, 0]
"""
class TruncateMod(_MathBinaryOp):
"""
Returns element-wise remainder of division.
The inputs must be two tensors or one tensor and one scalar.
When the inputs are two tensors,
both dtypes cannot be bool, and the shapes of them could be broadcast.
When the inputs are one tensor and one scalar,
the scalar only could be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is number or bool.
Outputs:
Tensor, the shape is same as the shape after broadcasting,
and the data type is the one with high precision or high digits among the two inputs.
Examples:
>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
>>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> truncate_mod = P.TruncateMod()
>>> truncate_mod(input_x, input_y)
[2, 1, -1]
"""
class Mod(_MathBinaryOp): class Mod(_MathBinaryOp):
""" """
Computes the remainder of dividing the first input tensor by the second input tensor element-wise. Computes the remainder of dividing the first input tensor by the second input tensor element-wise.
...@@ -2870,6 +2929,34 @@ class Round(PrimitiveWithInfer): ...@@ -2870,6 +2929,34 @@ class Round(PrimitiveWithInfer):
return x_type return x_type
class Tan(PrimitiveWithInfer):
"""
Computes tan of `input_x` element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape as `input_x`.
Examples:
>>> tan = P.Tan()
>>> input_x = Tensor(np.array([-1.0, 0.0, 1.0]), mindspore.float32)
>>> output = tan(input_x)
"""
@prim_attr_register
def __init__(self):
"""init Tan"""
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
class Atan(PrimitiveWithInfer): class Atan(PrimitiveWithInfer):
""" """
Computes the trignometric inverse tangent of x element-wise. Computes the trignometric inverse tangent of x element-wise.
......
...@@ -766,6 +766,10 @@ test_case_math_ops = [ ...@@ -766,6 +766,10 @@ test_case_math_ops = [
'block': P.Asinh(), 'block': P.Asinh(),
'desc_inputs': [[3, 4, 5]], 'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}), 'desc_bprop': [[3, 4, 5]]}),
('Tan', {
'block': P.Tan(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Reciprocal', { ('Reciprocal', {
'block': P.Reciprocal(), 'block': P.Reciprocal(),
'desc_inputs': [[2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5]],
...@@ -852,6 +856,14 @@ test_case_math_ops = [ ...@@ -852,6 +856,14 @@ test_case_math_ops = [
'block': P.FloorMod(), 'block': P.FloorMod(),
'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]],
'desc_bprop': [[2, 3, 4, 5]]}), 'desc_bprop': [[2, 3, 4, 5]]}),
('TruncateDiv', {
'block': P.TruncateDiv(),
'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]],
'desc_bprop': [[2, 3, 4, 5]]}),
('TruncateMod', {
'block': P.TruncateMod(),
'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]],
'desc_bprop': [[2, 3, 4, 5]]}),
('identity', { ('identity', {
'block': ops.functional.identity, 'block': ops.functional.identity,
'desc_inputs': [[2, 2]], 'desc_inputs': [[2, 2]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册