Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4ae9dd00
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4ae9dd00
编写于
6月 26, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): add external transform
GitOrigin-RevId: e8e3ebe9c86afc9fb97900b5b5af9778cc1354e5
上级
9914129a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
332 addition
and
17 deletion
+332
-17
imperative/python/megengine/jit/xla_backend.py
imperative/python/megengine/jit/xla_backend.py
+45
-11
imperative/python/megengine/xla/compile.py
imperative/python/megengine/xla/compile.py
+15
-4
imperative/python/src/external_convert.h
imperative/python/src/external_convert.h
+153
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+69
-2
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+3
-0
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+1
-0
imperative/python/test/unit/xla/functional/test_xla_convert.py
...ative/python/test/unit/xla/functional/test_xla_convert.py
+46
-0
未找到文件。
imperative/python/megengine/jit/xla_backend.py
浏览文件 @
4ae9dd00
from
collections
import
OrderedDict
,
defaultdict
from
collections
import
OrderedDict
,
defaultdict
from
..
import
tensor
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
(
is_external_convert
,
set_external_convert
,
set_external_convert_hook
,
set_py_external_type
,
unset_external_convert
,
)
from
..core._trace_option
import
set_use_xla_backend
from
..core._trace_option
import
set_use_xla_backend
from
..device
import
get_default_device
from
..device
import
get_default_device
from
..utils.dlpack
import
from_dlpack
,
to_dlpack
from
..utils.dlpack
import
from_dlpack
,
to_dlpack
from
.tracing
import
trace
from
.tracing
import
trace
# try:
# from mge_xlalib.xla_extension import ArrayImpl
# from ..xla.lib import xla_client as xc
# except ImportError:
# pass
from
mge_xlalib.xla_extension
import
ArrayImpl
from
..xla.lib
import
xla_client
as
xc
xla_client_compute_stream
=
None
def
apply_external_convert_hook
(
input
,
cn
):
stream
=
xla_client_compute_stream
assert
isinstance
(
input
,
ArrayImpl
)
dlpack_capsule
=
xc
.
_xla
.
buffer_to_dlpack_managed_tensor
(
input
,
take_ownership
=
True
)
output
=
from_dlpack
(
dlpack_capsule
,
stream
).
to
(
cn
,
_borrow
=
True
)
return
output
class
xla_trace
(
trace
):
class
xla_trace
(
trace
):
r
"""Wraps a callable, and provides accelerated evaluation compiled by xla.
r
"""Wraps a callable, and provides accelerated evaluation compiled by xla.
...
@@ -48,6 +77,12 @@ class xla_trace(trace):
...
@@ -48,6 +77,12 @@ class xla_trace(trace):
def
__init__
(
self
,
function
,
*
,
without_host
=
True
,
symbolic_shape
=
False
,
**
kwargs
):
def
__init__
(
self
,
function
,
*
,
without_host
=
True
,
symbolic_shape
=
False
,
**
kwargs
):
assert
without_host
,
"xla trace only support without host mode"
assert
without_host
,
"xla trace only support without host mode"
assert
not
symbolic_shape
,
"xla doesn't support dynamic shape currently"
assert
not
symbolic_shape
,
"xla doesn't support dynamic shape currently"
set_external_convert_hook
(
apply_external_convert_hook
)
set_py_external_type
(
ArrayImpl
)
set_external_convert
()
super
().
__init__
(
super
().
__init__
(
function
,
without_host
=
without_host
,
symbolic_shape
=
symbolic_shape
,
**
kwargs
function
,
without_host
=
without_host
,
symbolic_shape
=
symbolic_shape
,
**
kwargs
)
)
...
@@ -142,8 +177,8 @@ class xla_trace(trace):
...
@@ -142,8 +177,8 @@ class xla_trace(trace):
return
xc
.
_xla
.
buffer_to_dlpack_managed_tensor
(
x
,
take_ownership
=
take_ownership
)
return
xc
.
_xla
.
buffer_to_dlpack_managed_tensor
(
x
,
take_ownership
=
take_ownership
)
def
execute
(
self
,
*
args
,
**
kwargs
):
def
execute
(
self
,
*
args
,
**
kwargs
):
from
..traced_module.pytree
import
tree_flatten
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
..traced_module.pytree
import
tree_flatten
from
..utils.module_utils
import
get_expand_structure
from
..utils.module_utils
import
get_expand_structure
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
...
@@ -161,6 +196,8 @@ class xla_trace(trace):
...
@@ -161,6 +196,8 @@ class xla_trace(trace):
arrays
=
self
.
prepare_xla_inputs
(
arrays
)
arrays
=
self
.
prepare_xla_inputs
(
arrays
)
outputs
=
self
.
xla_exec
(
*
arrays
)
outputs
=
self
.
xla_exec
(
*
arrays
)
global
xla_client_compute_stream
xla_client_compute_stream
=
xla_stream
return_vals
=
[]
return_vals
=
[]
for
i
in
self
.
out_list
:
for
i
in
self
.
out_list
:
if
i
==
-
1
:
if
i
==
-
1
:
...
@@ -170,28 +207,25 @@ class xla_trace(trace):
...
@@ -170,28 +207,25 @@ class xla_trace(trace):
return_vals
.
append
(
outputs
[
self
.
outkey2idx
[
i
]])
return_vals
.
append
(
outputs
[
self
.
outkey2idx
[
i
]])
keeped_features
=
[]
keeped_features
=
[]
for
i
in
self
.
keeped_activation
:
for
i
in
self
.
keeped_activation
:
capsule
=
self
.
to_dlpack
(
outputs
[
self
.
outkey2idx
[
i
]])
keeped_features
.
append
(
outputs
[
self
.
outkey2idx
[
i
]])
t
=
from_dlpack
(
capsule
,
xla_stream
).
to
(
cn
,
_borrow
=
True
)
keeped_features
.
append
(
t
)
out_tensors
=
[]
out_tensors
=
[]
for
array
in
return_vals
:
for
array
in
return_vals
:
if
array
is
not
None
:
if
array
is
not
None
:
capsule
=
self
.
to_dlpack
(
array
)
t
=
tensor
(
array
,
device
=
cn
)
t
=
from_dlpack
(
capsule
,
xla_stream
)
out_tensors
.
append
(
t
)
out_tensors
.
append
(
t
.
to
(
cn
,
_borrow
=
True
))
else
:
else
:
out_tensors
.
append
(
array
)
out_tensors
.
append
(
array
)
if
self
.
overall
:
if
self
.
overall
:
for
attr
,
key
in
self
.
update_param_dict
.
items
():
for
attr
,
key
in
self
.
update_param_dict
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
xla_array
=
outputs
[
self
.
outkey2idx
[
key
]]
xla_array
=
outputs
[
self
.
outkey2idx
[
key
]]
capsule
=
self
.
to_dlpack
(
xla_array
)
t
=
tensor
(
xla_array
,
device
=
cn
)
param
.
_reset
(
from_dlpack
(
capsule
).
to
(
cn
,
_borrow
=
True
)
)
param
.
_reset
(
t
)
for
state
,
key
in
self
.
update_opt_param_dict
.
items
():
for
state
,
key
in
self
.
update_opt_param_dict
.
items
():
xla_array
=
outputs
[
self
.
outkey2idx
[
key
]]
xla_array
=
outputs
[
self
.
outkey2idx
[
key
]]
capsule
=
self
.
to_dlpack
(
xla_array
)
t
=
tensor
(
xla_array
,
device
=
cn
)
state
.
_reset
(
from_dlpack
(
capsule
).
to
(
cn
,
_borrow
=
True
)
)
state
.
_reset
(
t
)
rst
=
(
rst
=
(
self
.
outdef
.
unflatten
(
out_tensors
)
self
.
outdef
.
unflatten
(
out_tensors
)
if
hasattr
(
self
,
"outdef"
)
if
hasattr
(
self
,
"outdef"
)
...
...
imperative/python/megengine/xla/compile.py
浏览文件 @
4ae9dd00
...
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Set,
...
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Set,
import
numpy
as
np
import
numpy
as
np
from
..
import
tensor
from
..distributed
import
is_distributed
from
..distributed
import
is_distributed
from
..utils.dlpack
import
from_dlpack
,
to_dlpack
from
..utils.dlpack
import
from_dlpack
,
to_dlpack
from
.
import
ir_utils
from
.
import
ir_utils
...
@@ -68,10 +69,20 @@ class InputsHandler:
...
@@ -68,10 +69,20 @@ class InputsHandler:
def
__call__
(
self
,
input_buffers
):
def
__call__
(
self
,
input_buffers
):
rst
=
[]
rst
=
[]
for
ibuf
in
input_buffers
:
for
idx
,
i
in
enumerate
(
input_buffers
):
capsule
=
to_dlpack
(
ibuf
)
if
i
.
_is_external_value
():
xla_array
=
self
.
from_dlpack
(
capsule
)
rst
.
append
([
i
.
_external_obj
()])
rst
.
append
([
xla_array
])
else
:
if
"gpu"
in
i
.
device
.
physical_name
:
capsule
=
to_dlpack
(
i
)
xla_array
=
self
.
from_dlpack
(
capsule
)
rst
.
append
([
xla_array
])
else
:
r
=
self
.
handler
(
self
.
local_devices
,
[
self
.
input_indices
[
idx
],],
[
i
,]
)[
0
]
rst
.
append
(
r
)
i
.
_reset
(
tensor
(
r
[
0
]))
return
rst
return
rst
def
__str__
(
self
):
def
__str__
(
self
):
...
...
imperative/python/src/external_convert.h
0 → 100644
浏览文件 @
4ae9dd00
#pragma once
#include <list>
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/map.h"
#include "./tensor.h"
namespace
mgb
::
imperative
::
python
{
namespace
py
=
pybind11
;
class
CreateExternalWrapper
final
:
public
OperatorImpl
<
CreateExternalWrapper
>
{
private:
py
::
object
m_object
;
CompNode
m_device
;
public:
CreateExternalWrapper
(
py
::
object
obj
,
CompNode
device
)
:
m_object
(
obj
),
m_device
(
device
)
{}
py
::
object
object
()
const
{
return
m_object
;
}
CompNode
device
()
const
{
return
m_device
;
}
std
::
string
raw_type
()
const
{
return
"CreateExternalWrapper"
;
}
std
::
string
to_string
()
const
{
return
"CreateExternalWrapper"
;
};
};
class
GetExternalVal
final
:
public
OperatorImpl
<
GetExternalVal
,
Operator
::
GetAttrLike
>
{
public:
std
::
string
to_string
()
const
{
return
"GetExternalVal"
;
};
std
::
string
raw_type
()
const
{
return
"GetExternalVal"
;
}
};
class
PyobjectStorage
{
private:
py
::
object
m_object
;
public:
PyobjectStorage
()
=
default
;
PyobjectStorage
(
py
::
object
object
)
:
m_object
(
object
)
{}
py
::
object
object
()
const
{
return
m_object
;
}
std
::
string
to_string
()
const
{
return
"PyobjectStorage"
;
}
};
class
PyobjectValue
final
:
public
PrimitiveValue
<
PyobjectValue
,
PyobjectStorage
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
{
return
PyobjectStorage
::
to_string
();
}
};
class
ExternalValue
final
:
public
ObjectValue
<
ExternalValue
>
{
private:
py
::
object
m_obj
;
mutable
CompNodeValue
::
ref_t
m_device
;
public:
ExternalValue
(
py
::
object
obj
,
CompNode
device
)
:
m_obj
(
obj
),
m_device
(
CompNodeValue
::
make
(
device
))
{}
py
::
object
object
()
const
{
return
m_obj
;
}
CompNodeValue
::
ref_t
device
()
const
{
return
m_device
;
}
std
::
string
to_string
()
const
override
{
return
"ExternalValue"
;
}
void
clear
()
override
{}
};
class
ExternalConvertTransformation
final
:
public
Transformation
{
private:
py
::
function
m_hook_fn
;
int
m_enabled
=
0
;
ObjectType
<
ExternalValue
>
m_value_type
{
"ExternalValue"
};
public:
ValueRefList
apply_external_imperative_hook
(
const
Operator
&
op
,
Span
<
ValueRef
>
input_values
)
{
for
(
int
i
=
0
;
i
<
input_values
.
size
();
i
++
)
{
if
(
auto
*
val
=
input_values
[
i
].
as
(
m_value_type
))
{
CompNode
cn
=
*
(
val
->
device
());
py
::
object
fn_res
=
m_hook_fn
(
val
->
object
(),
cn
);
auto
*
tw
=
TensorWrapper
::
try_cast
(
fn_res
.
ptr
());
mgb_assert
(
tw
,
"expect Tensor"
);
auto
external_input
=
input_values
[
i
].
as_ref
(
m_value_type
);
external_input
.
reset
(
tw
->
m_tensor
->
data
());
}
}
auto
outputs
=
imperative
::
apply
(
op
,
input_values
);
return
outputs
;
}
ExternalConvertTransformation
(
py
::
function
hook_fn
)
:
m_hook_fn
(
hook_fn
)
{}
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
{
if
(
!
m_enabled
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
bool
has_external_inp
=
false
;
if
(
auto
*
obj_value
=
op
.
as
<
CreateExternalWrapper
>
())
{
return
m_value_type
.
make
(
obj_value
->
object
(),
obj_value
->
device
());
}
for
(
auto
&&
input
:
inputs
)
{
if
(
input
.
is
(
m_value_type
))
{
has_external_inp
=
true
;
break
;
}
}
if
(
!
has_external_inp
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
else
if
(
op
.
is
<
GetExternalVal
>
())
{
py
::
object
m_object
=
inputs
.
item
().
cast
(
m_value_type
).
object
();
PyobjectStorage
inp_obj
=
PyobjectStorage
(
m_object
);
return
{
PyobjectValue
::
make
(
inp_obj
)};
}
else
if
(
op
.
is
<
RenameValue
>
())
{
return
{
inputs
[
0
]};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
&
input
=
inputs
.
item
().
cast
(
m_value_type
);
ValueRefList
outputs
;
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Device
:
outputs
=
{
input
.
device
()};
break
;
default:
outputs
=
apply_external_imperative_hook
(
op
,
inputs
);
break
;
}
return
outputs
;
}
else
{
auto
outputs
=
apply_external_imperative_hook
(
op
,
inputs
);
return
outputs
;
}
}
void
enable
()
{
m_enabled
=
1
;
}
void
disable
()
{
m_enabled
=
0
;
}
bool
enabled
()
const
{
return
m_enabled
;
}
ValueRef
unwrap
(
ValueRef
value
)
override
{
return
value
;
}
const
Type
<
ExternalValue
>&
value_type
()
const
{
return
m_value_type
;
}
std
::
string
name
()
const
override
{
return
"ExternalConvertTransformation"
;
}
};
}
// namespace mgb::imperative::python
\ No newline at end of file
imperative/python/src/tensor.cpp
浏览文件 @
4ae9dd00
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include "./common.h"
#include "./common.h"
#include "./dlpack.h"
#include "./dlpack.h"
#include "./dlpack_convertor.h"
#include "./dlpack_convertor.h"
#include "./external_convert.h"
#include "./grad.h"
#include "./grad.h"
#include "./graph_rt.h"
#include "./graph_rt.h"
#include "./helper.h"
#include "./helper.h"
...
@@ -61,6 +62,7 @@ namespace mgb::imperative::python {
...
@@ -61,6 +62,7 @@ namespace mgb::imperative::python {
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
=
nullptr
;
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
=
nullptr
;
PyTypeObject
*
py_tensor_type
=
nullptr
;
PyTypeObject
*
py_tensor_type
=
nullptr
;
PyTypeObject
*
py_varnode_type
=
nullptr
;
PyTypeObject
*
py_varnode_type
=
nullptr
;
PyTypeObject
*
py_external_type
=
nullptr
;
pybind11
::
handle
py_device_type
=
nullptr
;
pybind11
::
handle
py_device_type
=
nullptr
;
PyObject
*
cpp_use_symbolic_shape
;
PyObject
*
cpp_use_symbolic_shape
;
...
@@ -589,7 +591,13 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
...
@@ -589,7 +591,13 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
:
no_cache
?
CreateTensor
::
Unique
:
no_cache
?
CreateTensor
::
Unique
:
CreateTensor
::
Common
;
:
CreateTensor
::
Common
;
ValueRef
val
;
ValueRef
val
;
if
(
py
::
isinstance
(
data
,
Py_Varnode
))
{
bool
use_external_inp
=
py_external_type
!=
nullptr
;
if
(
use_external_inp
&&
PyObject_TypeCheck
(
py
::
handle
(
data
).
ptr
(),
py_external_type
))
{
val
=
imperative
::
apply
(
CreateExternalWrapper
(
data
,
cn
),
Span
<
ValueRef
>
(
nullptr
,
nullptr
))[
0
];
}
else
if
(
py
::
isinstance
(
data
,
Py_Varnode
))
{
cg
::
VarNode
*
m_node
=
py
::
handle
(
data
).
cast
<
cg
::
VarNode
*>
();
cg
::
VarNode
*
m_node
=
py
::
handle
(
data
).
cast
<
cg
::
VarNode
*>
();
val
=
imperative
::
apply
(
val
=
imperative
::
apply
(
CreateNode
(
m_node
),
Span
<
ValueRef
>
(
nullptr
,
nullptr
))[
0
];
CreateNode
(
m_node
),
Span
<
ValueRef
>
(
nullptr
,
nullptr
))[
0
];
...
@@ -750,6 +758,27 @@ PyObject* TensorWrapper::_graph() {
...
@@ -750,6 +758,27 @@ PyObject* TensorWrapper::_graph() {
return
py
::
cast
(
graph
).
release
().
ptr
();
return
py
::
cast
(
graph
).
release
().
ptr
();
}
}
PyObject
*
TensorWrapper
::
_external_obj
()
{
TypedValueRef
<
PyobjectValue
>
value
=
imperative
::
apply
(
GetExternalVal
(),
m_tensor
->
data
())[
0
]
.
as_ref
<
PyobjectValue
>
();
return
value
->
object
().
release
().
ptr
();
}
PyObject
*
TensorWrapper
::
_is_external_value
()
{
auto
&&
external_tsf
=
TransformationManager
::
get_instance
()
.
segments
[
TransformationManager
::
Segment
::
ExternalConvert
];
auto
*
tsf
=
reinterpret_cast
<
ExternalConvertTransformation
*>
(
external_tsf
[
0
].
get
());
mgb_assert
(
tsf
->
enabled
());
auto
valueref
=
m_tensor
->
data
();
if
(
valueref
.
is
(
tsf
->
value_type
()))
{
Py_RETURN_TRUE
;
}
else
{
Py_RETURN_FALSE
;
}
}
void
dlpack_capsule_destructor
(
PyObject
*
data
)
{
void
dlpack_capsule_destructor
(
PyObject
*
data
)
{
if
(
!
PyCapsule_IsValid
(
data
,
"dltensor"
))
{
if
(
!
PyCapsule_IsValid
(
data
,
"dltensor"
))
{
// early out, see DLPack spec: if a consuming library sets the capsule
// early out, see DLPack spec: if a consuming library sets the capsule
...
@@ -931,6 +960,8 @@ void init_tensor(py::module m) {
...
@@ -931,6 +960,8 @@ void init_tensor(py::module m) {
.
def
<&
TensorWrapper
::
_var
>
(
"var"
)
.
def
<&
TensorWrapper
::
_var
>
(
"var"
)
.
def
<&
TensorWrapper
::
_graph
>
(
"graph"
)
.
def
<&
TensorWrapper
::
_graph
>
(
"graph"
)
.
def
<&
TensorWrapper
::
value_id
>
(
"value_id"
)
.
def
<&
TensorWrapper
::
value_id
>
(
"value_id"
)
.
def
<&
TensorWrapper
::
_is_external_value
>
(
"_is_external_value"
)
.
def
<&
TensorWrapper
::
_external_obj
>
(
"_external_obj"
)
.
def_getset
<
.
def_getset
<
&
TensorWrapper
::
module_trace_info
,
&
TensorWrapper
::
module_trace_info
,
&
TensorWrapper
::
set_module_trace_info
>
(
"_NodeMixin__node"
)
&
TensorWrapper
::
set_module_trace_info
>
(
"_NodeMixin__node"
)
...
@@ -1150,6 +1181,10 @@ void init_tensor(py::module m) {
...
@@ -1150,6 +1181,10 @@ void init_tensor(py::module m) {
py_varnode_type
=
reinterpret_cast
<
PyTypeObject
*>
(
type_obj
.
inc_ref
().
ptr
());
py_varnode_type
=
reinterpret_cast
<
PyTypeObject
*>
(
type_obj
.
inc_ref
().
ptr
());
});
});
m
.
def
(
"set_py_external_type"
,
[](
py
::
object
type_obj
)
{
py_external_type
=
reinterpret_cast
<
PyTypeObject
*>
(
type_obj
.
inc_ref
().
ptr
());
});
m
.
def
(
"set_py_device_type"
,
m
.
def
(
"set_py_device_type"
,
[](
py
::
object
type_obj
)
{
py_device_type
=
type_obj
.
inc_ref
();
});
[](
py
::
object
type_obj
)
{
py_device_type
=
type_obj
.
inc_ref
();
});
...
@@ -1705,6 +1740,24 @@ void init_tensor(py::module m) {
...
@@ -1705,6 +1740,24 @@ void init_tensor(py::module m) {
return
module_trace_transformation
;
return
module_trace_transformation
;
};
};
static
py
::
function
external_convert_hook
;
static
auto
get_external_convert
=
[]
{
static
std
::
shared_ptr
<
ExternalConvertTransformation
>
external_convert_transformation
;
if
(
!
external_convert_transformation
)
{
mgb_assert
(
external_convert_hook
);
external_convert_transformation
=
std
::
make_shared
<
ExternalConvertTransformation
>
(
external_convert_hook
);
MGB_MARK_USED_VAR
(
transformations
.
register_at
<
Segment
::
ExternalConvert
>
(
external_convert_transformation
)
.
release
());
}
return
external_convert_transformation
;
};
m
.
def
(
"set_cpp_use_symbolic_shape"
,
&
set_cpp_use_symbolic_shape
);
m
.
def
(
"set_cpp_use_symbolic_shape"
,
&
set_cpp_use_symbolic_shape
);
m
.
def
(
"set_module_tracing"
,
[
=
]
{
get_module_trace
()
->
enable
();
});
m
.
def
(
"set_module_tracing"
,
[
=
]
{
get_module_trace
()
->
enable
();
});
...
@@ -1712,6 +1765,12 @@ void init_tensor(py::module m) {
...
@@ -1712,6 +1765,12 @@ void init_tensor(py::module m) {
m
.
def
(
"unset_module_tracing"
,
[
=
]
{
get_module_trace
()
->
disable
();
});
m
.
def
(
"unset_module_tracing"
,
[
=
]
{
get_module_trace
()
->
disable
();
});
m
.
def
(
"is_tracing_module"
,
[
=
]
{
return
get_module_trace
()
->
enabled
();
});
m
.
def
(
"is_tracing_module"
,
[
=
]
{
return
get_module_trace
()
->
enabled
();
});
m
.
def
(
"set_external_convert"
,
[
=
]
{
get_external_convert
()
->
enable
();
});
m
.
def
(
"unset_external_convert"
,
[
=
]
{
get_external_convert
()
->
disable
();
});
m
.
def
(
"is_external_convert"
,
[
=
]
{
return
get_external_convert
()
->
enabled
();
});
m
.
def
(
"set_python_backtrace_enabled"
,
&
set_python_backtrace_enabled
);
m
.
def
(
"set_python_backtrace_enabled"
,
&
set_python_backtrace_enabled
);
m
.
def
(
"set_transformation_backtrace_enabled"
,
m
.
def
(
"set_transformation_backtrace_enabled"
,
&
set_transformation_backtrace_enabled
);
&
set_transformation_backtrace_enabled
);
...
@@ -1723,8 +1782,16 @@ void init_tensor(py::module m) {
...
@@ -1723,8 +1782,16 @@ void init_tensor(py::module m) {
module_trace_hook
.
inc_ref
();
module_trace_hook
.
inc_ref
();
});
});
m
.
def
(
"set_external_convert_hook"
,
[](
py
::
function
function
)
{
external_convert_hook
=
function
;
external_convert_hook
.
inc_ref
();
});
auto
atexit
=
py
::
module
::
import
(
"atexit"
);
auto
atexit
=
py
::
module
::
import
(
"atexit"
);
atexit
.
attr
(
"register"
)(
py
::
cpp_function
([]()
{
module_trace_hook
=
{};
}));
atexit
.
attr
(
"register"
)(
py
::
cpp_function
([]()
{
module_trace_hook
=
{};
external_convert_hook
=
{};
}));
m
.
def
(
"begin_record_values"
,
[]
{
Value
::
begin_record_values
();
});
m
.
def
(
"begin_record_values"
,
[]
{
Value
::
begin_record_values
();
});
m
.
def
(
"end_record_values"
,
[]
{
m
.
def
(
"end_record_values"
,
[]
{
...
...
imperative/python/src/tensor.h
浏览文件 @
4ae9dd00
...
@@ -31,6 +31,7 @@ namespace mgb::imperative::python {
...
@@ -31,6 +31,7 @@ namespace mgb::imperative::python {
extern
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
;
extern
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
;
extern
PyTypeObject
*
py_tensor_type
;
extern
PyTypeObject
*
py_tensor_type
;
extern
PyTypeObject
*
py_varnode_type
;
extern
PyTypeObject
*
py_varnode_type
;
extern
PyTypeObject
*
py_external_type
;
extern
pybind11
::
handle
py_device_type
;
extern
pybind11
::
handle
py_device_type
;
extern
PyObject
*
cpp_use_symbolic_shape
;
extern
PyObject
*
cpp_use_symbolic_shape
;
extern
PyObject
*
cpp_astensor1d
;
extern
PyObject
*
cpp_astensor1d
;
...
@@ -142,6 +143,8 @@ public:
...
@@ -142,6 +143,8 @@ public:
PyObject
*
_detail
();
PyObject
*
_detail
();
PyObject
*
_var
();
PyObject
*
_var
();
PyObject
*
_graph
();
PyObject
*
_graph
();
PyObject
*
_is_external_value
();
PyObject
*
_external_obj
();
void
_watch
();
void
_watch
();
};
};
...
...
imperative/python/src/transformation.h
浏览文件 @
4ae9dd00
...
@@ -22,6 +22,7 @@ public:
...
@@ -22,6 +22,7 @@ public:
Complex
,
Complex
,
Format
,
Format
,
Grad
,
Grad
,
ExternalConvert
,
Scalar
,
Scalar
,
Symbol
,
Symbol
,
Trace
,
Trace
,
...
...
imperative/python/test/unit/xla/functional/test_xla_convert.py
0 → 100644
浏览文件 @
4ae9dd00
import
platform
import
numpy
as
np
import
pytest
import
megengine.functional
as
F
import
megengine.jit
as
jit
import
megengine.tensor
as
tensor
from
megengine
import
autodiff
,
is_cuda_available
from
megengine.autodiff.grad_manager
import
GradManager
from
meg_xlalib.xla_extension
import
ArrayImpl
def
test_external_flag_set
():
@
xla_trace
(
capture_as_const
=
True
)
def
test_fun
():
pass
def
test_external_value
():
m
=
Conv2d
(
9
,
9
,
3
,
groups
=
9
)
gm
=
GradManager
()
gm
.
attach
(
m
.
parameters
())
@
xla_trace
(
capture_as_const
=
True
)
def
conv_grad
(
inp
,
model
):
with
gm
:
gm
.
attach
(
inp
)
rst
=
model
(
inp
)
gm
.
backward
(
rst
.
mean
())
ig
=
inp
.
grad
wg
=
model
.
weight
.
grad
inp
.
grad
=
None
model
.
weight
.
grad
=
None
return
ig
,
wg
inp
=
tensor
(
np
.
random
.
random
((
9
,
9
,
32
,
32
)))
*
100
a
,
b
=
conv_grad
(
inp
,
m
)
a1
,
b1
=
conv_grad
(
inp
,
m
)
np
.
testing
.
assert_allclose
(
a
.
numpy
(),
a1
.
numpy
())
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录