提交 2d3cd827 编写于 作者: J jiangjinsheng

vm for erfc

上级 93e7c97a
...@@ -361,6 +361,23 @@ def get_bprop_erf(self): ...@@ -361,6 +361,23 @@ def get_bprop_erf(self):
return bprop return bprop
@bprop_getters.register(P.Erfc)
def get_bprop_erfc(self):
"""Grad definition for `Erfc` operation."""
exp = P.Exp()
square = P.Square()
sqrt = P.Sqrt()
cast = P.Cast()
dtype = P.DType()
def bprop(x, out, dout):
half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
x_square = square(x)
dx = dout * (-half_root_pi * exp(-x_square))
return (dx,)
return bprop
@bprop_getters.register(P.Pow) @bprop_getters.register(P.Pow)
def get_bprop_pow(self): def get_bprop_pow(self):
"""Grad definition for `Pow` operation.""" """Grad definition for `Pow` operation."""
......
...@@ -152,6 +152,7 @@ from .fused_mul_add_n import _fused_mul_add_n_tbe ...@@ -152,6 +152,7 @@ from .fused_mul_add_n import _fused_mul_add_n_tbe
from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe
from .fill import _fill_op_tbe from .fill import _fill_op_tbe
from .erf import _erf_op_tbe from .erf import _erf_op_tbe
from .erfc import _erfc_op_tbe
from .depthwise_conv2d import _depthwise_conv2d_tbe from .depthwise_conv2d import _depthwise_conv2d_tbe
from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe
from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Erfc op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
erfc_op_info = TBERegOp("Erfc") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("erfc.so") \
.compute_cost(10) \
.kernel_name("erfc") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(erfc_op_info)
def _erfc_op_tbe():
"""Erfc TBE register"""
return
...@@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge ...@@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh, Cos, Div, Equal, EqualCount, Exp, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
LogicalNot, LogicalOr, MatMul, Maximum, LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual, Minimum, Mul, Neg, NMSWithMask, NotEqual,
......
...@@ -1067,6 +1067,36 @@ class Erf(PrimitiveWithInfer): ...@@ -1067,6 +1067,36 @@ class Erf(PrimitiveWithInfer):
return x_type return x_type
class Erfc(PrimitiveWithInfer):
r"""
Computes the complementary error function of `input_x` element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape and dtype as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32)
>>> erfc = P.Erfc()
>>> erfc(input_x)
[1.8427168, 0., 0.1572832, 0.00469124, 0.00002235]
"""
@prim_attr_register
def __init__(self):
"""init Erfc"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
return x_type
class Minimum(_MathBinaryOp): class Minimum(_MathBinaryOp):
""" """
Computes the element-wise minimum of input tensors. Computes the element-wise minimum of input tensors.
......
...@@ -372,6 +372,15 @@ class Log1pNet(nn.Cell): ...@@ -372,6 +372,15 @@ class Log1pNet(nn.Cell):
return self.log1p(x) return self.log1p(x)
class ErfcNet(nn.Cell):
def __init__(self):
super(ErfcNet, self).__init__()
self.erfc = P.Erfc()
def construct(self, x):
return self.erfc(x)
test_case_math_ops = [ test_case_math_ops = [
('MatMulGrad', { ('MatMulGrad', {
'block': GradWrap(NetWithLoss(MatMulNet())), 'block': GradWrap(NetWithLoss(MatMulNet())),
...@@ -422,6 +431,11 @@ test_case_math_ops = [ ...@@ -422,6 +431,11 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))], 'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))], 'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'skip': ['backward']}), 'skip': ['backward']}),
('Erfc', {
'block': ErfcNet(),
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
}),
] ]
test_case_lists = [test_case_math_ops] test_case_lists = [test_case_math_ops]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册