Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
54633b92
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
54633b92
编写于
12月 29, 2022
作者:
M
Megvii Engine Team
提交者:
黄信达
1月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): fix the split dump problem
GitOrigin-RevId: 0a0265e59819a89b12853229eff9d14d1e55ace6
上级
6da3de19
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
39 addition
and
3 deletion
+39
-3
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+17
-3
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+22
-0
未找到文件。
imperative/python/src/tensor_utils.cpp
浏览文件 @
54633b92
...
...
@@ -400,7 +400,6 @@ py::object get_res_by_refhdl(
ref
=
py
::
reinterpret_borrow
<
py
::
object
>
(
ref_hdl
);
}
if
(
PyObject_TypeCheck
(
ref
.
ptr
(),
py_varnode_type
))
{
auto
temp
=
dtype
.
cast
<
mgb
::
DType
>
();
ComputingGraph
*
graph
=
getattr
(
ref
,
"graph"
).
cast
<
ComputingGraph
*>
();
cg
::
VarNode
*
node
=
getattr
(
ref
,
"var"
).
cast
<
cg
::
VarNode
*>
();
CompNode
cn
;
...
...
@@ -1473,8 +1472,23 @@ py::object _split_cpp(
std
::
to_string
(
axis
)
+
" cannot be split into "
+
std
::
to_string
(
n_sections
)
+
" sections"
);
}
op
=
Split
::
make
(
axis
,
n_sections
);
p
.
resize
(
2
);
if
(
enable_fastpath
(
inp_hdl
))
{
op
=
Split
::
make
(
axis
,
n_sections
);
p
.
resize
(
2
);
}
else
{
size_t
n_total_
=
n_total
.
cast
<
int
>
();
for
(
size_t
idx
=
0
;
idx
<
n_sections
;
++
idx
)
{
auto
section_size
=
(
n_total_
+
n_sections
-
idx
-
1
)
/
n_sections
;
partitions
.
append
(
_Const
(
py
::
int_
(
section_size
),
py
::
cast
((
mgb
::
DType
)
dtype
::
Int32
()),
getattr
(
inp_hdl
,
"device"
)));
}
op
=
Split
::
make
(
axis
,
0
);
p
.
resize
(
partitions
.
size
()
+
2
);
for
(
size_t
idx
=
0
;
idx
<
partitions
.
size
();
++
idx
)
{
p
[
idx
+
2
]
=
partitions
[
idx
].
ptr
();
}
}
}
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
54633b92
...
...
@@ -308,6 +308,28 @@ def test_dump_with_testcase():
f
.
dump
(
file
,
input_data
=
[
"#rand(0, 255, 1)"
])
def
test_split_dump
():
class
SimpleNet
(
Module
):
def
__init__
(
self
,
num_segments
:
int
=
3
):
super
().
__init__
()
self
.
num_segments
=
num_segments
def
forward
(
self
,
x
):
x
=
F
.
split
(
x
,
self
.
num_segments
,
axis
=
1
)
return
x
model
=
SimpleNet
()
model
.
eval
()
data
=
tensor
(
np
.
random
.
random
((
1
,
12
,
224
,
224
)))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fun
(
data
,
*
,
net
):
return
net
(
data
)
x
=
fun
(
data
,
net
=
model
)
fun
.
dump
(
io
.
BytesIO
(),
arg_names
=
[
"data"
])
@
pytest
.
mark
.
parametrize
(
"trace_mode"
,
[
False
,
True
])
def
test_trace_profiler
(
trace_mode
):
@
trace
(
symbolic
=
trace_mode
,
profiling
=
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录