提交 f4e4fd2c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3577 add ReciprocalGrad RsqrtGrad SqrtGrad

Merge pull request !3577 from fangzehua/many_grad_ops
......@@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum";
const char kNameIsFinite[] = "isFinite";
const char kNameReciprocal[] = "Reciprocal";
const char kNameRsqrt[] = "Rsqrt";
const char kNameRsqrtGrad[] = "RsqrtGrad";
const char kNameSqrt[] = "Sqrt";
const char kNameSquare[] = "Square";
const char kNameSquaredDifference[] = "SquaredDifference";
......@@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad";
const char kNameConvolution[] = "Convolution";
const char kNameBiasAdd[] = "BiasAdd";
const char kNameMaxPoolGrad[] = "MaxPoolGrad";
const char kNameRsqrtGrad[] = "RsqrtGrad";
const char kNameSqrtGrad[] = "SqrtGrad";
const char kNameReciprocalGrad[] = "ReciprocalGrad";
const char kNameAvgPoolGrad[] = "AvgPoolGrad";
const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
const char kNameApplyMomentum[] = "ApplyMomentum";
......@@ -233,6 +235,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameAllgather), ADPT_DESC(HcomAllGather)},
{string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)},
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
{string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)},
{string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)},
{string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)},
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},
......
......@@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}};
// RsqrtGrad
INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}};
// SqrtGrad
INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}};
// ReciprocalGrad
INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}};
// avgpoolgrad
INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}};
ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
......
......@@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool)
DECLARE_OP_USE_OUTPUT(MaxPool)
DECLARE_OP_ADAPTER(MaxPoolGrad)
DECLARE_OP_USE_OUTPUT(MaxPoolGrad)
DECLARE_OP_ADAPTER(SqrtGrad)
DECLARE_OP_USE_OUTPUT(SqrtGrad)
DECLARE_OP_ADAPTER(ReciprocalGrad)
DECLARE_OP_USE_OUTPUT(ReciprocalGrad)
DECLARE_OP_ADAPTER(RsqrtGrad)
DECLARE_OP_USE_OUTPUT(RsqrtGrad)
DECLARE_OP_ADAPTER(AvgPool)
DECLARE_OP_USE_OUTPUT(AvgPool)
DECLARE_OP_ADAPTER(AvgPoolGrad)
......
......@@ -356,15 +356,10 @@ def get_bprop_square(self):
@bprop_getters.register(P.Sqrt)
def get_bprop_sqrt(self):
"""Grad definition for `Sqrt` operation."""
mul_func = P.Mul()
fill_func = P.Fill()
div_op = P.RealDiv()
sqrt = P.Sqrt()
dtype = P.DType()
sqrt_grad = G.SqrtGrad()
def bprop(x, out, dout):
temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x))
dx = mul_func(dout, temp)
dx = sqrt_grad(out, dout)
return (dx,)
return bprop
......@@ -373,10 +368,10 @@ def get_bprop_sqrt(self):
@bprop_getters.register(P.Rsqrt)
def get_bprop_rsqrt(self):
"""Grad definition for `Rsqrt` operation."""
rsqrt_grad = G.RsqrtGrad()
def bprop(x, out, dout):
grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x)
dx = dout * grad
dx = rsqrt_grad(out, dout)
return (dx,)
return bprop
......@@ -385,14 +380,10 @@ def get_bprop_rsqrt(self):
@bprop_getters.register(P.Reciprocal)
def get_bprop_reciprocal(self):
"""Grad definition for `Reciprocal` operation."""
neg = P.Neg()
mul = P.Mul()
square = P.Square()
reciprocal = P.Reciprocal()
reciprocal_grad = G.ReciprocalGrad()
def bprop(x, out, dout):
g = neg(reciprocal(square(x)))
dx = mul(dout, g)
dx = reciprocal_grad(out, dout)
return (dx,)
return bprop
......
......@@ -441,6 +441,7 @@ def get_bprop_softmax(self):
sub = P.Sub()
mul = P.Mul()
axis = self.axis
def bprop(x, out, dout):
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
return (dx,)
......
......@@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe
from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe
from .reciprocal_grad import _reciprocal_grad_tbe
from .sqrt_grad import _sqrt_grad_tbe
from .rsqrt_grad import _rsqrt_grad_tbe
from .flatten_grad import _flatten_grad_tbe
from .scatter_add import _scatter_add_tbe
from .atan2 import _atan2_tbe
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Add op"""
"""Reciprocal op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reciprocal_op_info = TBERegOp("Reciprocal") \
......@@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
@op_info_register(reciprocal_op_info)
def _reciprocal_tbe():
"""Add TBE register"""
"""Reciprocal TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReciprocalGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reciprocal_grad.so") \
.compute_cost(10) \
.kernel_name("reciprocal_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(reciprocal_grad_op_info)
def _reciprocal_grad_tbe():
"""ReciprocalGrad TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RsqrtGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("rsqrt_grad.so") \
.compute_cost(10) \
.kernel_name("rsqrt_grad") \
.partial_flag(True) \
.op_pattern("broadcast") \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(rsqrt_grad_op_info)
def _rsqrt_grad_tbe():
"""RsqrtGrad TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SqrtGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sqrt_grad_op_info = TBERegOp("SqrtGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("sqrt_grad.so") \
.compute_cost(10) \
.kernel_name("sqrt_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()
@op_info_register(sqrt_grad_op_info)
def _sqrt_grad_tbe():
"""SqrtGrad TBE register"""
return
......@@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer):
return x
class ReciprocalGrad(PrimitiveWithInfer):
"""Performs grad of Reciprocal operation."""
@prim_attr_register
def __init__(self):
"""init ReciprocalGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
class RsqrtGrad(PrimitiveWithInfer):
"""Performs grad of Rsqrt operation."""
@prim_attr_register
def __init__(self):
"""init RsqrtGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
return x_dtype
class SoftmaxGrad(PrimitiveWithInfer):
"""Performs grad of Softmax operation."""
@prim_attr_register
def __init__(self):
"""init SoftmaxGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
class SqrtGrad(PrimitiveWithInfer):
"""Performs grad of Sqrt operation."""
@prim_attr_register
def __init__(self):
"""init SqrtGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
class BatchNormGrad(PrimitiveWithInfer):
"""Performs grad of BatchNorm operation."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册