From 374b7b8583fdf9f8a1ec5457cf7e7288c13db46a Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 17 Jul 2020 12:52:16 +0800 Subject: [PATCH] Add TBE op SquaredDifference for VM. --- mindspore/ops/_grad/grad_math_ops.py | 45 ++++++++ mindspore/ops/_op_impl/tbe/__init__.py | 3 + .../ops/_op_impl/tbe/squared_difference.py | 39 +++++++ mindspore/ops/_op_impl/tbe/xdivy.py | 38 +++++++ mindspore/ops/_op_impl/tbe/xlogy.py | 38 +++++++ mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/math_ops.py | 100 ++++++++++++++++++ mindspore/ops/operations/nn_ops.py | 54 +++++----- tests/ut/python/ops/test_ops.py | 12 +++ 9 files changed, 306 insertions(+), 28 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/squared_difference.py create mode 100644 mindspore/ops/_op_impl/tbe/xdivy.py create mode 100644 mindspore/ops/_op_impl/tbe/xlogy.py diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index c7d39c6aa..4925c82c8 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -252,6 +252,21 @@ def get_bprop_div_no_nan(self): return bprop +@bprop_getters.register(P.Xdivy) +def get_bprop_xdivy(self): + """Grad definition for `Xdivy` operation.""" + div_op = P.Xdivy() + + def bprop(x, y, out, dout): + x_dtype = F.dtype(x) + not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype) + bc_x = div_op(not_zero_x, y) * dout + bc_y = div_op(-x, F.square(y)) * dout + return binop_grad_common(x, y, bc_x, bc_y) + + return bprop + + @bprop_getters.register(P.Floor) def get_bprop_floor(self): """Grad definition for `floor` operation.""" @@ -353,6 +368,36 @@ def get_bprop_square(self): return bprop +@bprop_getters.register(P.SquaredDifference) +def get_bprop_squared_difference(self): + """Grad definition for `SquaredDifference` operation.""" + neg = P.Neg() + + def bprop(x, y, out, dout): + x_grad = 2 * dout * (x - y) + bc_x = x_grad + bc_y = neg(x_grad) + return binop_grad_common(x, y, bc_x, bc_y) + + return bprop + + +@bprop_getters.register(P.Xlogy) +def get_bprop_xlogy(self): + """Grad definition for `Xlogy` operation.""" + log_op = P.Xlogy() + div_op = P.Xdivy() + + def bprop(x, y, out, dout): + x_dtype = F.dtype(x) + not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype) + bc_x = log_op(not_zero_x, y) * dout + bc_y = div_op(x, y) * dout + return binop_grad_common(x, y, bc_x, bc_y) + + return bprop + + @bprop_getters.register(P.Sqrt) def get_bprop_sqrt(self): """Grad definition for `Sqrt` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 1dcc4bf15..b0b352f61 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -108,6 +108,8 @@ from .elu import _elu_tbe from .elu_grad import _elu_grad_tbe from .div import _div_tbe from .log import _log_tbe +from .xdivy import _xdivy_tbe +from .xlogy import _xlogy_tbe from .floor_div import _floor_div_tbe from .zeros_like import _zeros_like_tbe from .neg import _neg_tbe @@ -133,6 +135,7 @@ from .softplus import _softplus_tbe from .softplus_grad import _softplus_grad_tbe from .softmax_grad_ext import _softmax_grad_ext_tbe from .square import _square_tbe +from .squared_difference import _squared_difference_tbe from .sqrt import _sqrt_tbe from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad diff --git a/mindspore/ops/_op_impl/tbe/squared_difference.py b/mindspore/ops/_op_impl/tbe/squared_difference.py new file mode 100644 index 000000000..f567b9196 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/squared_difference.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SquaredDifference op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +squared_difference_op_info = TBERegOp("SquaredDifference") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("squared_difference.so") \ + .compute_cost(10) \ + .kernel_name("squared_difference") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(squared_difference_op_info) +def _squared_difference_tbe(): + """SquaredDifference TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/xdivy.py b/mindspore/ops/_op_impl/tbe/xdivy.py new file mode 100644 index 000000000..1624576c2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/xdivy.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Xdivy op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +xdivy_op_info = TBERegOp("Xdivy") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("xdivy.so") \ + .compute_cost(10) \ + .kernel_name("xdivy") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(xdivy_op_info) +def _xdivy_tbe(): + """Xdivy TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/xlogy.py b/mindspore/ops/_op_impl/tbe/xlogy.py new file mode 100644 index 000000000..7a997f216 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/xlogy.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Xlogy op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +xlogy_op_info = TBERegOp("Xlogy") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("xlogy.so") \ + .compute_cost(10) \ + .kernel_name("xlogy") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(xlogy_op_info) +def _xlogy_tbe(): + """Xlogy TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 49d4ad1f0..d4a1b1ff6 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -51,7 +51,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A Minimum, Mul, Neg, NMSWithMask, NotEqual, NPUAllocFloatStatus, NPUClearFloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, - Reciprocal, CumSum, HistogramFixedWidth, + Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) @@ -107,6 +107,9 @@ __all__ = [ 'Rsqrt', 'Sqrt', 'Square', + 'SquaredDifference', + 'Xdivy', + 'Xlogy', 'Conv2D', 'Flatten', 'MaxPoolWithArgmax', diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index e2f5df8c3..86b0eb576 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1121,6 +1121,40 @@ class Mul(_MathBinaryOp): return None +class SquaredDifference(_MathBinaryOp): + """ + Subtracts the second input tensor from the first input tensor element-wise and returns square of it. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is float16, float32, int32 or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is + float16, float32, int32 or bool. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) + >>> input_y = Tensor(np.array([2.0, 4.0, 6.0]), mindspore.float32) + >>> squared_difference = P.SquaredDifference() + >>> squared_difference(input_x, input_y) + [1.0, 4.0, 9.0] + """ + + def infer_dtype(self, x_dtype, y_dtype): + valid_type = [mstype.float16, mstype.float32, mstype.int32] + return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, valid_type, self.name) + + class Square(PrimitiveWithInfer): """ Returns square of a tensor element-wise. @@ -1962,6 +1996,72 @@ class Ceil(PrimitiveWithInfer): return x_dtype +class Xdivy(_MathBinaryOp): + """ + Divide the first input tensor by the second input tensor element-wise. Returns zero when `x` is zero. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is float16, float32 or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.float32) + >>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32) + >>> xdivy = P.Xdivy() + >>> xdivy(input_x, input_y) + [1.0, 2.0, -0.5] + """ + + def infer_dtype(self, x_dtype, y_dtype): + return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name) + + +class Xlogy(_MathBinaryOp): + """ + Computes first input tensor multiplied by the logarithm of second input tensor element-wise. + Returns zero when `x` is zero. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is float16, float32 or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool. + The value must be positive. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([-5, 0, 4]), mindspore.float32) + >>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32) + >>> xlogy = P.Xlogy() + >>> Xlogy(input_x, input_y) + [-3.465736, 0.0, 2.7725887] + """ + + def infer_dtype(self, x_dtype, y_dtype): + return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name) + + class Acosh(PrimitiveWithInfer): """ Compute inverse hyperbolic cosine of x element-wise. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 924060eec..dd868eaf1 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3205,11 +3205,11 @@ class FusedSparseFtrl(PrimitiveWithInfer): use_locking (bool): Use locks for update operation if True . Default: False. Inputs: - - **var** (Parameter): The variable to be updated. The data type must be float32. - - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. - - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. - - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. - - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape + - **var** (Parameter) - The variable to be updated. The data type must be float32. + - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. + - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. Outputs: @@ -3300,9 +3300,9 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): Inputs: - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. - - **lr** (Tensor): The learning rate value. The data type must be float32. - - **l1** (Tensor): l1 regularization strength. The data type must be float32. - - **l2** (Tensor): l2 regularization strength. The data type must be float32. + - **lr** (Tensor) - The learning rate value. The data type must be float32. + - **l1** (Tensor) - l1 regularization strength. The data type must be float32. + - **l2** (Tensor) - l2 regularization strength. The data type must be float32. - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type must be int32. @@ -4670,16 +4670,16 @@ class ApplyFtrl(PrimitiveWithInfer): use_locking (bool): Use locks for update operation if True . Default: False. Inputs: - - **var** (Tensor): The variable to be updated. - - **accum** (Tensor): The accum to be updated, must be same type and shape as `var`. - - **linear** (Tensor): The linear to be updated, must be same type and shape as `var`. - - **grad** (Tensor): Gradient. - - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. Default: 0.001. - - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. + - **var** (Tensor) - The variable to be updated. + - **accum** (Tensor) - The accum to be updated, must be same type and shape as `var`. + - **linear** (Tensor) - The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor) - Gradient. + - **lr** (Union[Number, Tensor]) - The learning rate value, must be positive. Default: 0.001. + - **l1** (Union[Number, Tensor]) - l1 regularization strength, must be greater than or equal to zero. Default: 0.0. - - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. + - **l2** (Union[Number, Tensor]) - l2 regularization strength, must be greater than or equal to zero. Default: 0.0. - - **lr_power** (Union[Number, Tensor]): Learning rate power controls how the learning rate decreases + - **lr_power** (Union[Number, Tensor]) - Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5. @@ -4760,17 +4760,17 @@ class SparseApplyFtrl(PrimitiveWithInfer): use_locking (bool): Use locks for update operation if True . Default: False. Inputs: - - **var** (Parameter): The variable to be updated. The data type must be float32. - - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. - - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. - - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. - - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. + - **var** (Parameter) - The variable to be updated. The data type must be float32. + - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. + - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. Outputs: - - **var** (Tensor): Tensor, has the same shape and type as `var`. - - **accum** (Tensor): Tensor, has the same shape and type as `accum`. - - **linear** (Tensor): Tensor, has the same shape and type as `linear`. + - **var** (Tensor) - Tensor, has the same shape and type as `var`. + - **accum** (Tensor) - Tensor, has the same shape and type as `accum`. + - **linear** (Tensor) - Tensor, has the same shape and type as `linear`. Examples: >>> import mindspore @@ -4858,9 +4858,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): Outputs: Tuple of 3 Tensor, the updated parameters. - - **var** (Tensor): Tensor, has the same shape and type as `var`. - - **accum** (Tensor): Tensor, has the same shape and type as `accum`. - - **linear** (Tensor): Tensor, has the same shape and type as `linear`. + - **var** (Tensor) - Tensor, has the same shape and type as `var`. + - **accum** (Tensor) - Tensor, has the same shape and type as `accum`. + - **linear** (Tensor) - Tensor, has the same shape and type as `linear`. Examples: >>> import mindspore diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 1f9415c61..f22366e13 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1013,6 +1013,18 @@ test_case_math_ops = [ 'desc_const': [(0, 3, 1, 2)], 'desc_inputs': [], 'skip': ['backward']}), + ('Xdivy', { + 'block': P.Xdivy(), + 'desc_inputs': [[4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), + ('Xlogy', { + 'block': P.Xlogy(), + 'desc_inputs': [[4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), + ('SquaredDifference', { + 'block': P.SquaredDifference(), + 'desc_inputs': [[4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), ('Square', { 'block': P.Square(), 'desc_inputs': [[4]], -- GitLab