diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 3f6d41804dffa2f6c0d2058fe065b8e900803acb..3483a8a9661047419f1ce7cdbcd8ea54ea4c36f0 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum"; const char kNameIsFinite[] = "isFinite"; const char kNameReciprocal[] = "Reciprocal"; const char kNameRsqrt[] = "Rsqrt"; -const char kNameRsqrtGrad[] = "RsqrtGrad"; const char kNameSqrt[] = "Sqrt"; const char kNameSquare[] = "Square"; const char kNameSquaredDifference[] = "SquaredDifference"; @@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad"; const char kNameConvolution[] = "Convolution"; const char kNameBiasAdd[] = "BiasAdd"; const char kNameMaxPoolGrad[] = "MaxPoolGrad"; +const char kNameRsqrtGrad[] = "RsqrtGrad"; +const char kNameSqrtGrad[] = "SqrtGrad"; +const char kNameReciprocalGrad[] = "ReciprocalGrad"; const char kNameAvgPoolGrad[] = "AvgPoolGrad"; const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; const char kNameApplyMomentum[] = "ApplyMomentum"; @@ -233,6 +235,9 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, + {string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)}, + {string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)}, + {string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)}, {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc index 1f803d4d813a197ab7ee157f8bf4e1386381a77b..48362c793bddeaf16a13aadc4c568d5946e23bcb 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits< {"data_format", ATTR_DESC(data_format, AnyTraits())}}; OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; +// RsqrtGrad +INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}}; + +// SqrtGrad +INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}}; + +// ReciprocalGrad +INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}}; + // avgpoolgrad INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare.h index 220cc945a7a87f9333998b9c8c6683ecd2d1588f..9d21d0a76cdd9d6b87a30d9aa273190c0721d19a 100755 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.h @@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool) DECLARE_OP_USE_OUTPUT(MaxPool) DECLARE_OP_ADAPTER(MaxPoolGrad) DECLARE_OP_USE_OUTPUT(MaxPoolGrad) +DECLARE_OP_ADAPTER(SqrtGrad) +DECLARE_OP_USE_OUTPUT(SqrtGrad) +DECLARE_OP_ADAPTER(ReciprocalGrad) +DECLARE_OP_USE_OUTPUT(ReciprocalGrad) +DECLARE_OP_ADAPTER(RsqrtGrad) +DECLARE_OP_USE_OUTPUT(RsqrtGrad) DECLARE_OP_ADAPTER(AvgPool) DECLARE_OP_USE_OUTPUT(AvgPool) DECLARE_OP_ADAPTER(AvgPoolGrad) diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index c7d39c6aa018d45cba696035bab07d88a0fdd3cd..e562b644170bfb23e9f58a9557c795f02d20dba0 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -356,15 +356,10 @@ def get_bprop_square(self): @bprop_getters.register(P.Sqrt) def get_bprop_sqrt(self): """Grad definition for `Sqrt` operation.""" - mul_func = P.Mul() - fill_func = P.Fill() - div_op = P.RealDiv() - sqrt = P.Sqrt() - dtype = P.DType() + sqrt_grad = G.SqrtGrad() def bprop(x, out, dout): - temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x)) - dx = mul_func(dout, temp) + dx = sqrt_grad(out, dout) return (dx,) return bprop @@ -373,10 +368,10 @@ def get_bprop_sqrt(self): @bprop_getters.register(P.Rsqrt) def get_bprop_rsqrt(self): """Grad definition for `Rsqrt` operation.""" + rsqrt_grad = G.RsqrtGrad() def bprop(x, out, dout): - grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x) - dx = dout * grad + dx = rsqrt_grad(out, dout) return (dx,) return bprop @@ -385,14 +380,10 @@ def get_bprop_rsqrt(self): @bprop_getters.register(P.Reciprocal) def get_bprop_reciprocal(self): """Grad definition for `Reciprocal` operation.""" - neg = P.Neg() - mul = P.Mul() - square = P.Square() - reciprocal = P.Reciprocal() + reciprocal_grad = G.ReciprocalGrad() def bprop(x, out, dout): - g = neg(reciprocal(square(x))) - dx = mul(dout, g) + dx = reciprocal_grad(out, dout) return (dx,) return bprop diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 75fd099f99a8918b951e564a7e4a84d0644e6d31..8f4cf8496d6014f5da92d743c3cd379ae491bf55 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -441,6 +441,7 @@ def get_bprop_softmax(self): sub = P.Sub() mul = P.Mul() axis = self.axis + def bprop(x, out, dout): dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) return (dx,) diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 1dcc4bf1532a280c37842754715ec3afdd17dda2..faa02746ebc51b39d9399795e4020b8f38ab7471 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe from .apply_rms_prop import _apply_rms_prop_tbe from .cumprod import _cumprop_tbe from .reduce_prod import _reduce_prod_tbe +from .reciprocal_grad import _reciprocal_grad_tbe +from .sqrt_grad import _sqrt_grad_tbe +from .rsqrt_grad import _rsqrt_grad_tbe from .flatten_grad import _flatten_grad_tbe from .scatter_add import _scatter_add_tbe from .atan2 import _atan2_tbe diff --git a/mindspore/ops/_op_impl/tbe/reciprocal.py b/mindspore/ops/_op_impl/tbe/reciprocal.py index c620fb17a60c786e22bfe40af4039b37745c211d..eacfdd6bcef5f8f932d35e7adeacdf1041016ab2 100644 --- a/mindspore/ops/_op_impl/tbe/reciprocal.py +++ b/mindspore/ops/_op_impl/tbe/reciprocal.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""Add op""" +"""Reciprocal op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType reciprocal_op_info = TBERegOp("Reciprocal") \ @@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \ @op_info_register(reciprocal_op_info) def _reciprocal_tbe(): - """Add TBE register""" + """Reciprocal TBE register""" return diff --git a/mindspore/ops/_op_impl/tbe/reciprocal_grad.py b/mindspore/ops/_op_impl/tbe/reciprocal_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..48c716986164e46f36d6e743a84146d6938eeca1 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/reciprocal_grad.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReciprocalGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("reciprocal_grad.so") \ + .compute_cost(10) \ + .kernel_name("reciprocal_grad") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(reciprocal_grad_op_info) +def _reciprocal_grad_tbe(): + """ReciprocalGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/rsqrt_grad.py b/mindspore/ops/_op_impl/tbe/rsqrt_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..914c7eca8b0e8d99f8bd21fc781e5aa6dacc27b8 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/rsqrt_grad.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RsqrtGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("rsqrt_grad.so") \ + .compute_cost(10) \ + .kernel_name("rsqrt_grad") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "x", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(rsqrt_grad_op_info) +def _rsqrt_grad_tbe(): + """RsqrtGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sqrt_grad.py b/mindspore/ops/_op_impl/tbe/sqrt_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..a951bb0f8a004cd522fdf6812441cad2371df865 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sqrt_grad.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SqrtGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sqrt_grad_op_info = TBERegOp("SqrtGrad") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("sqrt_grad.so") \ + .compute_cost(10) \ + .kernel_name("sqrt_grad") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(sqrt_grad_op_info) +def _sqrt_grad_tbe(): + """SqrtGrad TBE register""" + return diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 1927c50dcf43ff2519a19b6a67947a23cd51d88c..5663dadc76701aad2bd324ee282b0abe64498244 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer): return x +class ReciprocalGrad(PrimitiveWithInfer): + """Performs grad of Reciprocal operation.""" + + @prim_attr_register + def __init__(self): + """init ReciprocalGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + return x_dtype + + +class RsqrtGrad(PrimitiveWithInfer): + """Performs grad of Rsqrt operation.""" + + @prim_attr_register + def __init__(self): + """init RsqrtGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) + return x_dtype + + +class SoftmaxGrad(PrimitiveWithInfer): + """Performs grad of Softmax operation.""" + + @prim_attr_register + def __init__(self): + """init SoftmaxGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + return x_dtype + + +class SqrtGrad(PrimitiveWithInfer): + """Performs grad of Sqrt operation.""" + + @prim_attr_register + def __init__(self): + """init SqrtGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + return x_dtype + + class BatchNormGrad(PrimitiveWithInfer): """Performs grad of BatchNorm operation."""