提交 5b6c8fad 编写于 作者: L liuxiao

Add Erf\FillD operator for VM

上级 715c0735
......@@ -57,6 +57,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"strided_slice", "strided_slice_d"},
{"strided_slice_grad", "strided_slice_grad_d"},
{"transpose", "transpose_d"},
{"fill", "fill_d"},
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
{"concat", "concat_d"},
{"slice", "slice_d"},
......
......@@ -53,6 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(kExpandDimsOpName, {1});
Register(kSplitOpName, {0});
Register(kTopKOpName, {1});
Register(kErfOpName, {1});
Register(kSparseApplyAdagradOpName, {2});
Register(kResizeNearestNeighborGrad, {1});
}
......
......@@ -92,6 +92,7 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum";
constexpr auto kGreaterOpName = "Greater";
constexpr auto kSqrtOpName = "Sqrt";
constexpr auto kRsqrtOpName = "Rsqrt";
constexpr auto kErfOpName = "Erf";
constexpr auto kRealDivOpName = "RealDiv";
constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR";
constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay";
......
......@@ -17,6 +17,7 @@
from functools import reduce
import numpy as np
from .. import functional as F
from .. import operations as P
from ..operations import _grad_ops as G
......@@ -333,6 +334,23 @@ def get_bprop_log(self):
return bprop
@bprop_getters.register(P.Erf)
def get_bprop_erf(self):
"""Grad definition for `Erf` operation."""
exp = P.Exp()
square = P.Square()
sqrt = P.Sqrt()
cast = P.Cast()
dtype = P.DType()
def bprop(x, out, dout):
half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
x_square = square(x)
dx = dout * half_root_pi * exp(-x_square)
return (dx,)
return bprop
@bprop_getters.register(P.Pow)
def get_bprop_pow(self):
"""Grad definition for `Pow` operation."""
......
......@@ -139,6 +139,8 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe
from .fused_mul_add import _fused_mul_add_tbe
from .fused_mul_add_n import _fused_mul_add_n_tbe
from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe
from .fill_d import _fill_d_op_tbe
from .erf import _erf_op_tbe
from .depthwise_conv2d import _depthwise_conv2d_tbe
from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe
from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Erf op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
erf_op_info = TBERegOp("Erf") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("erf.so") \
.compute_cost(10) \
.kernel_name("erf") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(erf_op_info)
def _erf_op_tbe():
"""Erf TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FillD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fill_d_op_info = TBERegOp("FillD") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fill_d.so") \
.compute_cost(10) \
.kernel_name("fill_d") \
.partial_flag(True) \
.attr("dims", "required", "listInt", "all") \
.input(0, "value", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(fill_d_op_info)
def _fill_d_op_tbe():
"""FillD TBE register"""
return
......@@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Floor, FloorDiv, FloorMod, Acosh,
Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd,
LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
......@@ -139,6 +139,7 @@ __all__ = [
'ReLU',
'ReLU6',
'Elu',
'Erf',
'Sigmoid',
'HSwish',
'HSigmoid',
......
......@@ -1007,6 +1007,36 @@ class Log(PrimitiveWithInfer):
return x
class Erf(PrimitiveWithInfer):
r"""
Computes the Gauss error function of `input_x` element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape and dtype as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32)
>>> erf = P.Erf()
>>> erf(input_x)
[-0.8427168, 0., 0.8427168, 0.99530876, 0.99997765]
"""
@prim_attr_register
def __init__(self):
"""init Erf"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
return x_type
class Minimum(_MathBinaryOp):
"""
Computes the element-wise minimum of input tensors.
......
......@@ -250,6 +250,10 @@ test_case_math_ops = [
'block': P.Exp(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Erf', {
'block': P.Erf(),
'desc_inputs': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))]}),
('Floor', {
'block': P.Floor(),
'desc_inputs': [[2, 512, 56, 56]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册