Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d98be080
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d98be080
编写于
3月 09, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge): move Const into C++
GitOrigin-RevId: 31a443cffdc1b6d1470b5e0fd5ed49ab350cb4ff
上级
1709b394
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
71 addition
and
64 deletion
+71
-64
imperative/python/megengine/core/ops/special.py
imperative/python/megengine/core/ops/special.py
+0
-40
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+5
-5
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+2
-3
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+3
-4
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+3
-3
imperative/python/megengine/functional/tensor_cache.py
imperative/python/megengine/functional/tensor_cache.py
+3
-3
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+2
-2
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+2
-0
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+49
-4
imperative/python/src/tensor_utils.h
imperative/python/src/tensor_utils.h
+2
-0
未找到文件。
imperative/python/megengine/core/ops/special.py
已删除
100644 → 0
浏览文件 @
1709b394
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
from
.._imperative_rt
import
make_const
from
.._imperative_rt.core2
import
SymbolVar
,
Tensor
class
Const
:
def
__init__
(
self
,
value
=
None
,
*
,
dtype
=
None
,
device
=
None
):
self
.
value
=
np
.
asarray
(
value
,
dtype
=
dtype
)
self
.
dtype
=
dtype
self
.
device
=
device
def
__call__
(
self
,
*
reference
):
from
...tensor
import
Tensor
device
=
self
.
device
if
len
(
reference
)
!=
0
:
reference
=
reference
[
0
]
assert
isinstance
(
reference
,
(
SymbolVar
,
Tensor
)
),
"Reference should be Tensor or VarNode"
if
device
is
None
:
device
=
reference
.
device
if
isinstance
(
reference
,
SymbolVar
):
cls
=
type
(
reference
)
rst
=
cls
(
make_const
(
reference
.
graph
,
self
.
value
,
device
,
self
.
dtype
))
return
(
rst
,)
return
(
Tensor
(
self
.
value
,
self
.
dtype
,
self
.
device
,
True
),)
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
d98be080
...
@@ -14,6 +14,7 @@ import numpy as np
...
@@ -14,6 +14,7 @@ import numpy as np
from
.._imperative_rt
import
make_const
from
.._imperative_rt
import
make_const
from
.._imperative_rt.core2
import
(
from
.._imperative_rt.core2
import
(
Const
,
SymbolVar
,
SymbolVar
,
Tensor
,
Tensor
,
_get_convert_inputs
,
_get_convert_inputs
,
...
@@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported
...
@@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported
from
.._wrap
import
as_device
from
.._wrap
import
as_device
from
..autodiff.grad
import
Function
from
..autodiff.grad
import
Function
from
..ops
import
builtin
from
..ops
import
builtin
from
..ops.special
import
Const
from
.amp
import
_get_amp_high_prec_dtype
,
_get_amp_low_prec_dtype
from
.amp
import
_get_amp_high_prec_dtype
,
_get_amp_low_prec_dtype
from
.dtype
import
is_dtype_equal
,
is_quantize
from
.dtype
import
is_dtype_equal
,
is_quantize
...
@@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None):
...
@@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None):
if
not
is_quantize
(
v
.
dtype
):
if
not
is_quantize
(
v
.
dtype
):
v
=
astype
(
v
,
dtype
)
v
=
astype
(
v
,
dtype
)
else
:
else
:
(
v
,)
=
Const
(
v
,
dtype
=
dtype
,
device
=
device
)(
)
v
=
Const
(
v
,
dtype
,
device
,
None
)
return
v
return
v
...
@@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
...
@@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if
ndim
!=
0
and
ndim
!=
1
:
if
ndim
!=
0
and
ndim
!=
1
:
raise
ValueError
(
"ndim != 1 or 0, get : %d"
%
ndim
)
raise
ValueError
(
"ndim != 1 or 0, get : %d"
%
ndim
)
if
not
isinstance
(
x
,
(
Tensor
,
SymbolVar
)):
if
not
isinstance
(
x
,
(
Tensor
,
SymbolVar
)):
(
x
,)
=
Const
(
x
,
dtype
=
dtype
,
device
=
device
)(
*
reference
)
x
=
Const
(
x
,
dtype
,
device
,
reference
)
return
x
return
x
if
not
isinstance
(
x
,
collections
.
abc
.
Sequence
):
if
not
isinstance
(
x
,
collections
.
abc
.
Sequence
):
...
@@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
...
@@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if
dtype
is
not
None
:
if
dtype
is
not
None
:
x
=
astype
(
x
,
dtype
)
x
=
astype
(
x
,
dtype
)
return
x
return
x
(
x
,)
=
Const
(
x
,
dtype
=
dtype
,
device
=
device
)(
*
reference
)
x
=
Const
(
x
,
dtype
,
device
,
reference
)
return
x
return
x
...
@@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device):
...
@@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device):
return
results
return
results
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
return
Const
(
value
,
dtype
=
dtype
,
device
=
device
)()[
0
]
return
Const
(
value
,
dtype
,
device
,
None
)
outputs
,
outputs_has_grad
=
func
(
args
,
apply_expr
,
apply_const
)
outputs
,
outputs_has_grad
=
func
(
args
,
apply_expr
,
apply_const
)
outputs
=
[
outputs
=
[
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
d98be080
...
@@ -10,10 +10,9 @@ import collections
...
@@ -10,10 +10,9 @@ import collections
import
math
import
math
from
typing
import
Iterable
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.core2
import
Const
,
apply
,
dtype_promotion
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core.ops
import
builtin
from
..core.ops
import
builtin
from
..core.ops.special
import
Const
from
..core.tensor.array_method
import
_matmul
from
..core.tensor.array_method
import
_matmul
from
..core.tensor.utils
import
_normalize_axis
from
..core.tensor.utils
import
_normalize_axis
from
..tensor
import
Tensor
from
..tensor
import
Tensor
...
@@ -729,7 +728,7 @@ def topk(
...
@@ -729,7 +728,7 @@ def topk(
op
=
builtin
.
TopK
(
mode
=
mode
)
op
=
builtin
.
TopK
(
mode
=
mode
)
if
not
isinstance
(
k
,
Tensor
):
if
not
isinstance
(
k
,
Tensor
):
(
k
,)
=
Const
(
k
,
dtype
=
"int32"
,
device
=
inp
.
device
)(
)
k
=
Const
(
k
,
"int32"
,
inp
.
device
,
None
)
if
len
(
inp
.
shape
)
==
1
:
if
len
(
inp
.
shape
)
==
1
:
if
kth_only
:
if
kth_only
:
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
d98be080
...
@@ -11,7 +11,7 @@ from functools import lru_cache
...
@@ -11,7 +11,7 @@ from functools import lru_cache
from
typing
import
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
from
..core
import
_config
from
..core
import
_config
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.core2
import
Const
,
apply
,
dtype_promotion
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core._imperative_rt.ops
import
get_global_rng_seed
as
_get_global_rng_seed
from
..core._imperative_rt.ops
import
get_global_rng_seed
as
_get_global_rng_seed
from
..core.ops
import
builtin
from
..core.ops
import
builtin
...
@@ -26,7 +26,6 @@ from ..core.ops.builtin import (
...
@@ -26,7 +26,6 @@ from ..core.ops.builtin import (
Reshape
,
Reshape
,
TypeCvt
,
TypeCvt
,
)
)
from
..core.ops.special
import
Const
from
..core.tensor
import
amp
,
megbrain_graph
from
..core.tensor
import
amp
,
megbrain_graph
from
..core.tensor.array_method
import
_elwise_apply
from
..core.tensor.array_method
import
_elwise_apply
from
..core.tensor.utils
import
(
from
..core.tensor.utils
import
(
...
@@ -1317,7 +1316,7 @@ def batch_norm(
...
@@ -1317,7 +1316,7 @@ def batch_norm(
raise
ValueError
(
"Invalid param_dim {}"
.
format
(
param_dim
))
raise
ValueError
(
"Invalid param_dim {}"
.
format
(
param_dim
))
if
x
is
None
:
if
x
is
None
:
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)(
)
x
=
Const
(
value
,
inp
.
dtype
,
inp
.
device
,
None
)
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
shape
)
return
result
return
result
...
@@ -1541,7 +1540,7 @@ def sync_batch_norm(
...
@@ -1541,7 +1540,7 @@ def sync_batch_norm(
def
_make_full_if_none
(
x
,
value
):
def
_make_full_if_none
(
x
,
value
):
if
x
is
None
:
if
x
is
None
:
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
_device
)(
)
x
=
Const
(
value
,
inp
.
dtype
,
_device
,
None
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
reduce_shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
reduce_shape
)
return
result
return
result
elif
x
.
ndim
==
1
:
elif
x
.
ndim
==
1
:
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
d98be080
...
@@ -13,6 +13,7 @@ import numpy as np
...
@@ -13,6 +13,7 @@ import numpy as np
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
(
from
..core._imperative_rt.core2
import
(
Const
,
SymbolVar
,
SymbolVar
,
apply
,
apply
,
broadcast_cpp
,
broadcast_cpp
,
...
@@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import (
...
@@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import (
from
..core._wrap
import
as_device
from
..core._wrap
import
as_device
from
..core.ops
import
builtin
from
..core.ops
import
builtin
from
..core.ops.builtin
import
Copy
,
Identity
from
..core.ops.builtin
import
Copy
,
Identity
from
..core.ops.special
import
Const
from
..core.tensor.utils
import
astensor1d
,
convert_inputs
,
get_device
,
subgraph_fn
from
..core.tensor.utils
import
astensor1d
,
convert_inputs
,
get_device
,
subgraph_fn
from
..device
import
get_default_device
from
..device
import
get_default_device
from
..tensor
import
Tensor
from
..tensor
import
Tensor
...
@@ -177,7 +177,7 @@ def full(
...
@@ -177,7 +177,7 @@ def full(
shape
=
(
shape
,)
shape
=
(
shape
,)
if
device
is
None
:
if
device
is
None
:
device
=
get_default_device
()
device
=
get_default_device
()
(
x
,)
=
Const
(
value
,
dtype
=
dtype
,
device
=
device
)(
)
x
=
Const
(
value
,
dtype
,
device
,
None
)
if
type
(
shape
)
in
(
list
,
tuple
)
and
len
(
shape
)
==
0
:
if
type
(
shape
)
in
(
list
,
tuple
)
and
len
(
shape
)
==
0
:
return
x
return
x
return
broadcast_to
(
x
,
shape
)
return
broadcast_to
(
x
,
shape
)
...
@@ -325,7 +325,7 @@ def full_like(
...
@@ -325,7 +325,7 @@ def full_like(
[2 2 2]]
[2 2 2]]
"""
"""
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)(
inp
)
x
=
Const
(
value
,
inp
.
dtype
,
inp
.
device
,
inp
)
if
inp
.
ndim
==
0
:
if
inp
.
ndim
==
0
:
return
x
return
x
return
broadcast_to
(
x
,
inp
.
shape
)
return
broadcast_to
(
x
,
inp
.
shape
)
...
...
imperative/python/megengine/functional/tensor_cache.py
浏览文件 @
d98be080
from
..core.
ops.special
import
Const
from
..core.
_imperative_rt.core2
import
Const
from
..jit.tracing
import
is_tracing
from
..jit.tracing
import
is_tracing
small_tensor_cache
=
{}
small_tensor_cache
=
{}
...
@@ -7,11 +7,11 @@ small_tensor_cache = {}
...
@@ -7,11 +7,11 @@ small_tensor_cache = {}
def
_get_scalar_tensor_with_value
(
value
,
dtype
=
None
,
device
=
None
):
def
_get_scalar_tensor_with_value
(
value
,
dtype
=
None
,
device
=
None
):
global
small_tensor_cache
global
small_tensor_cache
if
is_tracing
():
if
is_tracing
():
(
ret
,)
=
Const
(
value
,
dtype
=
dtype
,
device
=
device
)(
)
ret
=
Const
(
value
,
dtype
,
device
,
None
)
else
:
else
:
cache_key
=
(
value
,
dtype
,
device
)
cache_key
=
(
value
,
dtype
,
device
)
if
cache_key
not
in
small_tensor_cache
:
if
cache_key
not
in
small_tensor_cache
:
(
ret
,)
=
Const
(
value
,
dtype
=
dtype
,
device
=
device
)(
)
ret
=
Const
(
value
,
dtype
,
device
,
None
)
small_tensor_cache
[
cache_key
]
=
ret
small_tensor_cache
[
cache_key
]
=
ret
else
:
else
:
ret
=
small_tensor_cache
[
cache_key
]
ret
=
small_tensor_cache
[
cache_key
]
...
...
imperative/python/megengine/traced_module/expr.py
浏览文件 @
d98be080
...
@@ -16,6 +16,7 @@ from importlib import import_module
...
@@ -16,6 +16,7 @@ from importlib import import_module
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Union
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt.core2
import
Const
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
(
from
..core._imperative_rt.core2
import
(
apply
,
apply
,
...
@@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import (
...
@@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import (
unset_module_tracing
,
unset_module_tracing
,
)
)
from
..core.ops.builtin
import
FakeQuant
from
..core.ops.builtin
import
FakeQuant
from
..core.ops.special
import
Const
from
..module
import
Module
from
..module
import
Module
from
..tensor
import
Parameter
,
Tensor
from
..tensor
import
Parameter
,
Tensor
from
..version
import
__version__
from
..version
import
__version__
...
@@ -764,7 +764,7 @@ class Constant(Expr):
...
@@ -764,7 +764,7 @@ class Constant(Expr):
def
interpret
(
self
,
*
inputs
):
def
interpret
(
self
,
*
inputs
):
if
isinstance
(
self
.
value
,
RawTensor
):
if
isinstance
(
self
.
value
,
RawTensor
):
return
Const
(
self
.
value
.
numpy
())(
)
return
(
Const
(
self
.
value
.
numpy
(),
None
,
None
,
None
),
)
return
(
self
.
value
,)
return
(
self
.
value
,)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
...
imperative/python/src/tensor.cpp
浏览文件 @
d98be080
...
@@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp);
...
@@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp);
WRAP_FUNC_PY35
(
transpose_cpp
);
WRAP_FUNC_PY35
(
transpose_cpp
);
WRAP_FUNC_PY35
(
broadcast_cpp
);
WRAP_FUNC_PY35
(
broadcast_cpp
);
WRAP_FUNC_PY35
(
reshape_cpp
);
WRAP_FUNC_PY35
(
reshape_cpp
);
WRAP_FUNC_PY35
(
Const
);
#undef WRAP_FUNC_PY35
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
...
@@ -777,6 +778,7 @@ void init_tensor(py::module m) {
...
@@ -777,6 +778,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE
(
transpose_cpp
,
transpose_cpp
),
MGE_PY_INTERFACE
(
transpose_cpp
,
transpose_cpp
),
MGE_PY_INTERFACE
(
broadcast_cpp
,
broadcast_cpp
),
MGE_PY_INTERFACE
(
broadcast_cpp
,
broadcast_cpp
),
MGE_PY_INTERFACE
(
reshape_cpp
,
reshape_cpp
),
MGE_PY_INTERFACE
(
reshape_cpp
,
reshape_cpp
),
MGE_PY_INTERFACE
(
Const
,
Const
),
{
nullptr
,
nullptr
,
0
,
nullptr
}};
{
nullptr
,
nullptr
,
0
,
nullptr
}};
for
(
auto
&&
def
:
method_defs
)
{
for
(
auto
&&
def
:
method_defs
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
d98be080
...
@@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) {
...
@@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) {
}
}
py
::
object
_Const
(
py
::
object
_Const
(
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
,
py
::
handle
ref
)
{
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
,
py
::
handle
ref
_hdl
)
{
py
::
object
val
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
py
::
object
val
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
if
(
PyArray_Check
(
value
.
ptr
()))
{
if
(
PyArray_Check
(
value
.
ptr
()))
{
py
::
tuple
strides
=
py
::
tuple
strides
=
...
@@ -107,21 +107,56 @@ py::object _Const(
...
@@ -107,21 +107,56 @@ py::object _Const(
}
}
if
(
need_squeeze
)
{
if
(
need_squeeze
)
{
val
=
py
::
reinterpret_borrow
<
py
::
array
>
(
value
);
val
=
py
::
reinterpret_borrow
<
py
::
array
>
(
value
);
py
::
object
orig_shp
=
val
.
attr
(
"shape"
);
val
=
val
.
attr
(
"squeeze"
)();
val
=
val
.
attr
(
"squeeze"
)();
val
=
val
.
attr
(
"reshape"
)(
val
.
attr
(
"shape"
)
);
val
=
val
.
attr
(
"reshape"
)(
orig_shp
);
}
}
}
}
py
::
object
ref
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
ref_hdl
))
{
py
::
tuple
tup
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
ref_hdl
);
if
(
tup
.
size
())
{
ref
=
tup
[
0
];
}
else
{
ref
=
py
::
none
();
}
}
else
{
ref
=
py
::
reinterpret_borrow
<
py
::
object
>
(
ref_hdl
);
}
if
(
py
::
isinstance
<
PySymbolVar
>
(
ref
))
{
if
(
py
::
isinstance
<
PySymbolVar
>
(
ref
))
{
auto
ref_var
=
ref
.
cast
<
PySymbolVar
*>
();
auto
ref_var
=
ref
.
cast
<
PySymbolVar
*>
();
auto
*
graph
=
ref_var
->
m_node
->
owner_graph
();
auto
*
graph
=
ref_var
->
m_node
->
owner_graph
();
auto
cn
=
device
.
cast
<
CompNode
>
();
CompNode
cn
;
if
(
device
.
ptr
()
==
Py_None
)
{
cn
=
ref_var
->
m_node
->
comp_node
();
}
else
{
cn
=
device
.
cast
<
CompNode
>
();
}
OperatorNodeConfig
config
(
cn
);
OperatorNodeConfig
config
(
cn
);
auto
hv
=
npy
::
np2tensor
(
auto
hv
=
npy
::
np2tensor
(
val
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
.
cast
<
mgb
::
DType
>
());
val
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
.
cast
<
mgb
::
DType
>
());
auto
typeobj
=
ref
.
get_type
();
auto
typeobj
=
ref
.
get_type
();
return
typeobj
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
).
node
());
return
typeobj
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
).
node
());
}
}
py
::
tuple
tup
=
py
::
make_tuple
(
val
,
dtype
,
device
,
true
,
false
,
py
::
none
());
py
::
object
device_obj
;
if
(
device
.
ptr
()
==
Py_None
)
{
device_obj
=
py
::
cast
(
CompNode
::
load
(
get_default_device
()));
}
else
if
(
py
::
isinstance
<
py
::
str
>
(
device
))
{
py
::
object
dmap
=
getattr
(
py
::
reinterpret_borrow
<
py
::
object
>
((
PyObject
*
)
py_tensor_type
),
"dmap_callback"
);
if
(
dmap
.
ptr
()
!=
Py_None
)
{
device_obj
=
dmap
(
device
);
py
::
print
(
device_obj
);
}
else
{
device_obj
=
py
::
cast
(
CompNode
::
load
(
device
.
cast
<
std
::
string
>
()));
}
}
else
if
(
py
::
isinstance
<
CompNode
>
(
device
))
{
device_obj
=
py
::
reinterpret_borrow
<
py
::
object
>
(
device
);
}
else
{
device_obj
=
getattr
(
device
,
"_cn"
);
}
py
::
tuple
tup
=
py
::
make_tuple
(
val
,
dtype
,
device_obj
,
true
,
false
,
py
::
none
());
return
TensorWrapper
::
make
(
py_tensor_type
,
tup
.
ptr
(),
nullptr
);
return
TensorWrapper
::
make
(
py_tensor_type
,
tup
.
ptr
(),
nullptr
);
}
}
...
@@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
...
@@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
}
PyObject
*
Const
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_Const
(
py
::
handle
(
args
[
0
]),
py
::
handle
(
args
[
1
]),
py
::
handle
(
args
[
2
]),
py
::
handle
(
args
[
3
]))
.
release
()
.
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
imperative/python/src/tensor_utils.h
浏览文件 @
d98be080
...
@@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs);
...
@@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject
*
reshape_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
reshape_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
Const
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录