提交 71d27c68 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2068 Add InplaceAddD and InplaceSubD vm ops

Merge pull request !2068 from liuwenhao/master
......@@ -80,6 +80,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"concat", "concat_d"},
{"slice", "slice_d"},
{"reduce_sum", "reduce_sum_d"},
{"inplace_add", "inplace_add_d"},
{"inplace_sub", "inplace_sub_d"},
{"one_hot", "one_hot_d"},
{"sum", "reduce_sum_d"},
{"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"},
......
......@@ -171,6 +171,8 @@ const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
// NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
......
......@@ -180,6 +180,8 @@ extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimCumSum;
extern const PrimitivePtr kPrimCumProd;
extern const PrimitivePtr kPrimSubscalar;
extern const PrimitivePtr kPrimInplaceAdd;
extern const PrimitivePtr kPrimInplaceSub;
// NN
extern const PrimitivePtr kPrimFlatten;
......
......@@ -133,6 +133,8 @@ constexpr auto kResizeNearestNeighborV2OpName = "ResizeNearestNeighborV2";
constexpr auto kResizeNearestNeighborV2GradOpName = "ResizeNearestNeighborV2Grad";
constexpr auto kApplyRMSPropOpname = "ApplyRMSProp";
constexpr auto kCumsumOpName = "Cumsum";
constexpr auto kInplaceAddOpName = "InplaceAdd";
constexpr auto kInplaceSubOpName = "InplaceSub";
constexpr auto kResizeBilinearV2OpName = "kResizeBilinearV2";
constexpr auto kReduceProdOpName = "ReduceProd";
constexpr auto kCumprodOpName = "Cumprod";
......
......@@ -15,6 +15,8 @@
"""tbe ops"""
from .abs import _abs_tbe
from .inplace_add import _inplace_add_tbe
from .inplace_sub import _inplace_sub_tbe
from .abs_grad import _abs_grad_tbe
from .acos import _acos_tbe
from .acos_grad import _acos_grad_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InplaceAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
inplace_add_op_info = TBERegOp("InplaceAdd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("inplace_add_d.so") \
.compute_cost(10) \
.kernel_name("inplace_add_d") \
.partial_flag(True) \
.attr("indices", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "v", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(inplace_add_op_info)
def _inplace_add_tbe():
"""InplaceAdd TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InplaceSub op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
inplace_sub_op_info = TBERegOp("InplaceSub") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("inplace_sub_d.so") \
.compute_cost(10) \
.kernel_name("inplace_sub_d") \
.partial_flag(True) \
.attr("indices", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "v", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(inplace_sub_op_info)
def _inplace_sub_tbe():
"""InplaceSub TBE register"""
return
......@@ -41,7 +41,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
BitwiseXor, Inv, Invert, ApproximateEqual,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
......@@ -178,6 +178,8 @@ __all__ = [
'DropoutGrad',
'Dropout',
'Neg',
'InplaceAdd',
'InplaceSub',
'Slice',
'DType',
'NPUAllocFloatStatus',
......
......@@ -772,6 +772,125 @@ class Neg(PrimitiveWithInfer):
return input_x
class InplaceAdd(PrimitiveWithInfer):
"""
Adds v into specified rows of x. Computes y = x; y[i,] += v.
Args:
- **indices** (Union[int, tuple]) - Indices into the left-most dimension of x, and determines which rows of x
to add with v. It is a int or tuple, whose value is in [0, the first dimension size of x).
Inputs:
- **input_x** (Tensor) - The first input is a tensor whose data type is number.
- **input_v** (Tensor) - The second input is a tensor who has the same dimension sizes as x except
the first dimension, which must be the same as indices's size.
Outputs:
Tensor, has the same shape and dtype as input.
Examples:
>>> indices = [0, 1]
>>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
>>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
>>> inplaceAdd = P.InplaceAdd(indices)
>>> inplaceAdd(input_x, input_v)
[[1.5 3.]
[4. 5.5]
[5. 6.]]
"""
@prim_attr_register
def __init__(self, indices):
"""init InplaceAdd"""
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
self.indices = indices
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
if isinstance(self.indices, int):
validator.check("size of indices", 1, "v's first dimension", v_shape[0],
Rel.EQ, self.name)
if self.indices < 0 or self.indices >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.')
else:
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
if len(x_shape) > 1:
validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:],
Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_value_type('indices', self.indices, [tuple, int], self.name)
return x_dtype
class InplaceSub(PrimitiveWithInfer):
"""
Subtracts v into specified rows of x. Computes y = x; y[i, :] -= v; return y.
Args:
- **indices** (Union[int, tuple]) - Indices into the left-most dimension of x, and determines which rows of x
to sub with v. It is a int or tuple, whose value is in [0, the first dimension size of x).
Inputs:
- **input_x** (Tensor) - The first input is a tensor whose data type is number.
- **input_v** (Tensor) - The second input is a tensor who has the same dimension sizes as x except
the first dimension, which must be the same as indices's size.
Outputs:
Tensor, has the same shape and dtype as input.
Examples:
>>> indices = [0, 1]
>>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
>>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
>>> inplaceSub = P.InplaceSub(indices)
>>> inplaceSub(input_x, input_v)
[[0.5 1.]
[2. 2.5]
[5. 6.]]
"""
@prim_attr_register
def __init__(self, indices):
"""init InplaceSub"""
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
self.indices = indices
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
if isinstance(self.indices, int):
validator.check("size of indices", 1, "v's first dimension", v_shape[0],
Rel.EQ, self.name)
if self.indices < 0 or self.indices >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.')
else:
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
if len(x_shape) > 1:
validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:],
Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_value_type('indices', self.indices, [tuple, int], self.name)
return x_dtype
class Sub(_MathBinaryOp):
"""
Subtracts the second input tensor from the first input tensor element-wise.
......
......@@ -368,6 +368,26 @@ class ApplyRMSNet(nn.Cell):
return out
class InplaceAddNet(nn.Cell):
def __init__(self):
super(InplaceAddNet, self).__init__()
self.inplace_add = P.InplaceAdd(indices=(0, 1))
def construct(self, x, v):
out = self.inplace_add(x, v)
return out
class InplaceSubNet(nn.Cell):
def __init__(self):
super(InplaceSubNet, self).__init__()
self.inplace_sub = P.InplaceSub(indices=(0, 1))
def construct(self, x, v):
out = self.inplace_sub(x, v)
return out
test_case_math_ops = [
('BitwiseAnd', {
'block': P.BitwiseAnd(),
......@@ -493,6 +513,16 @@ test_case_math_ops = [
'desc_inputs': [[2, 512, 56, 56]],
'desc_bprop': [[2, 512, 56, 56]],
'skip': ['backward']}),
('InplaceAdd', {
'block': InplaceAddNet(),
'desc_inputs': [Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)),
Tensor(np.array([[0.5, 1], [1, 1.5]]).astype(np.float32))],
'skip': ['backward']}),
('InplaceSub', {
'block': InplaceSubNet(),
'desc_inputs': [Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)),
Tensor(np.array([[0.5, 1], [1, 1.5]]).astype(np.float32))],
'skip': ['backward']}),
('ACos', {
'block': P.ACos(),
'desc_inputs': [Tensor(np.array([2., 3.]).astype(np.float32))],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册