提交 dd1fecdf 编写于 作者: M Megvii Engine Team

feat(mge/opr): add cumsum

GitOrigin-RevId: 740f00a8e5c66253934c676c2fcdf02fe8e2d313
上级 a0c7e047
......@@ -27,6 +27,7 @@ __all__ = [
"broadcast_to",
"concat",
"cond_take",
"cumsum",
"expand_dims",
"eye",
"flatten",
......@@ -1328,3 +1329,35 @@ def roll(
if shp_bak is not None:
out = out.reshape(shp_bak)
return out
def cumsum(inp: Tensor, axis: int):
"""
Computes the cumulative sum of elements along given axis.
:param inp: input tensor.
:param axis: axis along which cumsum is performed.
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
x = tensor([[1, 2, 3], [4, 5, 6]], "int32")
y = F.cumsum(x, 1)
print(y.numpy())
Outputs:
.. testoutput::
[[ 1 3 6]
[ 4 9 15]]
"""
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
return apply(op, inp)[0]
......@@ -673,4 +673,16 @@ OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
.fallback();
}} // sliding_window_transpose
namespace {
namespace cumsum {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Cumsum&>(def);
OperatorNodeConfig config{op.make_name()};
return opr::Cumsum::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
} // namespace cumsum
} // namespace
} // namespace mgb::imperative
......@@ -377,4 +377,6 @@ def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>;
def FastpathCopy: MgbHashableOp<"FastpathCopy">;
def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;
#endif // MGB_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册