Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
dd1fecdf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
dd1fecdf
编写于
7月 30, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/opr): add cumsum
GitOrigin-RevId: 740f00a8e5c66253934c676c2fcdf02fe8e2d313
上级
a0c7e047
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
47 addition
and
0 deletion
+47
-0
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+33
-0
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+12
-0
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+2
-0
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
dd1fecdf
...
...
@@ -27,6 +27,7 @@ __all__ = [
"broadcast_to"
,
"concat"
,
"cond_take"
,
"cumsum"
,
"expand_dims"
,
"eye"
,
"flatten"
,
...
...
@@ -1328,3 +1329,35 @@ def roll(
if
shp_bak
is
not
None
:
out
=
out
.
reshape
(
shp_bak
)
return
out
def
cumsum
(
inp
:
Tensor
,
axis
:
int
):
"""
Computes the cumulative sum of elements along given axis.
:param inp: input tensor.
:param axis: axis along which cumsum is performed.
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
x = tensor([[1, 2, 3], [4, 5, 6]], "int32")
y = F.cumsum(x, 1)
print(y.numpy())
Outputs:
.. testoutput::
[[ 1 3 6]
[ 4 9 15]]
"""
assert
isinstance
(
inp
,
Tensor
),
"input of cumsum must be type of Tensor"
assert
axis
>=
0
and
axis
<
inp
.
ndim
,
"input axis {} out of bound"
.
format
(
axis
)
op
=
builtin
.
Cumsum
(
axis
=
axis
,
exclusive
=
False
,
reverse
=
False
)
return
apply
(
op
,
inp
)[
0
]
imperative/src/impl/ops/specializations.cpp
浏览文件 @
dd1fecdf
...
...
@@ -673,4 +673,16 @@ OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
.
fallback
();
}}
// sliding_window_transpose
namespace
{
namespace
cumsum
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Cumsum
&>
(
def
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Cumsum
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
Cumsum
,
Cumsum
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace cumsum
}
// namespace
}
// namespace mgb::imperative
src/core/include/megbrain/ir/ops.td
浏览文件 @
dd1fecdf
...
...
@@ -377,4 +377,6 @@ def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>;
def FastpathCopy: MgbHashableOp<"FastpathCopy">;
def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;
#endif // MGB_OPS
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录