From 598bfa0205fa8794f0b32b69a4709829a394646d Mon Sep 17 00:00:00 2001 From: panyifeng Date: Mon, 27 Jul 2020 16:59:26 +0800 Subject: [PATCH] add sparse operators --- mindspore/ccsrc/frontend/optimizer/clean.cc | 61 ++++++++++++++++++- mindspore/ccsrc/frontend/optimizer/clean.h | 2 +- mindspore/ccsrc/pipeline/jit/pass.cc | 8 +-- mindspore/nn/__init__.py | 5 +- mindspore/nn/sparse/__init__.py | 22 +++++++ mindspore/nn/sparse/sparse.py | 54 ++++++++++++++++ mindspore/ops/_grad/__init__.py | 2 +- mindspore/ops/_grad/grad_implementations.py | 1 + mindspore/ops/_grad/grad_sparse.py | 58 ++++++++++++++++++ .../composite/multitype_ops/ones_like_impl.py | 10 +++ mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/sparse_ops.py | 55 +++++++++++++++++ tests/ut/python/ir/test_sparse_tensor.py | 27 +++++--- .../python/pipeline/parse/test_cell_bprop.py | 25 +------- 14 files changed, 291 insertions(+), 43 deletions(-) create mode 100644 mindspore/nn/sparse/__init__.py create mode 100644 mindspore/nn/sparse/sparse.py create mode 100644 mindspore/ops/_grad/grad_sparse.py create mode 100644 mindspore/ops/operations/sparse_ops.py diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index 2597ac67c..b361c2a84 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -32,9 +32,11 @@ namespace opt { using mindspore::abstract::AbstractAttribute; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractIndexedSlices; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractUndetermined; @@ -73,6 +75,19 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) { return std::make_shared(abs_list->elements()); } + if (t->isa()) { + auto abs_sparse = dyn_cast(t); + std::vector abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()}; + return std::make_shared(abstract_list); + } + + if (t->isa()) { + auto abs_indexed_slices = dyn_cast(t); + std::vector abstract_list{abs_indexed_slices->indices(), abs_indexed_slices->values(), + abs_indexed_slices->dense_shape()}; + return std::make_shared(abstract_list); + } + return nullptr; } @@ -389,14 +404,44 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr return changed; } -bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { +AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items; + (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); + return node->func_graph()->NewCNode(inputs); +} + +AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int &index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [spase_getattr, sparse] + if (inputs.size() < 2) { + MS_LOG(EXCEPTION) << "Node's input number < 2."; + } + + AnfNodePtr sparse = inputs[1]; + MS_EXCEPTION_IF_NULL(sparse); + auto cons_node = NewValueNode(index); + AbstractBasePtr aptr = std::make_shared(std::make_shared(index)); + cons_node->set_abstract(aptr); + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}); +} + +bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); bool changed = false; // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var - AnfNodeSet all_node = manager->all_nodes(); + auto all_node = manager->all_nodes(); for (auto &node : all_node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); @@ -409,6 +454,18 @@ bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { new_node = ConvertListSetItemToTupleSetItem(cnode); } else if (IsValueNode(node)) { new_node = ConvertValueListNodeToValueTupleNode(node->cast()); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) || + IsPrimitiveCNode(node, prim::kPrimMakeIndexedSlices)) { + new_node = ConvertMakeSparseToMakeTuple(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) || + IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetIndices)) { + new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0); + } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) || + IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetValues)) { + new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1); + } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) || + IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetDenseShape)) { + new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2); } if (new_node != nullptr) { diff --git a/mindspore/ccsrc/frontend/optimizer/clean.h b/mindspore/ccsrc/frontend/optimizer/clean.h index 3f44511d8..abd847f8c 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.h +++ b/mindspore/ccsrc/frontend/optimizer/clean.h @@ -32,7 +32,7 @@ namespace opt { // Remove the class type from graphs bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); -bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); +bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); // Remove most uses of tuples from the graph // tuples that are returned will be kept diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 7bd5bc12e..eb25728a7 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -69,11 +69,11 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { return true; } -bool CleanListPass(const ResourcePtr &res) { +bool CleanAfterOptAPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); - bool changed = opt::CleanList(func_graph, res->manager()); + bool changed = opt::CleanAfterOptA(func_graph, res->manager()); abstract::AbstractBasePtrList args_spec; auto parameters = func_graph->parameters(); @@ -337,7 +337,7 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, - {"clean_list", CleanListPass}, + {"clean_after_opta", CleanAfterOptAPass}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, @@ -346,7 +346,7 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, - {"clean_list", CleanListPass}, + {"clean_after_opta", CleanAfterOptAPass}, {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, {"opt_control", ControlGroup}, diff --git a/mindspore/nn/__init__.py b/mindspore/nn/__init__.py index cc3a5483b..5bd9e3e24 100644 --- a/mindspore/nn/__init__.py +++ b/mindspore/nn/__init__.py @@ -17,13 +17,14 @@ Neural Networks Cells. Pre-defined building blocks or computing units to construct Neural Networks. """ -from . import layer, loss, optim, metrics, wrap, probability +from . import layer, loss, optim, metrics, wrap, probability, sparse from .cell import Cell, GraphKernel from .layer import * from .loss import * from .optim import * from .metrics import * from .wrap import * +from .sparse import * __all__ = ["Cell", "GraphKernel"] @@ -32,7 +33,7 @@ __all__.extend(loss.__all__) __all__.extend(optim.__all__) __all__.extend(metrics.__all__) __all__.extend(wrap.__all__) - +__all__.extend(sparse.__all__) __all__.sort() diff --git a/mindspore/nn/sparse/__init__.py b/mindspore/nn/sparse/__init__.py new file mode 100644 index 000000000..1a4c350a0 --- /dev/null +++ b/mindspore/nn/sparse/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Sparse related transformation. +""" +from .sparse import SparseToDense + +__all__ = [ + "SparseToDense", + ] diff --git a/mindspore/nn/sparse/sparse.py b/mindspore/nn/sparse/sparse.py new file mode 100644 index 000000000..c6c1c6113 --- /dev/null +++ b/mindspore/nn/sparse/sparse.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Sparse related tools.""" +from mindspore.ops import operations as P +from ..cell import Cell + + +class SparseToDense(Cell): + """ + Convert a sparse tensor into dense. + + Not yet supported by any backend at the moment. + + Args: + sparse_tensor (SparseTensor): the sparse tensor to convert. + + Returns: + Tensor, the tensor converted. + + Examples: + >>> class SparseToDenseCell(nn.Cell): + >>> def __init__(self, dense_shape): + >>> super(SparseToDenseCell, self).__init__() + >>> self.dense_shape = dense_shape + >>> self.sparse_to_dense = nn.SparseToDense() + >>> def construct(self, indices, values): + >>> sparse = SparseTensor(indices, values, self.dense_shape) + >>> return self.sparse_to_dense(sparse) + >>> + >>> indices = Tensor([[0, 1], [1, 2]]) + >>> values = Tensor([1, 2], dtype=ms.float32) + >>> dense_shape = (3, 4) + >>> SparseToDenseCell(dense_shape)(indices, values) + """ + def __init__(self): + super(SparseToDense, self).__init__() + self.sparse_to_dense = P.SparseToDense() + + def construct(self, sparse_tensor): + return self.sparse_to_dense(sparse_tensor.indices(), + sparse_tensor.values(), + sparse_tensor.dense_shape()) diff --git a/mindspore/ops/_grad/__init__.py b/mindspore/ops/_grad/__init__.py index de9e3ae8d..c9db4094e 100644 --- a/mindspore/ops/_grad/__init__.py +++ b/mindspore/ops/_grad/__init__.py @@ -15,7 +15,7 @@ """grad impl.""" from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ - grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops + grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse from .grad_base import get_bprop_fn __all__ = ['get_bprop_fn'] diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index 5dfe6e600..0ada82c84 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -116,6 +116,7 @@ def bprop_tuple_getitem(data, idx, out, dout): """Backpropagator for primitive `tuple_getitem`.""" return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) + @bprops.register("list_getitem") def bprop_list_getitem(data, idx, out, dout): """Backpropagator for primitive `list_getitem`.""" diff --git a/mindspore/ops/_grad/grad_sparse.py b/mindspore/ops/_grad/grad_sparse.py new file mode 100644 index 000000000..03b68cc40 --- /dev/null +++ b/mindspore/ops/_grad/grad_sparse.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""bprop primitives""" +from .. import functional as F +from .. import operations as P +from ..composite.multitype_ops.zeros_like_impl import zeros_like +from .grad_base import bprops, bprop_getters + +# Unused parameters are placeholders. + + +@bprops.register("MakeSparseTensor") +def bprop_make_sparse_tensor(indices, values, dense_shape, out, dout): + """Backpropagator for primitive `MakeSparseTensor`.""" + return zeros_like(indices), F.sparse_tensor_get_values(dout), () + + +@bprops.register("SparseTensorGetIndices") +def bprop_sparse_tensor_get_indices(sparse_tensor, out, dout): + """Backpropagator for primitive `SparseTensorGetIndices`.""" + return (zeros_like(sparse_tensor),) + + +@bprops.register("SparseTensorGetValues") +def bprop_sparse_tensor_get_values(sparse_tensor, out, dout): + """Backpropagator for primitive `SparseTensorGetValues`.""" + return F.make_sparse_tensor(F.sparse_tensor_get_indices(sparse_tensor), + dout, + F.sparse_tensor_get_dense_shape(sparse_tensor)) + + +@bprops.register("SparseTensorGetDenseShape") +def bprop_sparse_tensor_get_dense_shape(sparse_tensor, out, dout): + """Backpropagator for primitive `SparseTensorGetDenseShape`.""" + return (zeros_like(sparse_tensor),) + + +@bprop_getters.register(P.SparseToDense) +def get_bprop_sparse_to_dense(self): + """Generate bprop for SparseToDense""" + + def bprop(indices, values, dense_shape, out, dout): + return zeros_like(indices), dout, zeros_like(dense_shape) + + return bprop diff --git a/mindspore/ops/composite/multitype_ops/ones_like_impl.py b/mindspore/ops/composite/multitype_ops/ones_like_impl.py index d88193f84..840571c8b 100644 --- a/mindspore/ops/composite/multitype_ops/ones_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/ones_like_impl.py @@ -42,6 +42,16 @@ def _ones_like_tensor(x): return P.Fill()(P.DType()(x), P.Shape()(x), 1.0) +@ones_like_leaf.register("SparseTensor") +def _ones_like_sparse_tensor(x): + """Returns a tensor with the same shape and dtype as x and all elements are 1.""" + values_ = F.sparse_tensor_get_values(x) + values = P.Fill()(P.DType()(values_), + P.Shape()(values_), + 1.0) + return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x)) + + ones_like = base.HyperMap(ones_like_leaf) """ `ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype. diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 59290c323..0e2f8fab0 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -84,6 +84,7 @@ from ._quant_ops import * from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull) from .thor_ops import * +from .sparse_ops import SparseToDense __all__ = [ 'ReverseSequence', @@ -357,7 +358,8 @@ __all__ = [ "PopulationCount", "ParallelConcat", "Push", - "Pull" + "Pull", + 'SparseToDense', ] __all__.sort() diff --git a/mindspore/ops/operations/sparse_ops.py b/mindspore/ops/operations/sparse_ops.py new file mode 100644 index 000000000..6aaa7f271 --- /dev/null +++ b/mindspore/ops/operations/sparse_ops.py @@ -0,0 +1,55 @@ +# coding: utf-8 + +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. +# ============================================================================ + +"""Operators for sparse operators.""" + +from ..._checkparam import Validator as validator +from ...common import dtype as mstype +from ..primitive import PrimitiveWithInfer, prim_attr_register + +class SparseToDense(PrimitiveWithInfer): + """ + Convert a sparse representation into a dense tensor. + + Inputs: + - **indices** (Tensor) - The indices of sparse representation. + - **values** (Tensor) - Values corresponding to each row of indices. + - **dense_shape** (tuple) - A int tuple which specifies the shape of dense tensor. + + Returns: + Tensor, the shape of tensor is dense_shape. + + Examples: + >>> indices = Tensor([[0, 1], [1, 2]]) + >>> values = Tensor([1, 2], dtype=ms.float32) + >>> dense_shape = (3, 4) + >>> out = P.SparseToDense()(indices, values, dense_shape) + """ + + @prim_attr_register + def __init__(self): + """init index_select""" + self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output']) + + def __infer__(self, indices, values, dense_shape): + validator.check_subclass("indices", indices['dtype'], mstype.tensor, self.name) + validator.check_subclass("values", values['dtype'], mstype.tensor, self.name) + out = {'shape': dense_shape['value'], + 'dtype': values['dtype'], + 'value': None} + return out diff --git a/tests/ut/python/ir/test_sparse_tensor.py b/tests/ut/python/ir/test_sparse_tensor.py index a3096268c..4bf59529c 100644 --- a/tests/ut/python/ir/test_sparse_tensor.py +++ b/tests/ut/python/ir/test_sparse_tensor.py @@ -28,6 +28,7 @@ from mindspore import Tensor, SparseTensor, context context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) +grad_op = C.GradOperation('get_all', get_all=True) class MakeSparseTensor(nn.Cell): def __init__(self, dense_shape): @@ -45,15 +46,6 @@ def test_sparse_tensor_make_sparse_tensor(): def test_sparse_tensor_attr(): - grad_op = C.GradOperation('get_all', get_all=True) - class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - def construct(self, input1, input2): - gout = grad_op(self.network)(input1, input2) - return gout - class SparseTensorGetAttr(nn.Cell): def __init__(self): super(SparseTensorGetAttr, self).__init__() @@ -82,3 +74,20 @@ def test_sparse_tensor_indices_dim_less_than_dense_shape_dim(): dense_shape = (2, 2, 2) with pytest.raises(TypeError): MakeSparseTensor(dense_shape)(indices, values) + + +def test_sparse_tensor_to_tensor(): + class SparseToDenseCell(nn.Cell): + def __init__(self, dense_shape): + super(SparseToDenseCell, self).__init__() + self.dense_shape = dense_shape + self.sparse_to_dense = nn.SparseToDense() + def construct(self, indices, values): + sparse = SparseTensor(indices, values, self.dense_shape) + return self.sparse_to_dense(sparse) + + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + dense_shape = (3, 4) + SparseToDenseCell(dense_shape)(indices, values) + grad_op(SparseToDenseCell(dense_shape))(indices, values) diff --git a/tests/ut/python/pipeline/parse/test_cell_bprop.py b/tests/ut/python/pipeline/parse/test_cell_bprop.py index e896ddc9a..023073eea 100644 --- a/tests/ut/python/pipeline/parse/test_cell_bprop.py +++ b/tests/ut/python/pipeline/parse/test_cell_bprop.py @@ -102,7 +102,7 @@ def test_with_no_bprop(): with_no_bprop = WithNoBprop() x = Tensor(1, dtype=ms.int32) y = Tensor(2, dtype=ms.int32) - assert C.grad_all(with_no_bprop)(x, y) == (2, 1) + C.grad_all(with_no_bprop)(x, y) def test_grad_in_bprop_1(): @@ -263,10 +263,7 @@ def test_grad_inline_bprop_two_input(): net = InlineBpropTwoInput() input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input2 = Tensor(np.ones([2, 2]).astype(np.float32)) - grads = C.grad_all(net)(input1, input2) - assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - assert len(grads) == 2 + C.grad_all(net)(input1, input2) class TwoInputBprop(nn.Cell): @@ -350,24 +347,6 @@ def test_refkey_bprop(): assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() -class MulAddWithWrongOutputNum(nn.Cell): - def __init__(self): - super(MulAddWithWrongOutputNum, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - return (2 * dout,) - - -def test_grad_mul_add_with_wrong_output_num(): - context.set_context(check_bprop=True) - mul_add = MulAddWithWrongOutputNum() - with pytest.raises(TypeError): - C.grad_all(mul_add)(1, 2) - - class MulAddWithWrongOutputType(nn.Cell): def __init__(self): super(MulAddWithWrongOutputType, self).__init__() -- GitLab