Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d3247bee
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d3247bee
编写于
6月 13, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dtr): always write shape when tensor produced
GitOrigin-RevId: d2b23b5c25bb509456b3d77a56f20ec985efa458
上级
0a266d7a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
10 deletion
+26
-10
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+5
-10
imperative/src/impl/interpreter/tensor_info.h
imperative/src/impl/interpreter/tensor_info.h
+21
-0
未找到文件。
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
d3247bee
...
...
@@ -312,7 +312,7 @@ void ChannelImpl::dispatch_default_cpu(
HostTensorND
::
make_proxy
(
tensornd
).
proxy_to_comp_node
(
output_cn
);
// use `put` for consistency
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
put_impl
(
host_tensornd
,
false
));
mgb_assert
(
info
->
desc
.
layout
.
ndim
!=
0
);
mgb_assert
(
info
->
shape_valid
()
);
output_infos
.
push_back
(
info
);
outputs
->
push_back
(
reinterpret_cast
<
Handle
>
(
info
));
}
...
...
@@ -406,7 +406,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
MGB_LOCK_GUARD
(
m_spin
);
mgb_assert
(
check_available
(),
"Channel already closed"
);
auto
*
input
=
reinterpret_cast
<
TensorInfo
*>
(
inputs
[
0
]);
if
(
op
->
same_type
<
GetVarShape
>
()
&&
input
->
desc
.
layout
.
ndim
)
{
if
(
op
->
same_type
<
GetVarShape
>
()
&&
input
->
shape_valid
()
)
{
size_t
ndim
=
input
->
desc
.
layout
.
ndim
;
auto
&
gvs
=
op
->
cast_final_safe
<
GetVarShape
>
();
if
(
gvs
.
axis
==
MEGDNN_MAX_NDIM
)
{
...
...
@@ -477,11 +477,11 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
if
(
info
->
desc
.
layout
.
ndim
!=
0
)
{
if
(
info
->
shape_valid
()
)
{
return
info
->
desc
.
layout
;
}
TensorShape
ret
=
wait_tensor
(
info
,
TensorProp
::
Shape
)
->
layout
();
mgb_assert
(
ret
.
ndim
!=
0
);
mgb_assert
(
ret
.
ndim
>
0
);
return
ret
;
}
...
...
@@ -694,12 +694,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
TensorProduceEvent
,
dest
->
id
,
ptr
->
layout
(),
ptr
->
comp_node
(),
ptr
->
raw_ptr_not_for_readwrite
());
// update tensor desc for static infer
if
(
dest
->
desc
.
layout
.
ndim
)
{
mgb_assert
(
dest
->
desc
.
layout
.
eq_shape
(
ptr
->
layout
()),
"shape infer error, %s vs %s"
,
dest
->
desc
.
layout
.
to_string
().
c_str
(),
ptr
->
layout
().
to_string
().
c_str
());
}
dest
->
update_layout
(
ptr
->
layout
());
// in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled
if
(
state
.
options
.
enable_dtr_auto_drop
||
state
.
options
.
disable_memory_forwarding
)
{
...
...
imperative/src/impl/interpreter/tensor_info.h
浏览文件 @
d3247bee
...
...
@@ -48,6 +48,7 @@ struct TensorInfo {
// Lock interpreter when visiting `ptr`.
TensorPtr
ptr
;
LogicalTensorDesc
desc
;
Spinlock
lock
;
double
compute_time
;
size_t
memory
;
...
...
@@ -158,6 +159,26 @@ struct TensorInfo {
// UINT_MAX as a magic default value
size_t
cand_index
=
UINT_MAX
;
bool
shape_valid
()
{
MGB_LOCK_GUARD
(
lock
);
return
desc
.
layout
.
ndim
;
}
void
update_layout
(
const
TensorLayout
&
layout
)
{
MGB_LOCK_GUARD
(
lock
);
mgb_assert
(
desc
.
layout
.
dtype
==
layout
.
dtype
,
"dtype mismatch"
);
mgb_assert
(
desc
.
layout
.
format
==
layout
.
format
,
"format mismatch"
);
if
(
desc
.
layout
.
ndim
)
{
mgb_assert
(
desc
.
layout
.
eq_shape
(
layout
),
"shape infer error, %s vs %s"
,
desc
.
layout
.
to_string
().
c_str
(),
layout
.
to_string
().
c_str
());
// ignore strides
}
else
{
static_cast
<
TensorShape
&>
(
desc
.
layout
)
=
layout
;
desc
.
layout
.
init_contiguous_stride
();
}
}
};
}
// namespace interpreter::intl
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录