Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2484cd27
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2484cd27
编写于
6月 20, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(tensor): check args when construct tensor with existing tensor
GitOrigin-RevId: 03454540707f42d409fdfdf88b5c044c56cf43b5
上级
e7587617
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
56 addition
and
4 deletion
+56
-4
imperative/python/megengine/optimizer/optimizer.py
imperative/python/megengine/optimizer/optimizer.py
+1
-1
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+29
-2
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+1
-1
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+15
-0
imperative/src/impl/transformations/eval.cpp
imperative/src/impl/transformations/eval.cpp
+5
-0
imperative/src/include/megbrain/imperative/basic_operators.h
imperative/src/include/megbrain/imperative/basic_operators.h
+5
-0
未找到文件。
imperative/python/megengine/optimizer/optimizer.py
浏览文件 @
2484cd27
...
...
@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
+
str
(
type
(
param
))
)
param
.
_reset
(
Tensor
(
param
.
numpy
(),
no_cache
=
True
,
format
=
param
.
format
)
)
param
[...]
=
Tensor
(
param
.
numpy
(),
no_cache
=
True
)
for
name
,
default
in
self
.
_defaults
.
items
():
if
default
is
required
and
name
not
in
param_group
:
...
...
imperative/python/src/tensor.cpp
浏览文件 @
2484cd27
...
...
@@ -525,7 +525,34 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
mgb_assert
(
tup
.
size
()
==
7
);
if
(
auto
*
t
=
try_cast
(
tup
[
0
].
ptr
()))
{
m_tensor
=
t
->
m_tensor
->
copy
();
m_tensor
=
t
->
m_tensor
;
// TODO: merge two path in arg parse
if
(
!
tup
[
1
].
is_none
())
{
auto
dtype
=
tup
[
1
].
cast
<
DType
>
();
mgb_assert
(
dtype
==
m_tensor
->
dtype
(),
"dtype mismatch: %s vs %s"
,
dtype
.
name
(),
m_tensor
->
dtype
().
name
());
}
if
(
!
tup
[
2
].
is_none
())
{
auto
device
=
as_comp_node
(
tup
[
2
]);
mgb_assert
(
device
==
m_tensor
->
comp_node
(),
"device mismatch: %s vs %s"
,
device
.
to_string
().
c_str
(),
m_tensor
->
comp_node
().
to_string
().
c_str
());
}
mgb_assert
(
!
tup
[
3
].
cast
<
bool
>
(),
"expect is_const == False, got True"
);
bool
no_cache
=
tup
[
4
].
cast
<
bool
>
();
if
(
no_cache
)
{
// always copy because it's hard to tell whether this tensor is cached
m_tensor
=
m_tensor
->
copy
();
}
// ignore name
if
(
!
tup
[
6
].
is_none
())
{
Format
format
=
tup
[
6
].
cast
<
std
::
string
>
();
mgb_assert
(
format
==
m_tensor
->
format
(),
"format mismatch: %s vs %s"
,
format
.
to_string
().
c_str
(),
m_tensor
->
format
().
to_string
().
c_str
());
}
}
else
{
auto
data
=
tup
[
0
];
DType
dtype
=
tup
[
1
].
cast
<
DType
>
();
...
...
@@ -1030,7 +1057,7 @@ void init_tensor(py::module m) {
try
{
self
.
compiled
->
compile
();
}
catch
(
const
std
::
exception
&
e
)
{
mgb_log_error
(
e
.
what
());
mgb_log_error
(
"error in trace: %s"
,
e
.
what
());
}
}
// register transformations
...
...
imperative/python/src/tensor.h
浏览文件 @
2484cd27
...
...
@@ -47,7 +47,7 @@ public:
~
Tensor
()
=
default
;
inline
Tensor
copy
()
{
return
*
this
;
}
inline
Tensor
copy
()
{
return
Tensor
(
imperative
::
apply
(
DupTensor
(),
data
())[
0
])
;
}
inline
DType
dtype
()
{
return
*
data
().
dtype
();
}
inline
CompNode
comp_node
()
{
return
*
data
().
device
();
}
...
...
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
2484cd27
...
...
@@ -5,7 +5,9 @@ import numpy as np
import
pytest
from
utils
import
get_var_value
,
make_tensor
from
megengine
import
_full_sync
from
megengine.core.tensor.dtype
import
get_scale
,
get_zero_point
,
qint8
,
quint8
from
megengine.device
import
get_default_device
from
megengine.tensor
import
Parameter
,
Tensor
from
megengine.utils.network
import
Network
...
...
@@ -220,3 +222,16 @@ def test_tensor_from_bool():
assert
x
.
dtype
==
np
.
bool_
x
=
Tensor
([
True
,
False
])
assert
x
.
dtype
==
np
.
bool_
def
test_tensor_construct_tensor
():
x
=
Tensor
(
0
,
dtype
=
np
.
float32
,
device
=
"xpu0:1"
,
name
=
"MyName"
)
assert
Tensor
(
x
.
astype
(
np
.
int32
)).
dtype
==
np
.
int32
with
pytest
.
raises
(
RuntimeError
):
Tensor
(
x
.
astype
(
np
.
int32
),
dtype
=
np
.
float32
)
assert
Tensor
(
x
).
name
==
""
assert
Tensor
(
x
,
name
=
"MyName2"
).
name
==
"MyName2"
with
pytest
.
raises
(
RuntimeError
):
assert
Tensor
(
x
.
to
(
"xpu0:2"
),
device
=
"xpu0:1"
).
device
==
"xpu0:1"
assert
Tensor
(
x
.
to
(
"xpu0:2"
)).
device
==
x
.
to
(
"xpu0:2"
).
device
_full_sync
()
imperative/src/impl/transformations/eval.cpp
浏览文件 @
2484cd27
...
...
@@ -126,6 +126,11 @@ ValueRefList InterpreterTransformation::apply_transformation(
}
else
{
return
{
ValueRef
()};
}
}
else
if
(
op
.
is
<
DupTensor
>
())
{
auto
&
input
=
inputs
[
0
].
cast
(
m_value_type
);
DeviceTensorND
dev_tensor
;
dev_tensor
.
copy_from
(
m_channel
->
get_dev_tensor
(
input
.
handle
()
->
handle
()));
return
m_value_type
.
make
(
share_handle
(
m_channel
->
put
(
dev_tensor
,
{})));
}
else
{
return
op
.
fallback
(
inputs
);
}
...
...
imperative/src/include/megbrain/imperative/basic_operators.h
浏览文件 @
2484cd27
...
...
@@ -196,5 +196,10 @@ public:
std
::
string
to_string
()
const
override
;
};
class
DupTensor
final
:
public
OperatorImpl
<
DupTensor
,
Operator
::
IdentityLike
>
{
public:
std
::
string
to_string
()
const
override
{
return
"DupTensor"
;
}
};
}
// namespace imperative
}
// namespace mgb
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录