Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4f77509e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4f77509e
编写于
4月 28, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): allow empty ImmutableTensor
Fixes MGE-675. GitOrigin-RevId: c6771740fc48226f1b7c79d519de61445e671290
上级
6a7e7ce1
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
29 addition
and
11 deletion
+29
-11
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+2
-1
src/opr/impl/io.cpp
src/opr/impl/io.cpp
+16
-10
src/opr/test/io.cpp
src/opr/test/io.cpp
+11
-0
未找到文件。
src/gopt/impl/framework.cpp
浏览文件 @
4f77509e
...
...
@@ -761,7 +761,8 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr(
if
(
is_const_var
(
m_const_var_type
,
opr
))
{
auto
sz
=
var_mem_size
(
opr
->
output
(
0
));
mgb_assert
(
sz
);
mgb_assert
(
sz
||
opr
->
output
(
0
)
->
contain_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
));
info
.
is_const
=
true
;
info
.
max_size
=
sz
;
return
make_ret
();
...
...
src/opr/impl/io.cpp
浏览文件 @
4f77509e
...
...
@@ -382,7 +382,7 @@ class ImmutableTensor::Value {
void
setup
(
CompNode
cn
,
const
HostTensorND
&
val
);
bool
initialized
()
const
{
return
!
m_dev
.
empty
();
return
m_dev
.
shape_valid
();
}
//! value on comp node
...
...
@@ -400,8 +400,9 @@ class ImmutableTensor::Value {
};
void
ImmutableTensor
::
Value
::
setup
(
CompNode
cn
,
const
HostTensorND
&
val
)
{
mgb_assert
(
m_dev
.
empty
()
&&
!
val
.
empty
());
mgb_assert
(
m_dev
.
empty
()
&&
!
m_dev
.
shape_valid
());
m_dev
.
comp_node
(
cn
).
copy_from
(
val
).
sync
();
mgb_assert
(
val
.
empty
()
==
m_dev
.
empty
());
auto
one_elem
=
[](
const
TensorShape
&
shape
)
{
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
...
...
@@ -446,6 +447,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
HostTensorND
m_val_ref
;
const
dt_byte
*
val_ptr
()
const
{
mgb_assert
(
m_trait
.
size_bytes
);
return
m_val
.
empty
()
?
m_val_ref
.
raw_ptr
()
:
m_val
.
data
();
}
...
...
@@ -454,9 +456,8 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
TensorKey
(
const
HostTensorND
&
v
)
:
m_val_ref
{
v
}
{
mgb_assert
(
v
.
layout
().
is_contiguous
());
mgb_assert
(
v
.
layout
().
is_contiguous
()
||
v
.
layout
().
is_empty
()
);
m_trait
.
size_bytes
=
v
.
layout
().
span
().
high_byte
;
mgb_assert
(
m_trait
.
size_bytes
);
auto
&&
layout
=
m_trait
.
layout
;
// zero to enable byte-comparison
...
...
@@ -467,15 +468,19 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
layout
.
shape
[
i
]
=
v
.
layout
().
shape
[
i
];
layout
.
stride
[
i
]
=
v
.
layout
().
stride
[
i
];
}
m_trait
.
hash
=
XXHash
{}.
update
(
v
.
raw_ptr
(),
m_trait
.
size_bytes
).
update
(
&
m_trait
.
layout
,
sizeof
(
m_trait
.
layout
)).
digest
();
XXHash
hasher
;
if
(
!
v
.
empty
())
{
hasher
.
update
(
v
.
raw_ptr
(),
m_trait
.
size_bytes
);
}
hasher
.
update
(
&
m_trait
.
layout
,
sizeof
(
m_trait
.
layout
));
m_trait
.
hash
=
hasher
.
digest
();
}
bool
operator
==
(
const
TensorKey
&
rhs
)
const
{
return
!
memcmp
(
&
m_trait
,
&
rhs
.
m_trait
,
sizeof
(
Trait
))
&&
!
memcmp
(
val_ptr
(),
rhs
.
val_ptr
(),
m_trait
.
size_bytes
);
((
m_trait
.
size_bytes
==
0
&&
rhs
.
m_trait
.
size_bytes
==
0
)
||
!
memcmp
(
val_ptr
(),
rhs
.
val_ptr
(),
m_trait
.
size_bytes
));
}
size_t
hash
()
const
{
...
...
@@ -485,6 +490,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
//! copy from m_val_ref to m_val, to avoid refed value being
//! modified
void
copy_val_permanent
()
{
if
(
m_trait
.
size_bytes
==
0
)
return
;
mgb_assert
(
m_val
.
empty
());
m_val
.
resize
(
m_trait
.
size_bytes
);
memcpy
(
m_val
.
data
(),
m_val_ref
.
raw_ptr
(),
m_trait
.
size_bytes
);
...
...
@@ -544,7 +550,6 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
}
const
Value
&
get
(
const
HostTensorND
&
tensor
)
{
mgb_assert
(
!
tensor
.
empty
());
if
(
tensor
.
shape
().
is_scalar
())
{
return
get
(
DTypeScalar
::
make_from_raw
(
tensor
.
dtype
(),
tensor
.
raw_ptr
()));
...
...
@@ -595,6 +600,7 @@ ImmutableTensor::ImmutableTensor(ComputingGraph &graph,
add_output
(
value
.
dev
().
dtype
());
add_equivalence_component
<
ScalarHash
<
const
void
*>>
(
&
value
);
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
ImmutableTensor
::~
ImmutableTensor
()
noexcept
=
default
;
...
...
src/opr/test/io.cpp
浏览文件 @
4f77509e
...
...
@@ -177,6 +177,17 @@ TEST(TestOprIO, ImmutableTensorLarge) {
}
}
TEST
(
TestOprIO
,
ImmutableTensorEmpty
)
{
HostTensorGenerator
<>
gen
;
auto
graph
=
ComputingGraph
::
make
();
auto
host_x
=
gen
({
1
,
9
,
1
,
9
,
8
,
1
,
0
});
auto
x
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
*
host_x
);
HostTensorND
host_x2
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
x
,
host_x2
)});
func
->
execute
();
ASSERT_TRUE
(
host_x2
.
shape
().
is_empty
());
}
TEST
(
TestOprIO
,
SharedDeviceTensor
)
{
HostTensorGenerator
<>
gen
;
auto
hv
=
gen
({
123
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录