提交 c36e7f15 编写于 作者: J jiangjinsheng

support vm for log1p

上级 4a8fcf5d
......@@ -331,6 +331,19 @@ def get_bprop_log(self):
return bprop
@bprop_getters.register(P.Log1p)
def get_bprop_log1p(self):
"""Grad definition for `Log1p` operation."""
reciprocal = P.Reciprocal()
def bprop(x, out, dout):
x_1p = x + 1
g = reciprocal(x_1p)
dx = g * dout
return dx, 0
return bprop
@bprop_getters.register(P.Erf)
def get_bprop_erf(self):
"""Grad definition for `Erf` operation."""
......
......@@ -159,3 +159,4 @@ from .ones_like import _ones_like_tbe
from .batch_to_space import _batch_to_space_tbe
from .space_to_batch import _space_to_batch_tbe
from .floor import _floor_tbe
from .log1p import _log1p_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log1p op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
log1p_op_info = TBERegOp("Log1p") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("log1p.so") \
.compute_cost(10) \
.kernel_name("log1p") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(log1p_op_info)
def _log1p_tbe():
"""Log1p TBE register"""
return
......@@ -40,7 +40,7 @@ from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus,
......
......@@ -1007,6 +1007,35 @@ class Log(PrimitiveWithInfer):
return x
class Log1p(PrimitiveWithInfer):
"""
Returns the natural logarithm of one plus the input tensor element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> log1p = P.Log1p()
>>> log1p(input_x)
[0.6931472, 1.0986123, 1.609438]
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x):
return x
def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
return x
class Erf(PrimitiveWithInfer):
r"""
Computes the Gauss error function of `input_x` element-wise.
......
......@@ -359,6 +359,14 @@ class FloorNet(nn.Cell):
def construct(self, x):
return self.floor(x)
class Log1pNet(nn.Cell):
def __init__(self):
super(Log1pNet, self).__init__()
self.log1p = P.Log1p()
def construct(self, x):
return self.log1p(x)
test_case_math_ops = [
('MatMulGrad', {
......@@ -405,6 +413,11 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[1., 0., -2.]], np.float32))],
'desc_bprop': [Tensor(np.array([[1., 0., -2.]], np.float32))],
'skip': ['backward']}),
('Log1p', {
'block': Log1pNet(),
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
'skip': ['backward']}),
]
test_case_lists = [test_case_math_ops]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册