提交 b285c8fa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1973 develop TensorScatterUpdate op and access ge and vm

Merge pull request !1973 from zhangbuxue/develop_TensorScatterUpdate_op_and_access_ge_and_vm
......@@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu";
const char kNameEluGrad[] = "EluGrad";
const char kNameTensorScatterUpdate[] = "TensorScatterUpdate";
const char kNameScatterUpdate[] = "ScatterUpdate";
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
const char kNameScatterMax[] = "ScatterMax";
......@@ -261,6 +262,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)},
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
......
......@@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
// TensorScatterUpdate
INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}};
// ScatterUpdate
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
......
......@@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_OP_ADAPTER(TensorScatterUpdate)
DECLARE_OP_USE_OUTPUT(TensorScatterUpdate)
DECLARE_OP_ADAPTER(ScatterUpdate)
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
DECLARE_OP_ADAPTER(ScatterNdUpdate)
......
......@@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self):
return bprop
@bprop_getters.register(P.TensorScatterUpdate)
def get_bprop_tensor_scatter_update(self):
"""Generate bprop for TensorScatterUpdate"""
gather_nd = P.GatherNd()
tensor_scatter_update = P.TensorScatterUpdate()
def bprop(x, indices, update, out, dout):
x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
update_grad = gather_nd(dout, indices)
return x_grad, zeros_like(indices), update_grad
return bprop
@bprop_getters.register(P.Argmax)
def get_bprop_argmax(self):
"""Generate bprop for Argmax"""
......
......@@ -255,3 +255,4 @@ from .lamb_next_right import _lamb_next_right_tbe
from .sparse_gather_v2 import _sparse_gather_v2_tbe
from .data_format_dim_map import _data_format_dim_map_tbe
from .histogram_fixed_width import _histogram_fixed_width_tbe
from .tensor_scatter_update import _tensor_scatter_update_tbe
......@@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
......
......@@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorScatterUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("tensor_scatter_update.so") \
.compute_cost(10) \
.kernel_name("tensor_scatter_update") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(1, "updates", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(tensor_scatter_update_op_info)
def _tensor_scatter_update_tbe():
"""TensorScatterUpdate TBE register"""
return
......@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, EmbeddingLookup,
Squeeze, StridedSlice, Tile,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
......@@ -212,6 +212,7 @@ __all__ = [
'Pad',
'MirrorPad',
'GatherNd',
'TensorScatterUpdate',
'ScatterUpdate',
'ScatterNdUpdate',
'Floor',
......
......@@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer):
return x_dtype
class TensorScatterUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.
Using given values to update tensor value, along with the input indices.
Inputs:
- **input_x** (Tensor) - The target tensor.
- **indices** (Tensor) - The index of input tensor whose data type is int32.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.TensorScatterUpdate()
>>> output = op(input_x, indices, update)
"""
@prim_attr_register
def __init__(self):
"""Init TensorScatterUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
class ScatterUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.
......@@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer):
def infer_shape(self, x_shape, indices_shape, value_shape):
if indices_shape + x_shape[1:] != value_shape:
raise ValueError('Input value are not match with input indices.')
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
......@@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError('Input value are not match with input indices.')
raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
......
......@@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
def test_tensor_scatter_update():
class TensorScatterUpdateNet(nn.Cell):
"""TensorScatterUpdate net definition"""
def __init__(self):
super(TensorScatterUpdateNet, self).__init__()
self.tensor_scatter_update = P.TensorScatterUpdate()
def construct(self, x, i, u):
out = self.tensor_scatter_update(x, i, u)
return out
net = TensorScatterUpdateNet()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
updates = Tensor(np.ones([2, 5], np.float32))
net(x, indices, updates)
class InputBackward(nn.Cell):
def __init__(self, network):
super(InputBackward, self).__init__()
......@@ -1537,6 +1556,12 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
Tensor(np.ones((2,), np.int32))),
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
('TensorScatterUpdate', {
'block': P.TensorScatterUpdate(),
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.ones([2, 5], np.float32) * 99)),
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
('ScatterMax', {
'block': ScatterMax(),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册