Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
dc250745
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
dc250745
编写于
12月 25, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): add python custom op
GitOrigin-RevId: 35da0bb3017bdf90f7074bc84d9f3321672aad79
上级
60c44b08
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
254 addition
and
26 deletion
+254
-26
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+102
-14
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+38
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+38
-10
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+16
-1
imperative/python/src/trace.cpp
imperative/python/src/trace.cpp
+1
-1
imperative/src/impl/ops/utility.cpp
imperative/src/impl/ops/utility.cpp
+21
-0
imperative/src/include/megbrain/imperative/ops/utility.h
imperative/src/include/megbrain/imperative/ops/utility.h
+38
-0
未找到文件。
imperative/python/src/grad.cpp
浏览文件 @
dc250745
...
...
@@ -12,6 +12,7 @@
#include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/utils/mempool.h"
#include "range/v3/all.hpp"
...
...
@@ -21,6 +22,9 @@ namespace views = ranges::views;
namespace
mgb
::
imperative
::
python
{
using
scoped_disable
=
ApplyContext
::
scoped_disable
;
using
Flags
=
Tensor
::
Flags
;
namespace
{
struct
GradSlotWeakPtr
{
...
...
@@ -78,6 +82,21 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
return
result
;
}
struct
BackwardContext
{
PyTypeObject
*
pytype
=
nullptr
;
auto
wrap_tensor
(
std
::
shared_ptr
<
Tensor
>
t
)
{
if
(
pytype
)
{
return
TensorWrapper
::
make
(
pytype
,
std
::
move
(
t
));
}
return
TensorWrapper
::
make
(
std
::
move
(
t
));
}
auto
wrap_tensor
(
Tensor
*
t
)
{
return
wrap_tensor
(
t
->
shared_from_this
());
}
};
struct
BackwardGraphWithClosure
{
std
::
shared_ptr
<
BackwardGraphResult
>
backward_graph
;
SmallVector
<
std
::
shared_ptr
<
Tensor
>>
closure
;
...
...
@@ -119,7 +138,7 @@ struct BackwardGraphWithClosure {
}
template
<
typename
T
,
typename
R
>
void
operator
()(
T
&&
grads
,
R
&&
receiver
)
{
void
operator
()(
BackwardContext
&
,
T
&&
grads
,
R
&&
receiver
)
{
Tensor
*
args
[
closure
.
size
()
+
grads
.
size
()];
size_t
nargs
=
0
;
for
(
auto
&&
t
:
closure
)
{
...
...
@@ -143,7 +162,7 @@ struct BackwardGraphWithClosure {
ApplyContext
ctx
;
ctx
.
op
=
backward_graph
->
backward
;
ctx
.
flags
=
is_tracing
?
Tensor
::
Flags
::
TRACE
:
0
;
ctx
.
flags
=
is_tracing
?
Flags
::
TRACE
:
0
;
ctx
.
nargs
=
nargs
;
ctx
.
args
=
args
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
...
...
@@ -174,6 +193,47 @@ struct BackwardGraphWithClosure {
}
};
struct
PythonBackward
{
py
::
object
pyfunc
;
size_t
input_size
;
PythonBackward
(
py
::
object
f
,
size_t
nin
)
:
pyfunc
(
f
),
input_size
(
nin
)
{}
template
<
typename
T
,
typename
R
>
void
operator
()(
BackwardContext
&
ctx
,
T
&&
grads
,
R
&&
receiver
)
{
auto
args
=
py
::
tuple
(
grads
.
size
());
for
(
size_t
i
=
0
;
i
<
grads
.
size
();
++
i
)
{
auto
&&
g
=
grads
[
i
];
args
[
i
]
=
g
?
ctx
.
wrap_tensor
(
g
)
:
py
::
none
();
}
auto
input_grads
=
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
pyfunc
.
ptr
(),
args
.
ptr
(),
nullptr
));
if
(
input_grads
.
is_none
())
return
;
if
(
auto
*
tw
=
TensorWrapper
::
try_cast
(
input_grads
.
ptr
()))
{
if
(
input_size
!=
1
)
{
throw
py
::
value_error
(
"custom grad rule returned wrong number of grads"
);
}
receiver
(
0
,
tw
->
m_tensor
);
return
;
}
if
(
py
::
len
(
input_grads
)
!=
input_size
)
{
throw
py
::
value_error
(
"custom grad rule returned wrong number of grads"
);
}
for
(
auto
[
i
,
g
]
:
views
::
enumerate
(
input_grads
))
{
if
(
g
.
is_none
())
continue
;
auto
*
tw
=
TensorWrapper
::
try_cast
(
g
.
ptr
());
if
(
!
tw
)
{
throw
py
::
type_error
(
"custom grad rule returned non-tensor"
);
}
receiver
(
i
,
tw
->
m_tensor
);
}
}
static
constexpr
bool
input_has_grad
(
size_t
)
{
return
true
;}
static
constexpr
bool
output_requires_grad
(
size_t
)
{
return
true
;}
static
constexpr
bool
output_captured
(
size_t
)
{
return
true
;}
};
}
// namespace
struct
GradProducerRecord
:
intrusive_list
::
Node
<
GradProducerRecord
>
{
...
...
@@ -210,7 +270,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
// same length as inputs (of forward op)
SmallVector
<
GradSlotProducerPtr
>
dsts
;
// encapsules actual function to compute gradient
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
>
backward
;
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
PythonBackward
>
backward
;
// a flag used during backward
bool
in_ref_keeper
=
false
;
...
...
@@ -268,6 +328,30 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
return
outputs
;
}
apply_result_t
python_grad_rule
(
ApplyContext
&
ctx
,
GradFnHelper
&
ret_grad_fn
)
{
auto
*
op
=
ctx
.
op
->
try_cast_final
<
GenericPyOp
>
();
py
::
tuple
pyin
(
ctx
.
nargs
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
pyin
[
i
]
=
TensorWrapper
::
make
(
ctx
.
pytype
,
ctx
.
args
[
i
]
->
shared_from_this
());
}
auto
grad_rule
=
py
::
getattr
(
op
->
obj
,
"_grad_rule"
);
auto
pyret
=
(
scoped_disable
(
Flags
::
GRAD
),
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
grad_rule
.
ptr
(),
pyin
.
ptr
(),
nullptr
)));
// comma expression
auto
[
outputs
,
backward
]
=
py
::
cast
<
std
::
tuple
<
py
::
object
,
py
::
function
>>
(
pyret
);
ret_grad_fn
.
emplace
<
PythonBackward
>
(
std
::
move
(
backward
),
ctx
.
nargs
);
if
(
auto
*
tw
=
TensorWrapper
::
try_cast
(
outputs
.
ptr
()))
{
return
{
tw
->
m_tensor
};
}
apply_result_t
ret
;
ret
.
reserve
(
py
::
len
(
outputs
));
for
(
auto
&&
i
:
outputs
)
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
i
.
ptr
());
mgb_assert
(
tw
);
ret
.
push_back
(
tw
->
m_tensor
);
}
return
ret
;
}
}
// namespace
apply_result_t
apply_grad
(
ApplyContext
&
ctx
)
{
...
...
@@ -290,21 +374,23 @@ apply_result_t apply_grad(ApplyContext& ctx) {
// cleanup stale grad info
// under what condition?
tensor
->
m_grad_info
=
{};
tensor
->
m_flags
&=
~
Tensor
::
Flags
::
GRAD
;
tensor
->
m_flags
&=
~
Flags
::
GRAD
;
}
}
else
{
tensor
->
m_flags
&=
~
Tensor
::
Flags
::
GRAD
;
tensor
->
m_flags
&=
~
Flags
::
GRAD
;
}
}
ctx
.
flags
&=
~
Tensor
::
Flags
::
GRAD
;
ctx
.
flags
&=
~
Flags
::
GRAD
;
if
(
!
grad_key
)
{
return
apply
(
ctx
);
}
GradFnHelper
grad_fn_holder
;
auto
outputs
=
backward_graph_grad_rule
(
ctx
,
grad_fn_holder
);
auto
outputs
=
ctx
.
op
->
same_type
<
GenericPyOp
>
()
?
python_grad_rule
(
ctx
,
grad_fn_holder
)
:
backward_graph_grad_rule
(
ctx
,
grad_fn_holder
);
auto
&
grad_fn
=
grad_fn_holder
.
grad_fn
;
if
(
!
grad_fn
)
{
...
...
@@ -341,7 +427,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
grad_info
.
grad_fn
=
grad_fn
;
grad_info
.
idx
=
i
;
grad_info
.
insert_after
(
grad_key
->
free_vars_head
);
outputs
[
i
]
->
m_flags
|=
Tensor
::
Flags
::
GRAD
;
outputs
[
i
]
->
m_flags
|=
Flags
::
GRAD
;
}
}
}
...
...
@@ -357,7 +443,7 @@ void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
if
(
nargs
!=
2
)
{
throw
py
::
type_error
(
"expect 2 arguments"
);
}
auto
*
tw
=
TensorWrapper
::
cast_safe
(
args
[
0
]);
auto
*
tw
=
TensorWrapper
::
try_cast
(
args
[
0
]);
if
(
!
tw
)
{
throw
py
::
type_error
(
"argument 1 must be Tensor"
);
}
...
...
@@ -390,14 +476,15 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) {
grad_fn
->
key
=
shared_from_this
();
grad_fn
->
slots
.
resize
(
1
);
tensor
->
m_grad_info
.
insert_after
(
free_vars_head
);
tensor
->
m_flags
|=
Tensor
::
Flags
::
GRAD
;
tensor
->
m_flags
|=
Flags
::
GRAD
;
}
tensor
->
m_grad_info
.
grad_fn
->
slots
[
0
].
callback
=
std
::
move
(
callback
);
}
void
accum_grad
(
std
::
shared_ptr
<
Tensor
>&
grad
,
std
::
shared_ptr
<
Tensor
>&&
delta
)
{
template
<
typename
T
>
void
accum_grad
(
std
::
shared_ptr
<
Tensor
>&
grad
,
T
&&
delta
)
{
if
(
!
grad
)
{
grad
=
std
::
forward
<
decltype
(
delta
)
>
(
delta
);
grad
=
std
::
forward
<
T
>
(
delta
);
return
;
}
static
ApplyContext
ctx
;
...
...
@@ -409,7 +496,7 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta)
ctx
.
args
=
args
;
ctx
.
flags
=
grad
->
m_flags
|
delta
->
m_flags
;
if
(
is_tracing
)
{
ctx
.
flags
|=
Tensor
::
Flags
::
TRACE
;
ctx
.
flags
|=
Flags
::
TRACE
;
}
grad
=
apply
(
ctx
)[
0
];
}
...
...
@@ -440,6 +527,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
}
BackwardContext
bctx
{
pytype
};
std
::
vector
<
std
::
shared_ptr
<
GradFn
>>
ref_keeper
;
ref_keeper
.
reserve
(
tape
.
size
());
// back-propagation in reverse order
...
...
@@ -456,7 +544,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
mgb_assert
(
0
);
}
else
{
auto
&&
grads
=
views
::
transform
(
grad_fn
->
slots
,
[](
auto
&&
slot
)
{
return
slot
.
grad
.
get
();});
backward
(
std
::
forward
<
decltype
(
grads
)
>
(
grads
),
grad_receiver
);
backward
(
bctx
,
std
::
forward
<
decltype
(
grads
)
>
(
grads
),
grad_receiver
);
}
},
grad_fn
->
backward
);
...
...
imperative/python/src/ops.cpp
浏览文件 @
dc250745
...
...
@@ -14,6 +14,7 @@
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include <Python.h>
...
...
@@ -245,6 +246,35 @@ void _init_py_backward_graph(py::module m) {
mgb_assert
(
PyOp
(
OpDef
)
::
ctype2pytype
.
emplace
(
BackwardGraph
::
typeinfo
(),
&
py_type
).
second
);
}
struct
PyOpBase
:
PyOpDef
{
static
PyTypeObject
py_type
;
static
PyObject
*
tp_new
(
PyTypeObject
*
type
,
PyObject
*
,
PyObject
*
)
{
auto
*
obj
=
type
->
tp_alloc
(
type
,
0
);
if
(
obj
)
{
auto
*
self
=
reinterpret_cast
<
PyOpBase
*>
(
obj
);
new
(
&
self
->
op
)
decltype
(
self
->
op
);
}
return
obj
;
}
};
PyTypeObject
PyOpBase
::
py_type
;
void
_init_py_op_base
(
py
::
module
m
)
{
using
py_op
=
PyOpBase
;
auto
&
py_type
=
PyOpBase
::
py_type
;
py_type
=
{
PyVarObject_HEAD_INIT
(
NULL
,
0
)};
py_type
.
tp_name
=
"megengine.core._imperative_rt.ops.PyOpBase"
;
py_type
.
tp_basicsize
=
sizeof
(
py_op
);
py_type
.
tp_flags
=
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
;
py_type
.
tp_doc
=
"PyOpBase"
;
py_type
.
tp_base
=
&
PyOpType
(
OpDef
);
py_type
.
tp_dealloc
=
py_dealloc_generic
<
py_op
>
;
py_type
.
tp_new
=
py_op
::
tp_new
;
mgb_assert
(
PyType_Ready
(
&
py_type
)
>=
0
);
m
.
add_object
(
"PyOpBase"
,
reinterpret_cast
<
PyObject
*>
(
&
py_type
));
}
/*********** end of hand-write opdefs **************/
// auto generated opdefs
...
...
@@ -260,9 +290,16 @@ bool type_caster<OpDef>::load(handle src, bool convert) {
return
false
;
}
value
=
reinterpret_cast
<
PyOp
(
OpDef
)
*>
(
obj
)
->
op
;
if
(
!
value
)
{
// opdef only defined in Python
value
=
std
::
make_shared
<
GenericPyOp
>
(
reinterpret_borrow
<
object
>
(
src
));
}
return
true
;
}
handle
type_caster
<
OpDef
>::
cast
(
const
OpDef
&
op
,
return_value_policy
,
handle
)
{
if
(
auto
*
pyop
=
op
.
try_cast_final
<
GenericPyOp
>
())
{
return
object
(
pyop
->
obj
).
release
();
}
PyTypeObject
*
pytype
;
auto
&
c2p
=
PyOp
(
OpDef
)
::
ctype2pytype
;
auto
&&
iter
=
c2p
.
find
(
op
.
dyn_typeinfo
());
...
...
@@ -283,5 +320,6 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
void
init_ops
(
py
::
module
m
)
{
_init_py_op_def
(
m
);
_init_py_backward_graph
(
m
);
_init_py_op_base
(
m
);
INIT_ALL_OP
(
m
)
}
imperative/python/src/tensor.cpp
浏览文件 @
dc250745
...
...
@@ -11,6 +11,7 @@
#include "megbrain/dtype.h"
#include "megbrain/common.h"
#include "megbrain/imperative/ops/utility.h"
#include "./tensor.h"
#include "./grad.h"
...
...
@@ -22,10 +23,12 @@
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <range/v3/all.hpp>
#include <unordered_map>
namespace
py
=
pybind11
;
namespace
views
=
ranges
::
views
;
namespace
mgb
::
imperative
::
python
{
...
...
@@ -69,21 +72,45 @@ SET_UNSET_PROP(compiled)
bool
skip_tracing
=
false
;
Tensor
::
flags_t
ApplyContext
::
global_disable
=
0
;
apply_result_t
apply
(
ApplyContext
&
ctx
)
{
// emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python
// side. It could be move to C++ side if it has an impact on performance
if
(
ctx
.
flags
&
Tensor
::
Flags
::
SCALAR
)
{
auto
flags
=
ctx
.
flags
&
~
ApplyContext
::
global_disable
;
if
(
flags
&
Tensor
::
Flags
::
SCALAR
)
{
// TODO: emulate scalar
}
if
(
ctx
.
flags
&
Tensor
::
Flags
::
GRAD
)
{
if
(
flags
&
Tensor
::
Flags
::
GRAD
)
{
return
apply_grad
(
ctx
);
}
if
(
ctx
.
flags
&
Tensor
::
Flags
::
TRACE
)
{
if
(
flags
&
Tensor
::
Flags
::
TRACE
)
{
return
apply_trace
(
ctx
);
}
else
{
if
(
auto
*
op
=
ctx
.
op
->
try_cast_final
<
GenericPyOp
>
())
{
py
::
tuple
pyin
(
ctx
.
nargs
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
pyin
[
i
]
=
TensorWrapper
::
make
(
ctx
.
pytype
,
ctx
.
args
[
i
]
->
shared_from_this
());
}
auto
f
=
py
::
getattr
(
op
->
obj
,
"_default_rule"
);
auto
pyout
=
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
f
.
ptr
(),
pyin
.
ptr
(),
nullptr
));
if
(
auto
*
tw
=
TensorWrapper
::
try_cast
(
pyout
.
ptr
()))
{
return
{
tw
->
m_tensor
};
}
apply_result_t
ret
;
ret
.
reserve
(
py
::
len
(
pyout
));
for
(
auto
&&
i
:
pyout
)
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
i
.
ptr
());
mgb_assert
(
tw
);
ret
.
push_back
(
tw
->
m_tensor
);
}
return
ret
;
}
SmallVector
<
interpreter
::
Interpreter
::
Handle
>
handles
(
ctx
.
nargs
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
handles
[
i
]
=
ctx
.
args
[
i
]
->
m_handle
.
get
();
...
...
@@ -125,12 +152,13 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
SmallVector
<
Tensor
*
,
64
>
tensors
(
nargs
);
ctx
.
args
=
&
tensors
[
0
];
ctx
.
nargs
=
nargs
;
ctx
.
pytype
=
pytype
;
if
(
strstr
(
op
->
ob_type
->
tp_name
,
"BackwardGraph"
))
{
ctx
.
backward
=
true
;
}
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
if
(
TensorWrapper
*
tw
=
TensorWrapper
::
cast_safe
(
args
[
i
]))
{
if
(
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
args
[
i
]))
{
auto
*
t
=
tensors
[
i
]
=
tw
->
m_tensor
.
get
();
ctx
.
flags
|=
t
->
m_flags
;
}
else
{
...
...
@@ -166,7 +194,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if
(
nargs
==
0
)
{
throw
py
::
type_error
(
"too few arguments"
);
}
if
(
auto
*
t
=
cast_safe
(
tup
[
0
].
ptr
()))
{
if
(
auto
*
t
=
try_cast
(
tup
[
0
].
ptr
()))
{
if
(
nargs
>
1
)
{
throw
py
::
type_error
(
"expect 1 argument"
);
}
...
...
@@ -211,7 +239,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
auto
ret
=
pyf
(
*
tup
);
auto
py_ret
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
if
(
auto
*
t
=
cast_safe
(
py_ret
[
0
].
ptr
()))
{
if
(
auto
*
t
=
try_cast
(
py_ret
[
0
].
ptr
()))
{
m_tensor
=
t
->
m_tensor
;
}
return
;
...
...
@@ -349,7 +377,7 @@ PyObject* TensorWrapper::varnode() {
}
void
TensorWrapper
::
reset
(
PyObject
*
tensor
)
{
TensorWrapper
*
t
=
TensorWrapper
::
cast_safe
(
tensor
);
TensorWrapper
*
t
=
TensorWrapper
::
try_cast
(
tensor
);
if
(
!
t
)
{
throw
py
::
type_error
(
"expect Tensor"
);
}
...
...
@@ -446,7 +474,7 @@ uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
}
}
// Returns the data type with sufficient size to hold all types of
// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr
*
promote_types
(
SmallVector
<
PyArray_Descr
*>
types
,
uint8_t
cat
)
{
// Return value: New reference
...
...
@@ -507,7 +535,7 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
if
(
handle
==
Py_None
)
continue
;
TensorWrapper
*
tw
=
TensorWrapper
::
cast_safe
(
handle
);
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
if
(
tw
)
{
mgb
::
DType
type
=
tw
->
m_tensor
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
...
...
@@ -562,7 +590,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
CompNode
cn
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
TensorWrapper
*
tw
=
TensorWrapper
::
cast_safe
(
handle
);
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
if
(
tw
)
{
if
(
!
valid
)
{
cn
=
tw
->
m_tensor
->
comp_node
();
...
...
imperative/python/src/tensor.h
浏览文件 @
dc250745
...
...
@@ -124,7 +124,7 @@ struct TensorWrapper {
friend
wrap_t
;
inline
static
TensorWrapper
*
cast
(
PyObject
*
op
)
{
return
reinterpret_cast
<
wrap_t
*>
(
op
)
->
inst
();}
inline
static
TensorWrapper
*
cast_safe
(
PyObject
*
op
)
{
inline
static
TensorWrapper
*
try_cast
(
PyObject
*
op
)
{
if
(
!
wrap_t
::
type
().
isinstance
(
op
))
return
nullptr
;
return
cast
(
op
);
}
...
...
@@ -173,11 +173,26 @@ struct TensorWrapper {
PyObject
*
py_apply
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
/* , PyObject* kwnames */
);
struct
ApplyContext
{
static
Tensor
::
flags_t
global_disable
;
Tensor
::
flags_t
flags
;
std
::
shared_ptr
<
OpDef
>
op
;
Tensor
*
const
*
args
;
size_t
nargs
;
PyTypeObject
*
pytype
=
nullptr
;
bool
backward
=
false
;
class
scoped_disable
:
NonCopyableObj
{
Tensor
::
flags_t
saved_flags
;
public:
scoped_disable
(
Tensor
::
flags_t
flags
)
:
saved_flags
(
ApplyContext
::
global_disable
)
{
ApplyContext
::
global_disable
|=
flags
;
}
~
scoped_disable
()
{
ApplyContext
::
global_disable
=
saved_flags
;
}
};
};
using
apply_result_t
=
SmallVector
<
std
::
shared_ptr
<
Tensor
>
,
8
>
;
...
...
imperative/python/src/trace.cpp
浏览文件 @
dc250745
...
...
@@ -85,7 +85,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
for
(
auto
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
auto
tw
=
TensorWrapper
::
cast_safe
(
tup
[
i
].
ptr
());
auto
tw
=
TensorWrapper
::
try_cast
(
tup
[
i
].
ptr
());
outputs
.
emplace_back
(
tw
->
m_tensor
);
}
return
outputs
;
...
...
imperative/src/impl/ops/utility.cpp
0 → 100644
浏览文件 @
dc250745
/**
* \file imperative/src/impl/ops/utility.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h"
#include "../op_trait.h"
namespace
mgb
::
imperative
{
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
GenericPyOp
);
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/ops/utility.h
0 → 100644
浏览文件 @
dc250745
/**
* \file imperative/src/include/megbrain/imperative/ops/utility.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
#include <pybind11/pybind11.h>
namespace
mgb
::
imperative
{
struct
GenericPyOp
final
:
OpDefImplBase
<
GenericPyOp
>
{
pybind11
::
object
obj
;
GenericPyOp
(
pybind11
::
object
obj_
)
:
obj
(
std
::
move
(
obj_
))
{};
size_t
hash
()
const
override
{
return
pybind11
::
hash
(
obj
);
}
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
{
return
obj
.
equal
(
static_cast
<
const
GenericPyOp
&>
(
rhs
).
obj
);
}
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录