diff --git a/.gitignore b/.gitignore index b5d31931012c42dd8a5e1b0e7a4514147aa1f9ce..77ff222a1aedb1d9b98090869c367a847730c63f 100644 --- a/.gitignore +++ b/.gitignore @@ -65,6 +65,7 @@ test_temp_summary_event_file/ *.ckpt *.shp *.pkl +*.pb .clangd mindspore/version.py mindspore/default_config.py diff --git a/mindspore/ccsrc/kernel/common_utils.cc b/mindspore/ccsrc/kernel/common_utils.cc index 54980c2cb713b5a3bafaa2b7fbd0d3fc4316777f..2769e0c42aa5da6efd866daba650a2ac281e6306 100644 --- a/mindspore/ccsrc/kernel/common_utils.cc +++ b/mindspore/ccsrc/kernel/common_utils.cc @@ -253,7 +253,7 @@ std::string Dtype2String(const std::string &dtypes) { std::string TypeId2String(TypeId type_id) { auto iter = type_id_str_map.find(type_id); if (iter == type_id_str_map.end()) { - MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id); + return std::string(TypeIdLabel(type_id)); } return iter->second; } diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index 305f07584f69ca6853becbf3132c28b0d33cc22d..9591fef10da17a4f00bdef8e612b9ea03d3c3047 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -47,16 +47,6 @@ const std::vector &GetSignature(const ValuePtr &function) { return empty; } -const std::string GetOpName(const ValuePtr &function) { - std::string name = ""; - if (function->isa()) { - name = function->cast()->name(); - } else if (function->isa()) { - name = function->cast()->name(); - } - return name; -} - void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, const std::vector &signature, bool has_var, std::vector *const op_inputs) { std::size_t sig_size = signature.size(); @@ -93,7 +83,8 @@ void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, *max_type_number = type_number; } -TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indexs) { +TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indexs, + const std::set &write_indexs) { TypeId max_type_id = kTypeUnknown; TypeId max_type = kTypeUnknown; size_t max_type_number = 0; @@ -103,7 +94,12 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve TypeId arg_type = kTypeUnknown; AbstractBasePtr arg_value = args_spec_list[index]; if (arg_value->isa()) { - arg_value = arg_value->cast()->ref(); + auto is_write = (write_indexs.find(index) != write_indexs.end()); + if (is_write) { + arg_value = arg_value->cast()->ref_origin(); + } else { + arg_value = arg_value->cast()->ref(); + } } if (arg_value->isa()) { auto tensor = arg_value->cast(); @@ -157,7 +153,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve // Get the largest type of index in the same SignatureEnumDType of arguments. std::map GetMaxDtype(const std::vector &dtypes, - const abstract::AbstractBasePtrList &args_spec_list) { + const abstract::AbstractBasePtrList &args_spec_list, + const std::set &write_indexs) { // record index for signature.dtypes of the same type // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} std::map> type_indexs; @@ -192,7 +189,7 @@ std::map GetMaxDtype(const std::vector &signature, const abstract::AbstractBasePtrList &args_spec_list, - const FuncGraphPtr &graph, std::vector *const op_inputs, - const std::set &write_indexs) { +void DoAutoCast(const std::string &func_name, const std::vector &signature, + const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, + std::vector *const op_inputs, const std::set &write_indexs) { std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), [](const Signature &sig) { return sig.dtype; }); @@ -216,16 +213,23 @@ void DoAutoCast(const std::vector &signature, const abstract::Abstrac return; } // Stat the index of the arguments with the largest type in the same SignatureEnumDType. - std::map dst_type = GetMaxDtype(dtypes, args_spec_list); + std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs); // Identify which arg requires auto cast for (size_t i = 0; i < args_spec_list.size(); ++i) { auto it = dst_type.find(dtypes[i]); if (it == dst_type.end() || it->second == kTypeUnknown) { continue; } + auto rw_it = write_indexs.find(i); + auto is_write = (rw_it != write_indexs.end()); + AbstractBasePtr arg_value = args_spec_list[i]; if (arg_value->isa()) { - arg_value = arg_value->cast()->ref(); + if (is_write) { + arg_value = arg_value->cast()->ref_origin(); + } else { + arg_value = arg_value->cast()->ref(); + } } TypeId arg_type_id = kTypeUnknown; if (arg_value->isa()) { @@ -243,10 +247,9 @@ void DoAutoCast(const std::vector &signature, const abstract::Abstrac if (it_map == type_map.end()) { continue; } - auto rw_it = write_indexs.find(i); - if (rw_it != write_indexs.end()) { + if (is_write) { if (arg_type_id != it->second) { - MS_LOG(EXCEPTION) << "In op '" << GetOpName(graph) << "', argument '" << args_spec_list[i] + MS_LOG(EXCEPTION) << "In op '" << func_name << "', argument '" << args_spec_list[i] << "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '" << TypeIdLabel(it->second) << "' automatically."; } @@ -299,8 +302,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func if (sig == SignatureEnumRW::kRWRead) { param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); } else if (sig == SignatureEnumRW::kRWWrite) { + param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); write_indexs.insert(i); - param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param}); } // If sig is SignatureEnumRW::kRWRef, not do anything. } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { @@ -310,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func } // process default ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); - DoAutoCast(signature, args_spec_list, func_graph, &op_inputs, write_indexs); + DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs); return func_graph->NewCNode(op_inputs); } } // namespace diff --git a/mindspore/ccsrc/operator/prim_others.cc b/mindspore/ccsrc/operator/prim_others.cc index 84144380f8fa101e21e156b5aa795de5cd3ccada..b8e89378e6a387f329a33e31a834e7559159846f 100644 --- a/mindspore/ccsrc/operator/prim_others.cc +++ b/mindspore/ccsrc/operator/prim_others.cc @@ -160,7 +160,7 @@ AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const Primitive const AbstractBasePtrList &args_spec_list) { // arguments: value if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size() + MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() << "."; } TypePtr type = args_spec_list[0]->GetTypeTrack(); diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 0195562f16dc5c19f844d2d9236a88bf48563fd3..8a5c6342260bdb5fad8c54d89f45f0ca528eecfb 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -81,8 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Ref eliminate make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); - get_make_ref_eliminate_ = - MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); + get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", + {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); replace_refkey_by_param_ = MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); diff --git a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h index 01bdd0906e27140f18b582d1cd8de6a2cbfd4d92..201992ef13d460b206e836327beac1177203fb49 100644 --- a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h @@ -48,6 +48,7 @@ class MakeRefEliminater : public AnfVisitor { // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y +// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z class GetMakeRefEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -71,6 +72,10 @@ class GetMakeRefEliminater : public AnfVisitor { return ref->input(2); } + if (cnode->IsApply(prim::kPrimGetRefOrigin)) { + return ref->input(3); + } + return nullptr; } }; diff --git a/mindspore/ccsrc/pipeline/parse/function_block.cc b/mindspore/ccsrc/pipeline/parse/function_block.cc index 24e7ae74fb7d034bc07b5a4d732aab4089401563..66534390a0ec47a2fce1662780b6513df5d26823 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/parse/function_block.cc @@ -315,7 +315,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); - ValueNodePtr get_refkey_op = NewValueNode(prim::kPrimGetRefKey); + ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin); ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); const std::string primitive_name("assign"); const std::string module_name("mindspore.ops.functional"); @@ -329,8 +329,8 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { vec_states.emplace_back(make_tuple_op); for (auto &item : state_assign_) { auto source = ReadVariable(item.second); - auto refkey = func_graph()->NewCNode({get_refkey_op, item.first}); - auto assign = func_graph()->NewCNode({assign_op, refkey, source}); + auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first}); + auto assign = func_graph()->NewCNode({assign_op, origin, source}); MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; vec_states.emplace_back(assign); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc index e7f6579b95b47f2f3fb5de399fd5898a485703c9..d4f0c6f8d4dbd80937d6483bfeb0e0965f39fdf9 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc @@ -801,8 +801,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const { std::string AbstractRef::ToString() const { std::ostringstream buffer; buffer << type_name() << "(" - << "key: " << ref_key_->ToString() << "ref_value: " << ref_->ToString() - << "origin_value: " << ref_origin_->ToString(); + << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() + << " origin_value: " << ref_origin_->ToString(); auto value = GetValueTrack(); if (value) { buffer << ", value: " << value->ToString(); diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 8d2efe46a80cfcdb3c0c30e2327cd1acaea7b6ef..bcd02884240235f70f20c63b7c7e84514e7d7ed7 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -783,7 +783,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); AbstractRefPtr ref_abs = abs->cast(); if (ref_abs == nullptr) { - MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; + MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); return nullptr; } auto key_abs = ref_abs->ref_key(); diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 2c1ff60298ff957c52219d012f7f149518641393..02a27591d42ae74746b031fc72b503c4337eaed8 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -170,7 +170,7 @@ def get_py_obj_dtype(obj): Type of MindSpore type. """ # Tensor - if hasattr(obj, 'dtype'): + if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type): return tensor_type(obj.dtype()) if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): return function diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index e3cc057b494708ff6456ed3bcd557f7ab3a1ef2e..a9c856b7c54097ef57fdbfd3ca2a426c3abe2569 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -31,7 +31,9 @@ from ...common.tensor import Tensor from ..operations.math_ops import _infer_shape_reduce from .._utils import get_concat_offset from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register - +from ..._c_expression import signature_rw as sig_rw +from ..._c_expression import signature_kind as sig_kind +from ..._c_expression import signature_dtype as sig_dtype def _check_infer_attr_reduce(axis, keep_dims, prim_name): validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) @@ -2156,13 +2158,17 @@ class ScatterUpdate(PrimitiveWithInfer): >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) - >>> op = P.ScatterNdUpdate() + >>> op = P.ScatterUpdate() >>> output = op(input_x, indices, update) """ - + __mindspore_signature__ = ( + ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), + ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + ) @prim_attr_register def __init__(self, use_locking=True): - """Init ScatterNdUpdate""" + """Init ScatterUpdate""" self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) def infer_shape(self, x_shape, indices_shape, value_shape): @@ -2201,7 +2207,11 @@ class ScatterNdUpdate(PrimitiveWithInfer): >>> op = P.ScatterNdUpdate() >>> output = op(input_x, indices, update) """ - + __mindspore_signature__ = ( + ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), + ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + ) @prim_attr_register def __init__(self, use_locking=True): """Init ScatterNdUpdate""" diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 8c9b10b2466d81e016b31dc692fe0965bd715e54..c3820f27bff9de1d5a1bff4a8d5568b33e40c583 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -179,7 +179,7 @@ class AssignAdd(PrimitiveWithInfer): return value def infer_dtype(self, variable, value): - args = {"value": value} + args = {"variable": variable, "value": value} validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) return value @@ -222,7 +222,7 @@ class AssignSub(PrimitiveWithInfer): return value def infer_dtype(self, variable, value): - args = {"value": value} + args = {"variable": variable, "value": value} validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) return value diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 95f9df440c589c8ffb1f68a4a60c87f5c322696e..d73f53eb6a3cc8fffc986fa92b8b59d912f69d3e 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -58,6 +58,8 @@ class Assign(PrimitiveWithInfer): return variable def infer_dtype(self, variable, value): + args = {"variable": variable, "value": value} + validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) return variable diff --git a/tests/ut/python/ops/test_layer_switch.py b/tests/ut/python/ops/test_layer_switch.py index 35636637a4c3689d3640be40b3e9d678ff966b52..82aa6db39f568dfe1a05873f14b6c8c4c06d5100 100644 --- a/tests/ut/python/ops/test_layer_switch.py +++ b/tests/ut/python/ops/test_layer_switch.py @@ -1,3 +1,18 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test layer switch""" import numpy as np import mindspore diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 15ff49e2c0e0533523411237edd3490a13f58c1c..2c97b49e1520f4f31088e6194802f2f90620a1ac 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -345,19 +345,6 @@ class Conv2dNativeNet(nn.Cell): return self.flatten(self.conv(input_x, self.weight)) -class MakeRefKeyNet(nn.Cell): - """ MakeRefKeyNet definition """ - - def __init__(self): - super(MakeRefKeyNet, self).__init__() - self.y = Parameter(Tensor([1.0], mindspore.float32), name="y") - - def construct(self, x): - key = P.MakeRefKey("y")() - P.Assign()(key, x) - return x - - class StateNet(nn.Cell): """ StateTestTensor definition """ @@ -538,10 +525,6 @@ test_cases = [ 'block': Grad(NetWithLossClass(Conv2dNativeNet())), 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))], }), - ('MakeRefKey', { - 'block': MakeRefKeyNet(), - 'desc_inputs': [Tensor([2.0], mindspore.float32)], - }), ('StateTest', { 'block': StateNet(), 'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))], diff --git a/tests/ut/python/ops/test_signature.py b/tests/ut/python/ops/test_signature.py new file mode 100644 index 0000000000000000000000000000000000000000..e6447be8f37b1617bd87a460c488c21308160a74 --- /dev/null +++ b/tests/ut/python/ops/test_signature.py @@ -0,0 +1,75 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test assign sub +""" +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations as P +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +import mindspore as ms + +class AssignW(nn.Cell): + def __init__(self): + super(AssignW, self).__init__() + self.assign = P.Assign() + + def construct(self, x, w): + self.assign(x, w) + return x + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.b = Parameter(initializer('ones', [5]), name='b') + self.assign = AssignW() + + def construct(self, value): + return self.assign(self.b, value) + + +def test_assign_through_cell(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net = Net() + net.to_float(ms.float16) + net.add_flags_recursive(fp16=False) + input_data = Tensor(np.ones([5]).astype(np.float32)) + net(input_data) + with pytest.raises(TypeError): + net(None) + + +class NetScatterNdUpdate(nn.Cell): + def __init__(self): + super(NetScatterNdUpdate, self).__init__() + self.b = Parameter(initializer('ones', [5, 5]), name='b') + self.scatter = P.ScatterNdUpdate() + + def construct(self, idx, x): + return self.scatter(self.b, idx, x) + + +def test_scatter_nd_update(): + context.set_context(mode=context.GRAPH_MODE) + net = NetScatterNdUpdate() + x = Tensor(np.ones([5]).astype(np.float16)) + idx = Tensor(np.ones([1]).astype(np.int32)) + net(idx, x)