diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index b9fc118f405efda771305938f4df79879f8543c8..37fb6e9db7b2eb2e0199be01193e5e5f62e007ff 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -27,6 +27,7 @@ __all__ = [ "broadcast_to", "concat", "cond_take", + "cumsum", "expand_dims", "eye", "flatten", @@ -1328,3 +1329,35 @@ def roll( if shp_bak is not None: out = out.reshape(shp_bak) return out + + +def cumsum(inp: Tensor, axis: int): + """ + Computes the cumulative sum of elements along given axis. + + :param inp: input tensor. + :param axis: axis along which cumsum is performed. + + Examples: + + .. testcode:: + + from megengine import tensor + import megengine.functional as F + + x = tensor([[1, 2, 3], [4, 5, 6]], "int32") + y = F.cumsum(x, 1) + print(y.numpy()) + + Outputs: + + .. testoutput:: + + [[ 1 3 6] + [ 4 9 15]] + + """ + assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" + assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis) + op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) + return apply(op, inp)[0] diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index abe07a00a8d037066594280551e11e698a1d5dd2..edb76da194df8ba2ceb585cfeee8f833358b8391 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -673,4 +673,16 @@ OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose) .fallback(); }} // sliding_window_transpose +namespace { +namespace cumsum { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + OperatorNodeConfig config{op.make_name()}; + return opr::Cumsum::make(inputs[0], op.param(), config); +} + +OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace cumsum +} // namespace + } // namespace mgb::imperative diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 7bb892e2eabe8fe0ac745b298555a061528a97d4..14aa42f835814f65a6c3d407efa4c79ace25f4b5 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -377,4 +377,6 @@ def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; def FastpathCopy: MgbHashableOp<"FastpathCopy">; +def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; + #endif // MGB_OPS