Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
177001d5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
177001d5
编写于
2月 16, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dispatch): allow dynamic type creation
GitOrigin-RevId: 27dde05cff7e1e0bf61c652a3ae1fe4def829ada
上级
150a6a61
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
655 addition
and
568 deletion
+655
-568
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+16
-11
imperative/python/src/grad.h
imperative/python/src/grad.h
+1
-0
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+8
-8
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+36
-32
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+2
-3
imperative/src/impl/basic_operators.cpp
imperative/src/impl/basic_operators.cpp
+1
-1
imperative/src/impl/basic_values.cpp
imperative/src/impl/basic_values.cpp
+5
-5
imperative/src/impl/dispatch.cpp
imperative/src/impl/dispatch.cpp
+10
-43
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+3
-1
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+2
-1
imperative/src/impl/subgraph_detail.cpp
imperative/src/impl/subgraph_detail.cpp
+7
-6
imperative/src/impl/transformations/eval.cpp
imperative/src/impl/transformations/eval.cpp
+13
-13
imperative/src/impl/transformations/grad.cpp
imperative/src/impl/transformations/grad.cpp
+23
-21
imperative/src/impl/transformations/lazy.cpp
imperative/src/impl/transformations/lazy.cpp
+5
-5
imperative/src/impl/transformations/scalar.cpp
imperative/src/impl/transformations/scalar.cpp
+52
-36
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+22
-39
imperative/src/impl/value.cpp
imperative/src/impl/value.cpp
+11
-33
imperative/src/include/megbrain/imperative/basic_values.h
imperative/src/include/megbrain/imperative/basic_values.h
+57
-61
imperative/src/include/megbrain/imperative/subgraph.h
imperative/src/include/megbrain/imperative/subgraph.h
+18
-19
imperative/src/include/megbrain/imperative/transformations/eval.h
...ve/src/include/megbrain/imperative/transformations/eval.h
+7
-11
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+20
-31
imperative/src/include/megbrain/imperative/transformations/lazy.h
...ve/src/include/megbrain/imperative/transformations/lazy.h
+7
-11
imperative/src/include/megbrain/imperative/transformations/scalar.h
.../src/include/megbrain/imperative/transformations/scalar.h
+6
-2
imperative/src/include/megbrain/imperative/transformations/symbol.h
.../src/include/megbrain/imperative/transformations/symbol.h
+9
-6
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+46
-22
imperative/src/include/megbrain/imperative/utils/allocator.h
imperative/src/include/megbrain/imperative/utils/allocator.h
+2
-0
imperative/src/include/megbrain/imperative/utils/span.h
imperative/src/include/megbrain/imperative/utils/span.h
+1
-1
imperative/src/include/megbrain/imperative/utils/stats.h
imperative/src/include/megbrain/imperative/utils/stats.h
+83
-24
imperative/src/include/megbrain/imperative/value.h
imperative/src/include/megbrain/imperative/value.h
+177
-117
imperative/src/test/backward_graph.cpp
imperative/src/test/backward_graph.cpp
+5
-5
未找到文件。
imperative/python/src/grad.cpp
浏览文件 @
177001d5
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "range/v3/all.hpp"
#include "range/v3/all.hpp"
#include "./helper.h"
#include "./transformation.h"
#include "./transformation.h"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
@@ -30,9 +31,7 @@ namespace {
...
@@ -30,9 +31,7 @@ namespace {
std
::
unordered_map
<
std
::
shared_ptr
<
GradKey
>
,
GradKeyWrapper
*>
grad_key_map
;
std
::
unordered_map
<
std
::
shared_ptr
<
GradKey
>
,
GradKeyWrapper
*>
grad_key_map
;
}
}
GradKeyWrapper
::
GradKeyWrapper
()
:
m_key
(
std
::
make_shared
<
GradKey
>
())
{
GradKeyWrapper
::
GradKeyWrapper
()
{}
grad_key_map
[
m_key
]
=
this
;
}
void
GradKeyWrapper
::
attach
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
void
GradKeyWrapper
::
attach
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
nargs
!=
2
)
{
if
(
nargs
!=
2
)
{
...
@@ -77,8 +76,8 @@ pybind11::function GradKeyWrapper::get_backward_closure(
...
@@ -77,8 +76,8 @@ pybind11::function GradKeyWrapper::get_backward_closure(
for
(
auto
&&
tensor
:
tensors
)
{
for
(
auto
&&
tensor
:
tensors
)
{
args
.
push_back
(
TensorWrapper
::
try_cast
(
tensor
.
ptr
())
->
m_tensor
->
data
());
args
.
push_back
(
TensorWrapper
::
try_cast
(
tensor
.
ptr
())
->
m_tensor
->
data
());
}
}
auto
closure
=
imperative
::
apply
(
GetBackwardColsure
(
self
->
m_key
),
args
)[
0
]
auto
closure
_value
=
imperative
::
apply
(
GetBackwardColsure
(
self
->
m_key
),
args
)[
0
];
.
as
<
FunctionValue
>
();
auto
closure
=
closure_value
.
as_ref
<
FunctionValue
>
();
auto
py_function
=
[
closure
](
std
::
vector
<
TensorWrapper
*>
tensors
)
{
auto
py_function
=
[
closure
](
std
::
vector
<
TensorWrapper
*>
tensors
)
{
std
::
vector
<
ValueRef
>
args
;
std
::
vector
<
ValueRef
>
args
;
for
(
auto
*
tw
:
tensors
)
{
for
(
auto
*
tw
:
tensors
)
{
...
@@ -90,11 +89,14 @@ pybind11::function GradKeyWrapper::get_backward_closure(
...
@@ -90,11 +89,14 @@ pybind11::function GradKeyWrapper::get_backward_closure(
}
}
PyObject
*
GradKeyWrapper
::
get_name
()
{
PyObject
*
GradKeyWrapper
::
get_name
()
{
return
py
::
cast
(
m_
key
->
name
()
).
release
().
ptr
();
return
py
::
cast
(
m_
name
).
release
().
ptr
();
}
}
void
GradKeyWrapper
::
set_name
(
py
::
handle
name
)
{
void
GradKeyWrapper
::
set_name
(
py
::
handle
name
)
{
m_key
->
name
(
py
::
cast
<
std
::
string
>
(
name
));
m_name
=
py
::
cast
<
std
::
string
>
(
name
);
if
(
m_key
)
{
m_key
->
name
(
m_name
);
}
}
}
PyObject
*
GradKeyWrapper
::
is_attached_to
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
PyObject
*
GradKeyWrapper
::
is_attached_to
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
...
@@ -115,7 +117,10 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
...
@@ -115,7 +117,10 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
}
}
void
GradKeyWrapper
::
enter
()
{
void
GradKeyWrapper
::
enter
()
{
m_transformation
=
std
::
make_shared
<
GradTransformation
>
(
m_key
);
m_transformation
=
std
::
make_shared
<
GradTransformation
>
();
m_key
=
m_transformation
->
key
();
m_key
->
name
(
m_name
);
grad_key_map
[
m_key
]
=
this
;
TransformationManager
::
get_instance
().
register_at
<
TransformationManager
::
Grad
>
(
TransformationManager
::
get_instance
().
register_at
<
TransformationManager
::
Grad
>
(
m_transformation
);
m_transformation
);
}
}
...
@@ -123,6 +128,8 @@ void GradKeyWrapper::enter() {
...
@@ -123,6 +128,8 @@ void GradKeyWrapper::enter() {
void
GradKeyWrapper
::
exit
()
{
void
GradKeyWrapper
::
exit
()
{
TransformationManager
::
get_instance
().
unregister
<
TransformationManager
::
Grad
>
(
TransformationManager
::
get_instance
().
unregister
<
TransformationManager
::
Grad
>
(
m_transformation
);
m_transformation
);
grad_key_map
.
erase
(
m_key
);
m_key
=
{};
m_transformation
.
reset
();
m_transformation
.
reset
();
}
}
...
@@ -138,8 +145,6 @@ GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
...
@@ -138,8 +145,6 @@ GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
return
grad_key_map
.
at
(
key
);
return
grad_key_map
.
at
(
key
);
}
}
GradKeyWrapper
::~
GradKeyWrapper
()
{
GradKeyWrapper
::~
GradKeyWrapper
()
{}
grad_key_map
.
erase
(
m_key
);
}
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
imperative/python/src/grad.h
浏览文件 @
177001d5
...
@@ -26,6 +26,7 @@ struct GradKeyWrapper : NonCopyableObj {
...
@@ -26,6 +26,7 @@ struct GradKeyWrapper : NonCopyableObj {
using
wrap_t
=
pyext17
::
wrap
<
GradKeyWrapper
>
;
using
wrap_t
=
pyext17
::
wrap
<
GradKeyWrapper
>
;
static
constexpr
auto
tp_name
=
pybind11
::
detail
::
_
(
"GradKey"
);
static
constexpr
auto
tp_name
=
pybind11
::
detail
::
_
(
"GradKey"
);
std
::
string
m_name
;
std
::
shared_ptr
<
GradKey
>
m_key
;
std
::
shared_ptr
<
GradKey
>
m_key
;
std
::
shared_ptr
<
GradTransformation
>
m_transformation
;
std
::
shared_ptr
<
GradTransformation
>
m_transformation
;
...
...
imperative/python/src/grad_override.cpp
浏览文件 @
177001d5
...
@@ -117,7 +117,7 @@ std::optional<ValueRefList> elemwise_grad_rule(
...
@@ -117,7 +117,7 @@ std::optional<ValueRefList> elemwise_grad_rule(
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
2
);
SmallVector
<
ValueRef
>
ret
(
2
);
if
(
!
grad
)
{
if
(
!
grad
)
{
return
ret
;
return
ret
;
}
}
...
@@ -147,7 +147,7 @@ std::optional<ValueRefList> reshape_grad_rule(
...
@@ -147,7 +147,7 @@ std::optional<ValueRefList> reshape_grad_rule(
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
2
);
SmallVector
<
ValueRef
>
ret
(
2
);
if
(
!
grad
)
{
if
(
!
grad
)
{
return
ret
;
return
ret
;
}
}
...
@@ -180,7 +180,7 @@ std::optional<ValueRefList> subtensor_grad_rule(
...
@@ -180,7 +180,7 @@ std::optional<ValueRefList> subtensor_grad_rule(
grad_op_
=
std
::
move
(
grad_op
)](
Span
<
ValueRef
>
grads
)
{
grad_op_
=
std
::
move
(
grad_op
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
1
);
SmallVector
<
ValueRef
>
ret
(
1
);
if
(
grad
&&
inputs
[
0
])
{
if
(
grad
&&
inputs
[
0
])
{
ValueRefList
args_
(
inputs
.
size
()
+
1
);
ValueRefList
args_
(
inputs
.
size
()
+
1
);
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
0
],
grad
.
dtype
());
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
0
],
grad
.
dtype
());
...
@@ -215,7 +215,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
...
@@ -215,7 +215,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
grad_op_
=
std
::
move
(
grad_op
)](
Span
<
ValueRef
>
grads
)
{
grad_op_
=
std
::
move
(
grad_op
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
1
);
SmallVector
<
ValueRef
>
ret
(
1
);
if
(
grad
&&
inputs
[
0
])
{
if
(
grad
&&
inputs
[
0
])
{
ValueRefList
args_
(
inputs
.
size
()
+
1
);
ValueRefList
args_
(
inputs
.
size
()
+
1
);
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
0
],
grad
.
dtype
());
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
0
],
grad
.
dtype
());
...
@@ -251,7 +251,7 @@ std::optional<ValueRefList> reduce_grad_rule(
...
@@ -251,7 +251,7 @@ std::optional<ValueRefList> reduce_grad_rule(
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
1
);
SmallVector
<
ValueRef
>
ret
(
1
);
if
(
grad
&&
shapes
[
0
])
{
if
(
grad
&&
shapes
[
0
])
{
ret
[
0
]
=
broadcast_to
(
grad
,
shapes
[
0
]);
ret
[
0
]
=
broadcast_to
(
grad
,
shapes
[
0
]);
}
}
...
@@ -274,7 +274,7 @@ std::optional<ValueRefList> addAxis_grad_rule(
...
@@ -274,7 +274,7 @@ std::optional<ValueRefList> addAxis_grad_rule(
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
),
flag_
=
flag
](
Span
<
ValueRef
>
grads
)
{
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
),
flag_
=
flag
](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
1
);
SmallVector
<
ValueRef
>
ret
(
1
);
if
(
grad
&&
flag_
)
{
if
(
grad
&&
flag_
)
{
ret
[
0
]
=
imperative
::
apply
(
*
grad_op_
,
grad
)[
0
];
ret
[
0
]
=
imperative
::
apply
(
*
grad_op_
,
grad
)[
0
];
}
}
...
@@ -297,7 +297,7 @@ std::optional<ValueRefList> removeAxis_grad_rule(
...
@@ -297,7 +297,7 @@ std::optional<ValueRefList> removeAxis_grad_rule(
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
),
flag_
=
flag
](
Span
<
ValueRef
>
grads
)
{
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
),
flag_
=
flag
](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
1
);
SmallVector
<
ValueRef
>
ret
(
1
);
if
(
grad
&&
flag_
)
{
if
(
grad
&&
flag_
)
{
ret
[
0
]
=
imperative
::
apply
(
*
grad_op_
,
grad
)[
0
];
ret
[
0
]
=
imperative
::
apply
(
*
grad_op_
,
grad
)[
0
];
}
}
...
@@ -316,7 +316,7 @@ std::optional<ValueRefList> fastpathcopy_grad_rule(
...
@@ -316,7 +316,7 @@ std::optional<ValueRefList> fastpathcopy_grad_rule(
maker
.
backward
([](
Span
<
ValueRef
>
grads
)
{
maker
.
backward
([](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
ValueRef
grad
=
grads
[
0
];
ValueRefList
ret
(
1
);
SmallVector
<
ValueRef
>
ret
(
1
);
if
(
grad
)
{
if
(
grad
)
{
ret
[
0
]
=
grad
;
ret
[
0
]
=
grad
;
}
}
...
...
imperative/python/src/tensor.cpp
浏览文件 @
177001d5
...
@@ -56,42 +56,44 @@ WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
...
@@ -56,42 +56,44 @@ WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
struct
SymbolVarContext
{
struct
SymbolVarContext
{
TransformationContext
context
;
TransformationContext
context
;
cg
::
ComputingGraph
*
graph
;
std
::
shared_ptr
<
SymbolTransformation
>
symbol_tsf
;
std
::
shared_ptr
<
ScalarTransformation
>
scalar_tsf
;
SymbolVarContext
(
cg
::
ComputingGraph
*
graph
)
:
graph
(
graph
)
{
SymbolVarContext
(
cg
::
ComputingGraph
*
graph
)
{
symbol_tsf
=
std
::
make_shared
<
SymbolTransformation
>
(
graph
);
scalar_tsf
=
std
::
make_shared
<
ScalarTransformation
>
();
Transformation
::
swap_context
(
context
);
Transformation
::
swap_context
(
context
);
}
}
void
init
()
{
void
init
()
{
std
::
make_shared
<
SymbolTransformation
>
(
graph
)
->
register_at
(
symbol_tsf
->
register_at
(
Transformation
::
top
());
Transformation
::
top
());
scalar_tsf
->
register_at
(
Transformation
::
top
());
std
::
make_shared
<
ScalarTransformation
>
()
->
register_at
(
Transformation
::
top
());
}
}
~
SymbolVarContext
()
{
Transformation
::
swap_context
(
context
);
}
ValueRef
symvar2val
(
py
::
handle
py_symbol_var
)
{
};
auto
*
symbol_var
=
py_symbol_var
.
cast
<
PySymbolVar
*>
();
ValueRef
value
=
symbol_tsf
->
value_type
().
make
(
symbol_var
->
m_node
);
if
(
symbol_var
->
is_scalar
)
{
value
=
scalar_tsf
->
value_type
().
make
(
value
);
}
return
value
;
}
ValueRef
symvar2val
(
py
::
handle
py_symbol_var
)
{
py
::
object
val2symvar
(
py
::
handle
typeobj
,
ValueRef
value
)
{
auto
*
symbol_var
=
py_symbol_var
.
cast
<
PySymbolVar
*>
();
bool
is_scalar
=
false
;
ValueRef
value
=
SymbolValue
::
make
(
symbol_var
->
m_node
);
if
(
auto
*
scalar_value
=
value
.
as
(
scalar_tsf
->
value_type
()))
{
if
(
symbol_var
->
is_scalar
)
{
value
=
scalar_value
->
value
();
value
=
ScalarValue
::
make
(
value
);
is_scalar
=
true
;
}
auto
*
node
=
value
.
cast
(
symbol_tsf
->
value_type
()).
node
();
auto
py_symbol_var
=
typeobj
(
pybind11
::
cast
(
node
,
pybind11
::
return_value_policy
::
automatic
));
py_symbol_var
.
cast
<
PySymbolVar
*>
()
->
is_scalar
=
is_scalar
;
return
py_symbol_var
;
}
}
return
value
;
}
py
::
object
val2symvar
(
py
::
handle
typeobj
,
ValueRef
value
)
{
~
SymbolVarContext
()
{
Transformation
::
swap_context
(
context
);
}
bool
is_scalar
=
false
;
};
if
(
auto
*
scalar_value
=
value
.
as
<
ScalarValue
>
())
{
value
=
scalar_value
->
value
();
is_scalar
=
true
;
}
auto
*
node
=
value
.
cast
<
SymbolValue
>
().
node
();
auto
py_symbol_var
=
typeobj
(
pybind11
::
cast
(
node
,
pybind11
::
return_value_policy
::
automatic
));
py_symbol_var
.
cast
<
PySymbolVar
*>
()
->
is_scalar
=
is_scalar
;
return
py_symbol_var
;
}
}
// namespace
}
// namespace
...
@@ -130,19 +132,21 @@ PyObject* py_apply(
...
@@ -130,19 +132,21 @@ PyObject* py_apply(
auto
op
=
py
::
handle
(
py_op
).
cast
<
std
::
shared_ptr
<
OpDef
>>
();
auto
op
=
py
::
handle
(
py_op
).
cast
<
std
::
shared_ptr
<
OpDef
>>
();
SmallVector
<
ValueRef
,
8
>
tensors
(
nargs
);
SmallVector
<
ValueRef
,
8
>
tensors
(
nargs
);
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
])))
{
bool
is_symbol_var
=
(
!
TensorWrapper
::
try_cast
(
args
[
0
]))
&&
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
]));
if
(
is_symbol_var
)
{
// swap to a special context to reuse scalar handle
// swap to a special context to reuse scalar handle
SymbolVarContext
context
(
SymbolVarContext
context
(
py
::
handle
(
args
[
0
]).
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
());
py
::
handle
(
args
[
0
]).
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
());
context
.
init
();
context
.
init
();
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
tensors
[
i
]
=
symvar2val
(
args
[
i
]);
tensors
[
i
]
=
context
.
symvar2val
(
args
[
i
]);
}
}
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
auto
ret
=
pybind11
::
tuple
(
outputs
.
size
());
auto
ret
=
pybind11
::
tuple
(
outputs
.
size
());
auto
typeobj
=
py
::
handle
(
args
[
0
]).
get_type
();
auto
typeobj
=
py
::
handle
(
args
[
0
]).
get_type
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
ret
[
i
]
=
val2symvar
(
typeobj
,
outputs
[
i
]);
ret
[
i
]
=
context
.
val2symvar
(
typeobj
,
outputs
[
i
]);
}
}
return
ret
.
release
().
ptr
();
return
ret
.
release
().
ptr
();
}
}
...
@@ -161,7 +165,7 @@ PyObject* py_apply(
...
@@ -161,7 +165,7 @@ PyObject* py_apply(
}
}
}
}
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
auto
outputs
=
[
&
]
{
return
imperative
::
apply
(
*
op
,
tensors
);
}(
);
size_t
nout
=
outputs
.
size
();
size_t
nout
=
outputs
.
size
();
auto
ret
=
py
::
tuple
(
nout
);
auto
ret
=
py
::
tuple
(
nout
);
for
(
size_t
i
=
0
;
i
<
nout
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nout
;
++
i
)
{
...
@@ -1573,9 +1577,9 @@ void init_tensor(py::module m) {
...
@@ -1573,9 +1577,9 @@ void init_tensor(py::module m) {
SymbolVarContext
context
(
graph
);
SymbolVarContext
context
(
graph
);
context
.
init
();
context
.
init
();
auto
output
=
reduce_to_scalar
(
auto
output
=
reduce_to_scalar
(
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
symvar2val
(
tensor
));
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
context
.
symvar2val
(
tensor
));
auto
typeobj
=
tensor
.
get_type
();
auto
typeobj
=
tensor
.
get_type
();
return
val2symvar
(
typeobj
,
output
);
return
context
.
val2symvar
(
typeobj
,
output
);
}
else
{
}
else
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
output
=
reduce_to_scalar
(
auto
output
=
reduce_to_scalar
(
...
...
imperative/python/src/transformation.h
浏览文件 @
177001d5
...
@@ -67,10 +67,9 @@ struct TransformationManager {
...
@@ -67,10 +67,9 @@ struct TransformationManager {
}
}
};
};
class
PyValue
final
class
PyValue
final
:
public
PrimitiveValue
<
PyValue
,
pybind11
::
object
>
{
:
public
MixinValueImpl
<
PyValue
,
ValueKind
::
Object
,
pybind11
::
object
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
{
std
::
string
to_string
()
const
{
return
pybind11
::
str
((
const
pybind11
::
object
&
)
*
this
).
cast
<
std
::
string
>
();
return
pybind11
::
str
((
const
pybind11
::
object
&
)
*
this
).
cast
<
std
::
string
>
();
...
...
imperative/src/impl/basic_operators.cpp
浏览文件 @
177001d5
...
@@ -63,7 +63,7 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
...
@@ -63,7 +63,7 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
MegBrainError
,
MegBrainError
,
"unknown input type, expects HostStorage or DeviceStorage, got "
"unknown input type, expects HostStorage or DeviceStorage, got "
"%s"
,
"%s"
,
input
.
name
()
->
c_str
());
input
.
to_string
().
c_str
());
}
}
}
}
mgb_assert
(
mgb_assert
(
...
...
imperative/src/impl/basic_values.cpp
浏览文件 @
177001d5
...
@@ -12,7 +12,7 @@ std::string CompNodeValue::to_string() const {
...
@@ -12,7 +12,7 @@ std::string CompNodeValue::to_string() const {
}
}
std
::
string
BoolValue
::
to_string
()
const
{
std
::
string
BoolValue
::
to_string
()
const
{
return
(
*
m_value
)
?
"true"
:
"false"
;
return
(
*
this
)
?
"true"
:
"false"
;
}
}
std
::
string
HostStorage
::
to_string
()
const
{
std
::
string
HostStorage
::
to_string
()
const
{
...
@@ -26,10 +26,10 @@ std::string DeviceStorage::to_string() const {
...
@@ -26,10 +26,10 @@ std::string DeviceStorage::to_string() const {
std
::
string
HostValue
::
to_string
()
const
{
std
::
string
HostValue
::
to_string
()
const
{
return
ssprintf
(
return
ssprintf
(
"HostValue{device=%s, dtype=%s, shape=%s}"
,
device
().
to_string
().
c_str
(),
"HostValue{device=%s, dtype=%s, shape=%s}"
,
device
().
to_string
().
c_str
(),
m_dtype
.
name
(),
m_shape
.
to_string
().
c_str
());
dtype
().
name
(),
shape
()
.
to_string
().
c_str
());
}
}
HostTensorND
Host
Value
::
as_nd
(
bool
allow_scalar
)
const
{
HostTensorND
Host
Tensor
::
as_nd
(
bool
allow_scalar
)
const
{
HostTensorND
nd
;
HostTensorND
nd
;
TensorShape
tensor_shape
;
TensorShape
tensor_shape
;
if
(
m_shape
.
is_scalar
())
{
if
(
m_shape
.
is_scalar
())
{
...
@@ -45,10 +45,10 @@ HostTensorND HostValue::as_nd(bool allow_scalar) const {
...
@@ -45,10 +45,10 @@ HostTensorND HostValue::as_nd(bool allow_scalar) const {
std
::
string
DeviceValue
::
to_string
()
const
{
std
::
string
DeviceValue
::
to_string
()
const
{
return
ssprintf
(
return
ssprintf
(
"DeviceValue{device=%s, dtype=%s, shape=%s}"
,
device
().
to_string
().
c_str
(),
"DeviceValue{device=%s, dtype=%s, shape=%s}"
,
device
().
to_string
().
c_str
(),
m_dtype
.
name
(),
m_shape
.
to_string
().
c_str
());
dtype
().
name
(),
shape
()
.
to_string
().
c_str
());
}
}
DeviceTensorND
Device
Value
::
as_nd
(
bool
allow_scalar
)
const
{
DeviceTensorND
Device
Tensor
::
as_nd
(
bool
allow_scalar
)
const
{
DeviceTensorND
nd
;
DeviceTensorND
nd
;
TensorShape
tensor_shape
;
TensorShape
tensor_shape
;
if
(
m_shape
.
is_scalar
())
{
if
(
m_shape
.
is_scalar
())
{
...
...
imperative/src/impl/dispatch.cpp
浏览文件 @
177001d5
...
@@ -19,46 +19,18 @@
...
@@ -19,46 +19,18 @@
namespace
mgb
{
namespace
mgb
{
namespace
imperative
{
namespace
imperative
{
namespace
{
MGB_NOINLINE
void
copy_outputs
(
ForwardAllocator
<
ValueRef
>&
allocator
,
ValueRefList
&
outputs
)
{
size_t
nr_outputs
=
outputs
.
size
();
if
(
mgb_likely
(
nr_outputs
==
1
))
{
ValueRef
output_copy
;
output_copy
=
outputs
[
0
];
allocator
.
clear
();
outputs
=
ValueRefList
({
output_copy
});
}
else
if
(
!
outputs
.
empty
())
{
SmallVector
<
ValueRef
>
outputs_copy
(
nr_outputs
);
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
outputs_copy
[
i
]
=
outputs
[
i
];
}
outputs
.
clear
();
allocator
.
clear
();
outputs
=
{
outputs_copy
.
begin
(),
outputs_copy
.
end
()};
}
else
{
allocator
.
clear
();
}
}
}
// namespace
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&
context
=
Transformation
::
get_context
();
auto
&
context
=
Transformation
::
get_context
();
size_t
&
depth
=
context
.
next_transformation
;
size_t
&
depth
=
context
.
next_transformation
;
bool
top
=
depth
==
0
;
// TODO: add fallback transformation
auto
outputs
=
([
&
]
{
bool
fallback
=
depth
>=
context
.
transformations
.
size
();
if
(
mgb_unlikely
(
depth
>=
context
.
transformations
.
size
()))
{
if
(
mgb_unlikely
(
fallback
))
{
return
op
.
fallback
(
inputs
);
return
op
.
fallback
(
inputs
);
}
else
{
}
else
{
auto
&
transformation
=
*
context
.
transformations
[
depth
++
];
auto
&
transformation
=
*
context
.
transformations
[
depth
++
];
CleanupGuard
_
{[
&
]
{
--
depth
;
}};
CleanupGuard
_
{[
&
]
{
--
depth
;
}};
return
transformation
.
apply_transformation
(
op
,
inputs
);
return
transformation
.
apply_transformation
(
op
,
inputs
);
}
})();
if
(
mgb_unlikely
(
top
))
{
copy_outputs
(
context
.
allocator
,
outputs
);
}
}
return
outputs
;
}
}
ValueRefList
apply
(
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
apply
(
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
)
{
...
@@ -66,12 +38,7 @@ ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) {
...
@@ -66,12 +38,7 @@ ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) {
}
}
ValueRefList
apply
(
const
Subgraph
&
graph
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
apply
(
const
Subgraph
&
graph
,
Span
<
ValueRef
>
inputs
)
{
SmallVector
<
ValueRef
>
inputs_storage
;
auto
apply_functor
=
[](
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
size_t
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inputs_storage
.
push_back
(
inputs
[
i
]);
}
auto
apply_functor
=
[](
std
::
shared_ptr
<
OpDef
>
op
,
SmallVector
<
ValueRef
>
inputs
,
size_t
)
{
auto
outputs
=
imperative
::
apply
(
*
op
,
inputs
);
auto
outputs
=
imperative
::
apply
(
*
op
,
inputs
);
return
SmallVector
<
ValueRef
>
(
outputs
.
begin
(),
outputs
.
end
());
return
SmallVector
<
ValueRef
>
(
outputs
.
begin
(),
outputs
.
end
());
};
};
...
@@ -93,7 +60,7 @@ ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) {
...
@@ -93,7 +60,7 @@ ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) {
HostStorage
::
make
(
host_value
.
storage
()),
HostStorage
::
make
(
host_value
.
storage
()),
DeviceStorage
::
make
(
device_value
.
storage
()))[
0
];
DeviceStorage
::
make
(
device_value
.
storage
()))[
0
];
};
};
auto
outputs
=
graph
.
apply
(
inputs
_storage
,
apply_functor
,
make_const
);
auto
outputs
=
graph
.
apply
(
inputs
,
apply_functor
,
make_const
);
return
ValueRefList
{
outputs
.
begin
(),
outputs
.
end
()};
return
ValueRefList
{
outputs
.
begin
(),
outputs
.
end
()};
}
}
...
...
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
177001d5
...
@@ -331,6 +331,7 @@ void ChannelImpl::dispatch_kernel(
...
@@ -331,6 +331,7 @@ void ChannelImpl::dispatch_kernel(
cmd
.
inputs
=
std
::
move
(
input_infos
);
cmd
.
inputs
=
std
::
move
(
input_infos
);
cmd
.
outputs
.
reserve
(
output_descs
.
size
());
cmd
.
outputs
.
reserve
(
output_descs
.
size
());
outputs
->
reserve
(
output_descs
.
size
());
outputs
->
reserve
(
output_descs
.
size
());
for
(
int
i
=
0
;
i
<
output_descs
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
output_descs
.
size
();
++
i
)
{
auto
&&
desc
=
output_descs
[
i
];
auto
&&
desc
=
output_descs
[
i
];
auto
info
=
alloc
();
auto
info
=
alloc
();
...
@@ -730,7 +731,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
...
@@ -730,7 +731,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
input_descs
.
push_back
({{{},
input
->
dtype
()},
input
->
comp_node
()});
input_descs
.
push_back
({{{},
input
->
dtype
()},
input
->
comp_node
()});
}
}
auto
forward_graph
=
OpDef
::
make_forward_graph
(
def
,
input_descs
);
auto
forward_graph
=
OpDef
::
make_forward_graph
(
def
,
input_descs
);
auto
outputs
=
forward_graph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
auto
outputs
=
forward_graph
.
apply
<
TensorPtr
>
(
inputs
,
apply_functor
,
const_functor
);
return
outputs
;
return
outputs
;
}
}
return
OpDef
::
apply_on_physical_tensor
(
def
,
inputs
);
return
OpDef
::
apply_on_physical_tensor
(
def
,
inputs
);
...
...
imperative/src/impl/ops/elemwise.cpp
浏览文件 @
177001d5
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
...
@@ -101,7 +102,7 @@ void apply_on_device_tensornd(
...
@@ -101,7 +102,7 @@ void apply_on_device_tensornd(
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
)
{
SmallVector
<
DeviceTensorND
>*
outputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Elemwise
>
();
auto
&&
op_def
=
def
.
cast_final_safe
<
Elemwise
>
();
auto
trait
=
megdnn
::
Elemwise
::
ModeTrait
::
from_mode
(
op_def
.
mode
);
auto
&&
trait
=
megdnn
::
Elemwise
::
ModeTrait
::
from_mode
(
op_def
.
mode
);
mgb_assert
(
mgb_assert
(
inputs
.
size
()
==
trait
.
arity
,
"%s expects %u inputs; got %zu actually"
,
inputs
.
size
()
==
trait
.
arity
,
"%s expects %u inputs; got %zu actually"
,
trait
.
name
,
trait
.
arity
,
inputs
.
size
());
trait
.
name
,
trait
.
arity
,
inputs
.
size
());
...
...
imperative/src/impl/subgraph_detail.cpp
浏览文件 @
177001d5
...
@@ -36,7 +36,7 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
...
@@ -36,7 +36,7 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
.
node
();
.
node
();
};
};
auto
subgraph
=
def
.
trait
()
->
make_forward_graph
(
def
,
input_descs
);
auto
subgraph
=
def
.
trait
()
->
make_forward_graph
(
def
,
input_descs
);
auto
outputs
=
subgraph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
auto
outputs
=
subgraph
.
apply
<
VarNode
*>
(
inputs
,
apply_functor
,
const_functor
);
return
outputs
;
return
outputs
;
}
}
...
@@ -56,7 +56,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
...
@@ -56,7 +56,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
value
->
layout
(),
value
->
comp_node
(),
value
->
layout
(),
value
->
comp_node
(),
value
->
get_value
().
proxy_to_default_cpu
()};
value
->
get_value
().
proxy_to_default_cpu
()};
};
};
auto
outputs
=
subgraph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
auto
outputs
=
subgraph
.
apply
<
LogicalTensorDesc
>
(
inputs
,
apply_functor
,
const_functor
);
return
{
outputs
,
all_validated
};
return
{
outputs
,
all_validated
};
}
}
...
@@ -72,7 +73,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -72,7 +73,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return
OpDef
::
apply_on_physical_tensor
(
*
op
,
inputs
);
return
OpDef
::
apply_on_physical_tensor
(
*
op
,
inputs
);
};
};
auto
const_functor
=
[
&
](
const
TensorPtr
&
value
)
{
return
value
;
};
auto
const_functor
=
[
&
](
const
TensorPtr
&
value
)
{
return
value
;
};
auto
outputs
=
subgraph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
auto
outputs
=
subgraph
.
apply
<
TensorPtr
>
(
inputs
,
apply_functor
,
const_functor
);
return
outputs
;
return
outputs
;
}
}
...
@@ -94,7 +95,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
...
@@ -94,7 +95,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
};
};
GradContext
<
var_t
>
grad_context
{
accum_grad
};
GradContext
<
var_t
>
grad_context
{
accum_grad
};
auto
input_vars
=
builder
.
write_inputs
(
inputs
);
auto
input_vars
=
builder
.
write_inputs
(
inputs
);
auto
outputs
=
forward_graph
.
apply
(
auto
outputs
=
forward_graph
.
apply
<
var_t
>
(
input_vars
,
std
::
bind
(
&
decltype
(
builder
)
::
write_expr
,
&
builder
,
_1
,
_2
,
_3
),
input_vars
,
std
::
bind
(
&
decltype
(
builder
)
::
write_expr
,
&
builder
,
_1
,
_2
,
_3
),
[
&
](
TensorPtr
constant
)
{
[
&
](
TensorPtr
constant
)
{
return
builder
.
write_constant
(
return
builder
.
write_constant
(
...
@@ -102,7 +103,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
...
@@ -102,7 +103,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
});
});
size_t
nr_outputs
=
outputs
.
size
();
size_t
nr_outputs
=
outputs
.
size
();
auto
apply_mask
=
[](
auto
&&
values
,
SmallVector
<
bool
>
mask
)
{
auto
apply_mask
=
[](
auto
&&
values
,
SmallVector
<
bool
>
mask
)
{
mgb_assert
(
mask
.
size
()
==
values
.
size
()
,
""
);
mgb_assert
(
mask
.
size
()
==
values
.
size
());
std
::
decay_t
<
decltype
(
values
)
>
results
;
std
::
decay_t
<
decltype
(
values
)
>
results
;
for
(
size_t
i
=
0
;
i
<
mask
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
mask
.
size
();
++
i
)
{
if
(
mask
[
i
])
{
if
(
mask
[
i
])
{
...
@@ -143,7 +144,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
...
@@ -143,7 +144,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
return
builder
.
write_constant
(
return
builder
.
write_constant
(
constant
,
{
constant
->
layout
(),
constant
->
comp_node
()});
constant
,
{
constant
->
layout
(),
constant
->
comp_node
()});
};
};
return
bg
.
apply
(
grad_inputs
,
apply_functor
,
const_functor
);
return
bg
.
apply
<
var_t
>
(
grad_inputs
,
apply_functor
,
const_functor
);
});
});
builder
.
add_outputs
(
grad_context
.
get_grads
(
input_vars
));
builder
.
add_outputs
(
grad_context
.
get_grads
(
input_vars
));
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
...
...
imperative/src/impl/transformations/eval.cpp
浏览文件 @
177001d5
...
@@ -10,20 +10,19 @@
...
@@ -10,20 +10,19 @@
*/
*/
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/imperative/utils/stats.h"
namespace
mgb
{
namespace
mgb
{
namespace
imperative
{
namespace
imperative
{
DTypeValue
::
ref_t
Interpreter
Info
::
dtype
()
const
{
DTypeValue
::
ref_t
Interpreter
Value
::
dtype
()
const
{
if
(
!
m_dtype
)
{
if
(
!
m_dtype
)
{
m_dtype
=
DTypeValue
::
make
(
handle
()
->
channel
()
->
get_dtype
(
handle
()
->
handle
()));
m_dtype
=
DTypeValue
::
make
(
handle
()
->
channel
()
->
get_dtype
(
handle
()
->
handle
()));
}
}
return
m_dtype
;
return
m_dtype
;
}
}
CompNodeValue
::
ref_t
Interpreter
Info
::
comp_node
()
const
{
CompNodeValue
::
ref_t
Interpreter
Value
::
comp_node
()
const
{
if
(
!
m_comp_node
)
{
if
(
!
m_comp_node
)
{
m_comp_node
=
CompNodeValue
::
make
(
m_comp_node
=
CompNodeValue
::
make
(
handle
()
->
channel
()
->
get_device
(
handle
()
->
handle
()));
handle
()
->
channel
()
->
get_device
(
handle
()
->
handle
()));
...
@@ -31,7 +30,7 @@ CompNodeValue::ref_t InterpreterInfo::comp_node() const {
...
@@ -31,7 +30,7 @@ CompNodeValue::ref_t InterpreterInfo::comp_node() const {
return
m_comp_node
;
return
m_comp_node
;
}
}
ShapeValue
::
ref_t
Interpreter
Info
::
shape
()
const
{
ShapeValue
::
ref_t
Interpreter
Value
::
shape
()
const
{
if
(
!
m_shape
)
{
if
(
!
m_shape
)
{
m_shape
=
ShapeValue
::
make
(
m_shape
=
ShapeValue
::
make
(
ValueShape
::
from
(
handle
()
->
channel
()
->
get_shape
(
handle
()
->
handle
())));
ValueShape
::
from
(
handle
()
->
channel
()
->
get_shape
(
handle
()
->
handle
())));
...
@@ -51,21 +50,22 @@ ValueRefList InterpreterTransformation::apply_op(
...
@@ -51,21 +50,22 @@ ValueRefList InterpreterTransformation::apply_op(
}
}
}};
}};
for
(
auto
input
:
inputs
)
{
for
(
auto
input
:
inputs
)
{
input_handles
.
push_back
(
input
.
cast
<
InterpreterValue
>
(
).
handle
()
->
handle
());
input_handles
.
push_back
(
input
.
cast
(
m_value_type
).
handle
()
->
handle
());
}
}
output_handles
=
output_handles
=
m_channel
->
apply_op
(
apply_op
.
op
().
shared_from_this
(),
input_handles
);
m_channel
->
apply_op
(
apply_op
.
op
().
shared_from_this
(),
input_handles
);
ValueRefList
outputs
(
output_handles
.
size
());
ValueRefList
outputs
(
output_handles
.
size
());
for
(
size_t
i
=
0
;
i
<
output_handles
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_handles
.
size
();
++
i
)
{
outputs
[
i
]
=
InterpreterValue
::
make
(
share_handle
(
output_handles
[
i
]));
outputs
[
i
]
=
m_value_type
.
make
(
share_handle
(
output_handles
[
i
]));
output_handles
[
i
]
=
nullptr
;
output_handles
[
i
]
=
nullptr
;
}
}
output_handles
.
clear
();
return
outputs
;
return
outputs
;
}
}
ValueRefList
InterpreterTransformation
::
apply_get_attr
(
ValueRefList
InterpreterTransformation
::
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
auto
&
input
=
inputs
.
item
().
cast
<
InterpreterValue
>
(
);
auto
&
input
=
inputs
.
item
().
cast
(
m_value_type
);
ValueRef
output
;
ValueRef
output
;
switch
(
get_attr
.
attr
())
{
switch
(
get_attr
.
attr
())
{
case
GetAttr
::
DType
:
case
GetAttr
::
DType
:
...
@@ -98,10 +98,10 @@ ValueRefList InterpreterTransformation::apply_create_tensor(
...
@@ -98,10 +98,10 @@ ValueRefList InterpreterTransformation::apply_create_tensor(
if
(
!
args
.
device
)
{
if
(
!
args
.
device
)
{
// implies H2D
// implies H2D
mgb_assert
(
args
.
host
,
"neither host and device value is valid"
);
mgb_assert
(
args
.
host
,
"neither host and device value is valid"
);
return
{
InterpreterValue
::
make
(
share_handle
(
return
{
m_value_type
.
make
(
share_handle
(
m_channel
->
put
(
*
args
.
host
,
args
.
kind
==
CreateTensor
::
Unique
)))};
m_channel
->
put
(
*
args
.
host
,
args
.
kind
==
CreateTensor
::
Unique
)))};
}
else
{
}
else
{
return
{
InterpreterValue
::
make
(
share_handle
(
m_channel
->
put
(
return
{
m_value_type
.
make
(
share_handle
(
m_channel
->
put
(
*
args
.
device
,
args
.
host
?
*
args
.
host
:
HostTensorND
())))};
*
args
.
device
,
args
.
host
?
*
args
.
host
:
HostTensorND
())))};
}
}
}
}
...
@@ -119,7 +119,7 @@ ValueRefList InterpreterTransformation::apply_transformation(
...
@@ -119,7 +119,7 @@ ValueRefList InterpreterTransformation::apply_transformation(
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
return
apply_create_tensor
(
*
create_tensor
,
inputs
);
return
apply_create_tensor
(
*
create_tensor
,
inputs
);
}
else
if
(
auto
*
dtr_command
=
op
.
as
<
DTRCommand
>
())
{
}
else
if
(
auto
*
dtr_command
=
op
.
as
<
DTRCommand
>
())
{
auto
handle
=
inputs
[
0
].
cast
<
InterpreterValue
>
(
).
handle
()
->
handle
();
auto
handle
=
inputs
[
0
].
cast
(
m_value_type
).
handle
()
->
handle
();
switch
(
dtr_command
->
kind
())
{
switch
(
dtr_command
->
kind
())
{
case
DTRCommand
::
Drop
:
case
DTRCommand
::
Drop
:
m_channel
->
drop
(
handle
);
m_channel
->
drop
(
handle
);
...
@@ -129,10 +129,10 @@ ValueRefList InterpreterTransformation::apply_transformation(
...
@@ -129,10 +129,10 @@ ValueRefList InterpreterTransformation::apply_transformation(
}
}
return
{};
return
{};
}
else
if
(
auto
*
rename_value
=
op
.
as
<
RenameValue
>
())
{
}
else
if
(
auto
*
rename_value
=
op
.
as
<
RenameValue
>
())
{
auto
&
input
=
inputs
[
0
].
cast
<
InterpreterValue
>
(
);
auto
&
input
=
inputs
[
0
].
cast
(
m_value_type
);
return
{
InterpreterValue
::
make
(
input
.
handle
(),
rename_value
->
name
())};
return
{
m_value_type
.
make
(
input
.
handle
(),
rename_value
->
name
())};
}
else
if
(
op
.
is
<
GetName
>
())
{
}
else
if
(
op
.
is
<
GetName
>
())
{
auto
name
=
inputs
[
0
].
cast
<
InterpreterValue
>
(
).
name
();
auto
name
=
inputs
[
0
].
cast
(
m_value_type
).
name
();
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
return
{
StringValue
::
make
(
name
)};
return
{
StringValue
::
make
(
name
)};
}
else
{
}
else
{
...
...
imperative/src/impl/transformations/grad.cpp
浏览文件 @
177001d5
...
@@ -68,7 +68,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
...
@@ -68,7 +68,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
size_t
count
=
std
::
count_if
(
size_t
count
=
std
::
count_if
(
save_for_backward
.
begin
(),
save_for_backward
.
end
(),
ranges
::
identity
{});
save_for_backward
.
begin
(),
save_for_backward
.
end
(),
ranges
::
identity
{});
if
(
!
backward_graph
->
precomp
.
empty
())
{
if
(
!
backward_graph
->
precomp
.
empty
())
{
ValueRefList
inputs_and_outputs
(
inputs
.
size
()
+
outputs
.
size
());
SmallVector
<
ValueRef
>
inputs_and_outputs
(
inputs
.
size
()
+
outputs
.
size
());
auto
it
=
inputs_and_outputs
.
begin
();
auto
it
=
inputs_and_outputs
.
begin
();
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
*
it
++
=
input
;
*
it
++
=
input
;
...
@@ -94,7 +94,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
...
@@ -94,7 +94,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
}
}
}
}
void
BackwardGraphWithClosure
::
operator
()(
void
BackwardGraphWithClosure
::
operator
()(
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
Span
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
ValueRef
args
[
closure
.
size
()
+
grads
.
size
()];
ValueRef
args
[
closure
.
size
()
+
grads
.
size
()];
size_t
nargs
=
0
;
size_t
nargs
=
0
;
for
(
auto
&&
value
:
closure
)
{
for
(
auto
&&
value
:
closure
)
{
...
@@ -114,7 +114,9 @@ void BackwardGraphWithClosure::operator()(
...
@@ -114,7 +114,9 @@ void BackwardGraphWithClosure::operator()(
if
(
null_grad
)
{
if
(
null_grad
)
{
return
;
return
;
}
}
auto
igrads
=
imperative
::
apply
(
backward_graph
->
backward
,
Span
(
args
,
nargs
));
auto
igrads_
=
imperative
::
apply
(
backward_graph
->
backward
,
Span
(
args
,
nargs
));
SmallVector
<
ValueRef
>
igrads
=
{
igrads_
.
begin
(),
igrads_
.
end
()};
igrads_
.
clear
();
auto
&&
iter
=
igrads
.
begin
();
auto
&&
iter
=
igrads
.
begin
();
for
(
auto
[
i
,
p
]
:
ranges
::
views
::
enumerate
(
backward_graph
->
input_has_grad
))
{
for
(
auto
[
i
,
p
]
:
ranges
::
views
::
enumerate
(
backward_graph
->
input_has_grad
))
{
if
(
p
)
{
if
(
p
)
{
...
@@ -125,7 +127,7 @@ void BackwardGraphWithClosure::operator()(
...
@@ -125,7 +127,7 @@ void BackwardGraphWithClosure::operator()(
}
}
void
CustomBackward
::
operator
()(
void
CustomBackward
::
operator
()(
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
Span
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
size_t
nargs
=
grads
.
size
();
size_t
nargs
=
grads
.
size
();
ValueRef
args
[
nargs
];
ValueRef
args
[
nargs
];
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
...
@@ -206,7 +208,7 @@ void GradKey::backward() {
...
@@ -206,7 +208,7 @@ void GradKey::backward() {
mgb_throw
(
AssertionError
,
"invalid backward"
);
mgb_throw
(
AssertionError
,
"invalid backward"
);
}
else
{
}
else
{
mgb_assert
(
grad_fn
->
m_slots
.
size
()
>
0
);
mgb_assert
(
grad_fn
->
m_slots
.
size
()
>
0
);
ValueRefList
grads
(
grad_fn
->
m_slots
.
size
());
SmallVector
<
ValueRef
>
grads
(
grad_fn
->
m_slots
.
size
());
auto
iter
=
grads
.
begin
();
auto
iter
=
grads
.
begin
();
for
(
auto
&&
slot
:
grad_fn
->
m_slots
)
{
for
(
auto
&&
slot
:
grad_fn
->
m_slots
)
{
*
iter
++
=
slot
.
m_grad
;
*
iter
++
=
slot
.
m_grad
;
...
@@ -231,11 +233,9 @@ void GradKey::backward() {
...
@@ -231,11 +233,9 @@ void GradKey::backward() {
GradValue
::
ref_t
GradKey
::
attach
(
GradValue
::
ref_t
GradKey
::
attach
(
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
)
{
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
)
{
auto
grad_value
=
tensor
.
as_ref
<
GradValue
>
();
auto
grad_value
=
tensor
.
as_ref
(
m_value_type
);
if
(
grad_value
&&
grad_value
->
has_key
(
shared_from_this
()))
{
if
(
grad_value
)
{
mgb_assert
(
mgb_assert
(
!
tensor
.
cast
(
m_value_type
).
slot
()
->
callback
,
"callback exists"
);
!
tensor
.
cast
<
GradValue
>
().
slot_for
(
shared_from_this
())
->
callback
,
"callback exists"
);
}
else
{
}
else
{
GradSlotPtr
grad_slot
;
GradSlotPtr
grad_slot
;
auto
&
grad_fn
=
grad_slot
.
m_fn
;
auto
&
grad_fn
=
grad_slot
.
m_fn
;
...
@@ -243,9 +243,9 @@ GradValue::ref_t GradKey::attach(
...
@@ -243,9 +243,9 @@ GradValue::ref_t GradKey::attach(
grad_fn
->
m_key
=
shared_from_this
();
grad_fn
->
m_key
=
shared_from_this
();
grad_fn
->
m_slots
.
resize
(
1
);
grad_fn
->
m_slots
.
resize
(
1
);
grad_slot
.
m_index
=
0
;
grad_slot
.
m_index
=
0
;
grad_value
=
GradValue
::
make
(
tensor
,
shared_from_this
(),
grad_slot
);
grad_value
=
m_value_type
.
make
(
tensor
,
shared_from_this
(),
grad_slot
);
}
}
grad_value
->
slot
_for
(
shared_from_this
()
).
m_fn
->
m_slots
[
0
].
callback
=
callback
;
grad_value
->
slot
(
).
m_fn
->
m_slots
[
0
].
callback
=
callback
;
return
grad_value
;
return
grad_value
;
}
}
...
@@ -263,7 +263,7 @@ void GradKey::freeze() {
...
@@ -263,7 +263,7 @@ void GradKey::freeze() {
ValueRefList
GradTransformation
::
apply_transformation
(
ValueRefList
GradTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
fallback
=
[
&
]
{
auto
fallback
=
[
&
]
{
ValueRefList
unwrapped_inputs
(
inputs
.
size
());
SmallVector
<
ValueRef
>
unwrapped_inputs
(
inputs
.
size
());
{
{
// overhead
// overhead
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
...
@@ -367,7 +367,7 @@ ValueRefList GradTransformation::apply_transformation(
...
@@ -367,7 +367,7 @@ ValueRefList GradTransformation::apply_transformation(
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
backward
.
input_has_grad
(
i
)
&&
require_grads
[
i
])
{
if
(
backward
.
input_has_grad
(
i
)
&&
require_grads
[
i
])
{
auto
&
input_grad_slot
=
auto
&
input_grad_slot
=
inputs
[
i
].
cast
<
GradValue
>
().
slot_for
(
m_key
);
inputs
[
i
].
cast
(
m_value_type
).
slot
(
);
grad_fn
->
m_dests
.
emplace_back
(
input_grad_slot
);
grad_fn
->
m_dests
.
emplace_back
(
input_grad_slot
);
grad_fn
->
m_dests
.
back
().
m_producer_record
.
insert_after
(
grad_fn
->
m_dests
.
back
().
m_producer_record
.
insert_after
(
input_grad_slot
->
m_producer_head
);
input_grad_slot
->
m_producer_head
);
...
@@ -378,7 +378,7 @@ ValueRefList GradTransformation::apply_transformation(
...
@@ -378,7 +378,7 @@ ValueRefList GradTransformation::apply_transformation(
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
backward
.
output_requires_grad
(
i
))
{
if
(
backward
.
output_requires_grad
(
i
))
{
// little overhead: Value::make
// little overhead: Value::make
auto
grad_value
=
GradValue
::
make
(
outputs
[
i
],
m_key
,
GradSlotPtr
{
grad_fn
,
i
});
auto
grad_value
=
m_value_type
.
make
(
outputs
[
i
],
m_key
,
GradSlotPtr
{
grad_fn
,
i
});
outputs
[
i
]
=
record_grad
(
grad_value
);
outputs
[
i
]
=
record_grad
(
grad_value
);
}
}
}
}
...
@@ -435,7 +435,10 @@ ValueRefList GradTransformation::apply_transformation(
...
@@ -435,7 +435,10 @@ ValueRefList GradTransformation::apply_transformation(
backward
.
m_input_has_grad
=
SmallVector
(
nr_inputs
,
true
);
backward
.
m_input_has_grad
=
SmallVector
(
nr_inputs
,
true
);
backward
.
m_output_attrs
=
backward
.
m_output_attrs
=
SmallVector
(
nr_outputs
,
CustomBackward
::
OutputAttr
{
true
,
true
});
SmallVector
(
nr_outputs
,
CustomBackward
::
OutputAttr
{
true
,
true
});
backward
.
m_backward
=
set_grad
->
grad_fn
();
backward
.
m_backward
=
[
fn
=
set_grad
->
grad_fn
()](
Span
<
ValueRef
>
inputs
)
{
auto
result
=
fn
(
inputs
);
return
SmallVector
<
ValueRef
>
(
result
.
begin
(),
result
.
end
());
};
ValueRefList
outputs
(
nr_outputs
);
ValueRefList
outputs
(
nr_outputs
);
grad_fn
->
m_key
=
m_key
;
grad_fn
->
m_key
=
m_key
;
grad_fn
->
m_slots
.
resize
(
nr_outputs
);
grad_fn
->
m_slots
.
resize
(
nr_outputs
);
...
@@ -454,10 +457,10 @@ ValueRefList GradTransformation::apply_transformation(
...
@@ -454,10 +457,10 @@ ValueRefList GradTransformation::apply_transformation(
auto
&
output
=
outputs_
[
i
];
auto
&
output
=
outputs_
[
i
];
auto
grad_value
=
as_grad_value
(
output
);
auto
grad_value
=
as_grad_value
(
output
);
if
(
grad_value
)
{
if
(
grad_value
)
{
grad_value
=
GradValue
::
make
(
grad_value
=
m_value_type
.
make
(
grad_value
->
m_value
,
m_key
,
GradSlotPtr
(
grad_fn
,
i
));
grad_value
->
m_value
,
m_key
,
GradSlotPtr
(
grad_fn
,
i
));
}
else
{
}
else
{
grad_value
=
GradValue
::
make
(
output
,
m_key
,
GradSlotPtr
(
grad_fn
,
i
));
grad_value
=
m_value_type
.
make
(
output
,
m_key
,
GradSlotPtr
(
grad_fn
,
i
));
}
}
outputs
[
i
]
=
record_grad
(
grad_value
);
outputs
[
i
]
=
record_grad
(
grad_value
);
}
}
...
@@ -485,8 +488,7 @@ ValueRefList GradTransformation::apply_transformation(
...
@@ -485,8 +488,7 @@ ValueRefList GradTransformation::apply_transformation(
mgb_assert
(
inputs
.
size
()
==
1
);
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
auto
output
=
imperative
::
apply
(
op
,
grad_value
->
m_value
)[
0
];
auto
output
=
imperative
::
apply
(
op
,
grad_value
->
m_value
)[
0
];
auto
grad_output
=
GradValue
::
make
(
auto
grad_output
=
m_value_type
.
make
(
output
,
m_key
,
grad_value
->
slot
());
output
,
grad_value
->
key
(),
grad_value
->
slot_for
(
m_key
));
return
{
record_grad
(
grad_output
)};
return
{
record_grad
(
grad_output
)};
}
else
{
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
...
@@ -502,7 +504,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
...
@@ -502,7 +504,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
std
::
vector
<
GradSlotPtr
>
y_slots
;
std
::
vector
<
GradSlotPtr
>
y_slots
;
for
(
auto
&&
y
:
ys
)
{
for
(
auto
&&
y
:
ys
)
{
if
(
auto
&&
grad_value
=
as_grad_value
(
y
))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
y
))
{
y_slots
.
push_back
(
grad_value
->
slot
_for
(
grad_key
));
y_slots
.
push_back
(
grad_value
->
slot
(
));
}
else
{
}
else
{
y_slots
.
emplace_back
();
y_slots
.
emplace_back
();
}
}
...
...
imperative/src/impl/transformations/lazy.cpp
浏览文件 @
177001d5
...
@@ -32,7 +32,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
...
@@ -32,7 +32,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
bool
require_link
=
mm_io_ops
.
count
(
op_val
->
op
().
dyn_typeinfo
());
bool
require_link
=
mm_io_ops
.
count
(
op_val
->
op
().
dyn_typeinfo
());
VarNodeArray
input_nodes
;
VarNodeArray
input_nodes
;
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
*
input_node
=
input
.
as
<
LazyEvalValue
>
(
))
{
if
(
auto
*
input_node
=
input
.
as
(
m_value_type
))
{
input_nodes
.
push_back
(
input_node
->
node
());
input_nodes
.
push_back
(
input_node
->
node
());
}
else
{
}
else
{
// ImmutableTensor has empty shape issues
// ImmutableTensor has empty shape issues
...
@@ -112,7 +112,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
...
@@ -112,7 +112,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
return
{
record_var
(
node
)};
return
{
record_var
(
node
)};
}
}
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
if
(
auto
*
lazy_val
=
inputs
.
item
().
as
<
LazyEvalValue
>
(
))
{
if
(
auto
*
lazy_val
=
inputs
.
item
().
as
(
m_value_type
))
{
switch
(
get_attr
->
attr
())
{
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
DType
:
case
GetAttr
::
DType
:
return
{
DTypeValue
::
make
(
lazy_val
->
node
()
->
dtype
())};
return
{
DTypeValue
::
make
(
lazy_val
->
node
()
->
dtype
())};
...
@@ -167,14 +167,14 @@ ValueRefList LazyEvalTransformation::apply_transformation(
...
@@ -167,14 +167,14 @@ ValueRefList LazyEvalTransformation::apply_transformation(
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
}
}
}
else
if
(
auto
*
rename_value
=
op
.
as
<
RenameValue
>
())
{
}
else
if
(
auto
*
rename_value
=
op
.
as
<
RenameValue
>
())
{
if
(
auto
*
lazy_val
=
inputs
.
item
().
as
<
LazyEvalValue
>
(
))
{
if
(
auto
*
lazy_val
=
inputs
.
item
().
as
(
m_value_type
))
{
return
{
record_var
(
return
{
record_var
(
lazy_val
->
node
(),
lazy_val
->
bound_data
(),
rename_value
->
name
())};
lazy_val
->
node
(),
lazy_val
->
bound_data
(),
rename_value
->
name
())};
}
else
{
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
}
}
}
else
if
(
op
.
is
<
GetName
>
())
{
}
else
if
(
op
.
is
<
GetName
>
())
{
if
(
auto
*
lazy_val
=
inputs
.
item
().
as
<
LazyEvalValue
>
(
))
{
if
(
auto
*
lazy_val
=
inputs
.
item
().
as
(
m_value_type
))
{
auto
name
=
lazy_val
->
name
();
auto
name
=
lazy_val
->
name
();
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
return
{
StringValue
::
make
(
lazy_val
->
name
())};
return
{
StringValue
::
make
(
lazy_val
->
name
())};
...
@@ -255,7 +255,7 @@ void LazyEvalTransformation::on_unregister() noexcept {
...
@@ -255,7 +255,7 @@ void LazyEvalTransformation::on_unregister() noexcept {
DeviceStorage
::
make
(
data
.
storage
()))[
0
]);
DeviceStorage
::
make
(
data
.
storage
()))[
0
]);
}
}
for
(
auto
&&
lazy_val
:
lazy_vals
)
{
for
(
auto
&&
lazy_val
:
lazy_vals
)
{
if
(
lazy_val
.
is
<
LazyEvalValue
>
(
))
{
if
(
lazy_val
.
is
(
m_value_type
))
{
std
::
string
repr
=
std
::
string
repr
=
ssprintf
(
"lazy eval failed for %s"
,
lazy_val
->
to_string
().
c_str
());
ssprintf
(
"lazy eval failed for %s"
,
lazy_val
->
to_string
().
c_str
());
mgb_log_debug
(
"%s"
,
repr
.
c_str
());
mgb_log_debug
(
"%s"
,
repr
.
c_str
());
...
...
imperative/src/impl/transformations/scalar.cpp
浏览文件 @
177001d5
...
@@ -20,7 +20,8 @@ namespace imperative {
...
@@ -20,7 +20,8 @@ namespace imperative {
namespace
{
namespace
{
using
ScalarRule
=
ValueRefList
(
*
)(
const
OpDef
&
,
Span
<
ValueRef
>
,
Span
<
bool
>
);
using
ScalarRule
=
ValueRefList
(
*
)(
const
OpDef
&
,
Span
<
ValueRef
>
,
Span
<
bool
>
,
const
Type
<
ScalarValue
>&
);
static
std
::
unordered_map
<
Typeinfo
*
,
ScalarRule
>
scalar_rules
;
static
std
::
unordered_map
<
Typeinfo
*
,
ScalarRule
>
scalar_rules
;
ValueRef
make_scalar_shape
(
CompNode
device
)
{
ValueRef
make_scalar_shape
(
CompNode
device
)
{
...
@@ -41,17 +42,22 @@ bool is_scalar_shape(ValueRef shape) {
...
@@ -41,17 +42,22 @@ bool is_scalar_shape(ValueRef shape) {
return
*
shape_of_shape
==
ValueShape
{
0
};
return
*
shape_of_shape
==
ValueShape
{
0
};
}
}
template
<
typename
T
,
ValueRefList
(
*
rule
)(
const
T
&
,
Span
<
ValueRef
>,
Span
<
bool
>
)
>
template
<
typename
T
,
ValueRefList
(
*
rule
)(
const
T
&
,
Span
<
ValueRef
>,
Span
<
bool
>
,
const
Type
<
ScalarValue
>&
)
>
void
register_scalar_rule
()
{
void
register_scalar_rule
()
{
scalar_rules
[
T
::
typeinfo
()]
=
[](
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
,
scalar_rules
[
T
::
typeinfo
()]
=
[](
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
Span
<
bool
>
inputs_mask
,
return
(
*
rule
)(
def
.
cast_final_safe
<
T
>
(),
inputs
,
inputs_mask
);
const
Type
<
ScalarValue
>&
value_type
)
{
return
(
*
rule
)(
def
.
cast_final_safe
<
T
>
(),
inputs
,
inputs_mask
,
value_type
);
};
};
}
}
template
<
typename
TOpDef
,
size_t
nr_inputs
>
template
<
typename
TOpDef
,
size_t
nr_inputs
>
ValueRefList
elemwise_rule
(
ValueRefList
elemwise_rule
(
const
TOpDef
&
op_def
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
TOpDef
&
op_def
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
if
constexpr
(
nr_inputs
!=
0
)
{
if
constexpr
(
nr_inputs
!=
0
)
{
mgb_assert
(
inputs
.
size
()
==
inputs
.
size
(),
"inputs size mismatch"
);
mgb_assert
(
inputs
.
size
()
==
inputs
.
size
(),
"inputs size mismatch"
);
}
}
...
@@ -63,27 +69,29 @@ ValueRefList elemwise_rule(
...
@@ -63,27 +69,29 @@ ValueRefList elemwise_rule(
}
}
auto
outputs
=
imperative
::
apply
(
op_def
,
inputs
);
auto
outputs
=
imperative
::
apply
(
op_def
,
inputs
);
if
(
all_scalar
)
{
if
(
all_scalar
)
{
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
outputs
[
0
]
=
scalar_type
.
make
(
outputs
[
0
]);
}
}
return
outputs
;
return
outputs
;
}
}
ValueRefList
remove_axis_rule
(
ValueRefList
remove_axis_rule
(
const
RemoveAxis
&
remove_axis
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
RemoveAxis
&
remove_axis
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
mgb_assert
(
!
inputs_mask
.
item
());
mgb_assert
(
!
inputs_mask
.
item
());
bool
is_scalar
=
inputs
.
item
().
shape
()
->
ndim
==
remove_axis
.
axis
.
size
();
bool
is_scalar
=
inputs
.
item
().
shape
()
->
ndim
==
remove_axis
.
axis
.
size
();
if
(
is_scalar
&&
remove_axis
.
axis
.
size
()
==
1
)
{
if
(
is_scalar
&&
remove_axis
.
axis
.
size
()
==
1
)
{
return
{
ScalarValue
::
make
(
inputs
.
item
())};
return
{
scalar_type
.
make
(
inputs
.
item
())};
}
}
auto
outputs
=
imperative
::
apply
(
remove_axis
,
inputs
);
auto
outputs
=
imperative
::
apply
(
remove_axis
,
inputs
);
if
(
is_scalar
)
{
if
(
is_scalar
)
{
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
outputs
[
0
]
=
scalar_type
.
make
(
outputs
[
0
]);
}
}
return
outputs
;
return
outputs
;
}
}
ValueRefList
reduce_rule
(
ValueRefList
reduce_rule
(
const
Reduce
&
reduce
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
Reduce
&
reduce
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
if
(
inputs
.
size
()
==
1
)
{
if
(
inputs
.
size
()
==
1
)
{
return
imperative
::
apply
(
reduce
,
inputs
);
return
imperative
::
apply
(
reduce
,
inputs
);
}
}
...
@@ -91,7 +99,7 @@ ValueRefList reduce_rule(
...
@@ -91,7 +99,7 @@ ValueRefList reduce_rule(
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
if
(
is_scalar
)
{
if
(
is_scalar
)
{
CompNode
device
=
*
inputs
[
0
].
device
();
CompNode
device
=
*
inputs
[
0
].
device
();
return
{
ScalarValue
::
make
(
return
{
scalar_type
.
make
(
imperative
::
apply
(
reduce
,
inputs
[
0
],
make_scalar_shape
(
device
))[
0
])};
imperative
::
apply
(
reduce
,
inputs
[
0
],
make_scalar_shape
(
device
))[
0
])};
}
}
return
imperative
::
apply
(
reduce
,
inputs
);
return
imperative
::
apply
(
reduce
,
inputs
);
...
@@ -99,7 +107,7 @@ ValueRefList reduce_rule(
...
@@ -99,7 +107,7 @@ ValueRefList reduce_rule(
ValueRefList
collective_comm_rule
(
ValueRefList
collective_comm_rule
(
const
CollectiveComm
&
collective_comm
,
Span
<
ValueRef
>
inputs
,
const
CollectiveComm
&
collective_comm
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
mgb_assert
(
inputs
.
size
()
==
1
);
static
std
::
unordered_set
<
CollectiveComm
::
Mode
>
modes
=
{
static
std
::
unordered_set
<
CollectiveComm
::
Mode
>
modes
=
{
CollectiveComm
::
Mode
::
ALL_REDUCE_MAX
,
CollectiveComm
::
Mode
::
ALL_REDUCE_MIN
,
CollectiveComm
::
Mode
::
ALL_REDUCE_MAX
,
CollectiveComm
::
Mode
::
ALL_REDUCE_MIN
,
...
@@ -110,7 +118,7 @@ ValueRefList collective_comm_rule(
...
@@ -110,7 +118,7 @@ ValueRefList collective_comm_rule(
return
imperative
::
apply
(
collective_comm
,
inputs
);
return
imperative
::
apply
(
collective_comm
,
inputs
);
}
}
if
(
inputs_mask
.
item
())
{
if
(
inputs_mask
.
item
())
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
collective_comm
,
inputs
[
0
])[
0
])};
return
{
scalar_type
.
make
(
imperative
::
apply
(
collective_comm
,
inputs
[
0
])[
0
])};
}
else
{
}
else
{
return
imperative
::
apply
(
collective_comm
,
inputs
);
return
imperative
::
apply
(
collective_comm
,
inputs
);
}
}
...
@@ -118,24 +126,27 @@ ValueRefList collective_comm_rule(
...
@@ -118,24 +126,27 @@ ValueRefList collective_comm_rule(
ValueRefList
param_pack_split_rule
(
ValueRefList
param_pack_split_rule
(
const
ParamPackSplit
&
param_pack_split
,
Span
<
ValueRef
>
inputs
,
const
ParamPackSplit
&
param_pack_split
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
auto
outputs
=
imperative
::
apply
(
param_pack_split
,
inputs
);
auto
outputs
=
imperative
::
apply
(
param_pack_split
,
inputs
);
size_t
nr_outputs
=
outputs
.
size
();
size_t
nr_outputs
=
outputs
.
size
();
mgb_assert
(
nr_outputs
==
param_pack_split
.
shapes
.
size
());
mgb_assert
(
nr_outputs
==
param_pack_split
.
shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
if
(
param_pack_split
.
shapes
[
i
].
empty
())
{
if
(
param_pack_split
.
shapes
[
i
].
empty
())
{
outputs
[
i
]
=
ScalarValue
::
make
(
outputs
[
i
]);
outputs
[
i
]
=
scalar_type
.
make
(
outputs
[
i
]);
}
}
}
}
return
outputs
;
return
outputs
;
}
}
ValueRefList
dot_rule
(
const
Dot
&
dot
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
ValueRefList
dot_rule
(
return
{
ScalarValue
::
make
(
imperative
::
apply
(
dot
,
inputs
)[
0
])};
const
Dot
&
dot
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
return
{
scalar_type
.
make
(
imperative
::
apply
(
dot
,
inputs
)[
0
])};
}
}
ValueRefList
add_axis_rule
(
ValueRefList
add_axis_rule
(
const
AddAxis
&
add_axis
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
AddAxis
&
add_axis
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
inputs_mask
.
item
())
{
if
(
inputs_mask
.
item
())
{
mgb_assert
(
add_axis
.
axis
[
0
]
==
0
);
mgb_assert
(
add_axis
.
axis
[
0
]
==
0
);
...
@@ -151,7 +162,8 @@ ValueRefList add_axis_rule(
...
@@ -151,7 +162,8 @@ ValueRefList add_axis_rule(
}
}
ValueRefList
remote_recv_rule
(
ValueRefList
remote_recv_rule
(
const
RemoteRecv
&
remote_recv
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
RemoteRecv
&
remote_recv
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
if
(
remote_recv
.
shape
.
empty
())
{
if
(
remote_recv
.
shape
.
empty
())
{
std
::
vector
<
int32_t
>
shape
=
{
1
};
std
::
vector
<
int32_t
>
shape
=
{
1
};
auto
remote_recv_no_scalar
=
RemoteRecv
::
make
(
auto
remote_recv_no_scalar
=
RemoteRecv
::
make
(
...
@@ -167,20 +179,21 @@ ValueRefList remote_recv_rule(
...
@@ -167,20 +179,21 @@ ValueRefList remote_recv_rule(
ValueRefList
check_no_finite_rule
(
ValueRefList
check_no_finite_rule
(
const
CheckNonFinite
&
check_no_finite
,
Span
<
ValueRef
>
inputs
,
const
CheckNonFinite
&
check_no_finite
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
auto
outputs
=
imperative
::
apply
(
check_no_finite
,
inputs
);
auto
outputs
=
imperative
::
apply
(
check_no_finite
,
inputs
);
mgb_assert
(
outputs
.
size
()
==
inputs
.
size
()
+
1
,
"output size mismatch"
);
mgb_assert
(
outputs
.
size
()
==
inputs
.
size
()
+
1
,
"output size mismatch"
);
outputs
.
back
()
=
ScalarValue
::
make
(
outputs
.
back
());
outputs
.
back
()
=
scalar_type
.
make
(
outputs
.
back
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
inputs_mask
[
i
])
{
if
(
inputs_mask
[
i
])
{
outputs
[
i
]
=
ScalarValue
::
make
(
outputs
[
i
]);
outputs
[
i
]
=
scalar_type
.
make
(
outputs
[
i
]);
}
}
}
}
return
outputs
;
return
outputs
;
}
}
ValueRefList
subtensor_rule
(
ValueRefList
subtensor_rule
(
const
Subtensor
&
subtensor
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
Subtensor
&
subtensor
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
mgb_assert
(
inputs
.
size
()
>=
1
);
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
input
=
inputs
[
0
];
auto
input
=
inputs
[
0
];
bool
is_scalar
;
bool
is_scalar
;
...
@@ -199,14 +212,14 @@ ValueRefList subtensor_rule(
...
@@ -199,14 +212,14 @@ ValueRefList subtensor_rule(
}
}
auto
outputs
=
imperative
::
apply
(
subtensor
,
inputs
);
auto
outputs
=
imperative
::
apply
(
subtensor
,
inputs
);
if
(
is_scalar
)
{
if
(
is_scalar
)
{
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
outputs
[
0
]
=
scalar_type
.
make
(
outputs
[
0
]);
}
}
return
outputs
;
return
outputs
;
}
}
ValueRefList
get_var_shape_rule
(
ValueRefList
get_var_shape_rule
(
const
GetVarShape
&
get_var_shape
,
Span
<
ValueRef
>
inputs
,
const
GetVarShape
&
get_var_shape
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
Span
<
bool
>
inputs_mask
)
{
const
Type
<
ScalarValue
>&
scalar_type
)
{
bool
all_scalar
=
true
;
bool
all_scalar
=
true
;
mgb_assert
(
inputs
.
size
()
>=
1
);
mgb_assert
(
inputs
.
size
()
>=
1
);
for
(
auto
&&
input_mask
:
inputs_mask
)
{
for
(
auto
&&
input_mask
:
inputs_mask
)
{
...
@@ -228,11 +241,12 @@ ValueRefList get_var_shape_rule(
...
@@ -228,11 +241,12 @@ ValueRefList get_var_shape_rule(
}
}
ValueRefList
reshape_rule
(
ValueRefList
reshape_rule
(
const
Reshape
&
reshape
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
Reshape
&
reshape
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
if
(
is_scalar
)
{
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
return
{
scalar_type
.
make
(
imperative
::
apply
(
reshape
,
inputs
[
0
],
make_scalar_shape
(
*
inputs
[
0
].
device
()))[
0
])};
reshape
,
inputs
[
0
],
make_scalar_shape
(
*
inputs
[
0
].
device
()))[
0
])};
}
else
{
}
else
{
return
imperative
::
apply
(
reshape
,
inputs
);
return
imperative
::
apply
(
reshape
,
inputs
);
...
@@ -240,11 +254,12 @@ ValueRefList reshape_rule(
...
@@ -240,11 +254,12 @@ ValueRefList reshape_rule(
}
}
ValueRefList
broadcast_rule
(
ValueRefList
broadcast_rule
(
const
Broadcast
&
broadcast
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
const
Broadcast
&
broadcast
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
if
(
is_scalar
)
{
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
return
{
scalar_type
.
make
(
imperative
::
apply
(
broadcast
,
inputs
[
0
],
make_scalar_shape
(
*
inputs
[
0
].
device
()))[
0
])};
broadcast
,
inputs
[
0
],
make_scalar_shape
(
*
inputs
[
0
].
device
()))[
0
])};
}
else
{
}
else
{
return
imperative
::
apply
(
broadcast
,
inputs
);
return
imperative
::
apply
(
broadcast
,
inputs
);
...
@@ -299,11 +314,11 @@ struct ScalarRuleRegistry {
...
@@ -299,11 +314,11 @@ struct ScalarRuleRegistry {
ValueRefList
ScalarTransformation
::
apply_get_attr
(
ValueRefList
ScalarTransformation
::
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
input
=
inputs
.
item
();
auto
&&
input
=
inputs
.
item
();
bool
is_scalar
=
input
.
is
<
ScalarValue
>
(
);
bool
is_scalar
=
input
.
is
(
m_value_type
);
if
(
!
is_scalar
)
{
if
(
!
is_scalar
)
{
return
imperative
::
apply
(
get_attr
,
input
);
return
imperative
::
apply
(
get_attr
,
input
);
}
}
auto
unwrapped_input
=
input
.
cast
<
ScalarValue
>
(
).
value
();
auto
unwrapped_input
=
input
.
cast
(
m_value_type
).
value
();
if
(
get_attr
.
attr
()
==
GetAttr
::
Shape
)
{
if
(
get_attr
.
attr
()
==
GetAttr
::
Shape
)
{
if
(
!
m_empty_shape
)
{
if
(
!
m_empty_shape
)
{
m_empty_shape
=
ShapeValue
::
make
();
m_empty_shape
=
ShapeValue
::
make
();
...
@@ -352,7 +367,7 @@ ValueRefList ScalarTransformation::apply_transformation(
...
@@ -352,7 +367,7 @@ ValueRefList ScalarTransformation::apply_transformation(
ValueRefList
unwrapped_inputs
(
nr_inputs
);
ValueRefList
unwrapped_inputs
(
nr_inputs
);
SmallVector
<
bool
>
inputs_mask
(
nr_inputs
);
SmallVector
<
bool
>
inputs_mask
(
nr_inputs
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
auto
&&
scalar_value
=
inputs
[
i
].
as_ref
<
ScalarValue
>
(
))
{
if
(
auto
&&
scalar_value
=
inputs
[
i
].
as_ref
(
m_value_type
))
{
unwrapped_inputs
[
i
]
=
scalar_value
->
value
();
unwrapped_inputs
[
i
]
=
scalar_value
->
value
();
inputs_mask
[
i
]
=
true
;
inputs_mask
[
i
]
=
true
;
}
else
{
}
else
{
...
@@ -364,7 +379,8 @@ ValueRefList ScalarTransformation::apply_transformation(
...
@@ -364,7 +379,8 @@ ValueRefList ScalarTransformation::apply_transformation(
if
(
auto
apply_op
=
op
.
as
<
ApplyOp
>
())
{
if
(
auto
apply_op
=
op
.
as
<
ApplyOp
>
())
{
auto
iter
=
scalar_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
auto
iter
=
scalar_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
scalar_rules
.
end
())
{
if
(
iter
!=
scalar_rules
.
end
())
{
return
iter
->
second
(
apply_op
->
op
(),
unwrapped_inputs
,
inputs_mask
);
return
iter
->
second
(
apply_op
->
op
(),
unwrapped_inputs
,
inputs_mask
,
m_value_type
);
}
else
{
}
else
{
// TODO: repeat op
// TODO: repeat op
return
fallback
();
return
fallback
();
...
@@ -375,7 +391,7 @@ ValueRefList ScalarTransformation::apply_transformation(
...
@@ -375,7 +391,7 @@ ValueRefList ScalarTransformation::apply_transformation(
CreateTensor
scalar_op
(
CreateTensor
scalar_op
(
create_tensor
->
kind
(),
create_tensor
->
device
(),
create_tensor
->
kind
(),
create_tensor
->
device
(),
create_tensor
->
dtype
(),
scalar_shape
);
create_tensor
->
dtype
(),
scalar_shape
);
return
{
ScalarValue
::
make
(
imperative
::
apply
(
scalar_op
,
inputs
)[
0
])};
return
{
m_value_type
.
make
(
imperative
::
apply
(
scalar_op
,
inputs
)[
0
])};
}
else
{
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
}
}
...
@@ -387,7 +403,7 @@ ValueRefList ScalarTransformation::apply_transformation(
...
@@ -387,7 +403,7 @@ ValueRefList ScalarTransformation::apply_transformation(
bool
is_scalar
=
inputs_mask
[
0
];
bool
is_scalar
=
inputs_mask
[
0
];
auto
outputs
=
fallback
();
auto
outputs
=
fallback
();
if
(
is_scalar
)
{
if
(
is_scalar
)
{
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
outputs
[
0
]
=
m_value_type
.
make
(
outputs
[
0
]);
}
}
return
outputs
;
return
outputs
;
}
else
{
}
else
{
...
...
imperative/src/impl/transformations/trace.cpp
浏览文件 @
177001d5
...
@@ -160,7 +160,7 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -160,7 +160,7 @@ ValueRefList TracingTransformation::apply_transformation(
SmallVector
<
TracingValue
::
ref_t
>
wrapped_inputs
;
SmallVector
<
TracingValue
::
ref_t
>
wrapped_inputs
;
SmallVector
<
size_t
>
input_ids
;
SmallVector
<
size_t
>
input_ids
;
for
(
auto
input
:
inputs
)
{
for
(
auto
input
:
inputs
)
{
auto
tracing_value
=
input
.
as_ref
<
TracingValue
>
(
);
auto
tracing_value
=
input
.
as_ref
(
m_value_type
);
if
(
!
tracing_value
)
{
if
(
!
tracing_value
)
{
tracing_value
=
tracing_value
=
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
...
@@ -208,7 +208,7 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -208,7 +208,7 @@ ValueRefList TracingTransformation::apply_transformation(
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
unwrapped_input
=
unwrap_var
(
inputs
[
0
]);
auto
unwrapped_input
=
unwrap_var
(
inputs
[
0
]);
auto
outputs
=
imperative
::
apply
(
op
,
unwrapped_input
);
auto
outputs
=
imperative
::
apply
(
op
,
unwrapped_input
);
if
(
auto
*
tracing_value
=
inputs
[
0
].
as
<
TracingValue
>
(
))
{
if
(
auto
*
tracing_value
=
inputs
[
0
].
as
(
m_value_type
))
{
auto
&
var_info
=
m_vars
[
tracing_value
->
id
()];
auto
&
var_info
=
m_vars
[
tracing_value
->
id
()];
switch
(
get_attr
->
attr
())
{
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
case
GetAttr
::
Shape
:
...
@@ -228,7 +228,7 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -228,7 +228,7 @@ ValueRefList TracingTransformation::apply_transformation(
}
else
if
(
auto
*
trace_mark_var
=
op
.
as
<
TraceMarkVar
>
())
{
}
else
if
(
auto
*
trace_mark_var
=
op
.
as
<
TraceMarkVar
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"TraceMarkVar expects exactly one input"
);
mgb_assert
(
inputs
.
size
()
==
1
,
"TraceMarkVar expects exactly one input"
);
auto
input
=
inputs
[
0
];
auto
input
=
inputs
[
0
];
auto
tracing_var
=
input
.
as_ref
<
TracingValue
>
(
);
auto
tracing_var
=
input
.
as_ref
(
m_value_type
);
if
(
!
tracing_var
)
{
if
(
!
tracing_var
)
{
bool
is_input
=
trace_mark_var
->
mark
().
substr
(
0
,
4
)
==
"arg_"
||
bool
is_input
=
trace_mark_var
->
mark
().
substr
(
0
,
4
)
==
"arg_"
||
trace_mark_var
->
mark
().
substr
(
0
,
6
)
==
"kwarg_"
;
trace_mark_var
->
mark
().
substr
(
0
,
6
)
==
"kwarg_"
;
...
@@ -247,7 +247,7 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -247,7 +247,7 @@ ValueRefList TracingTransformation::apply_transformation(
}
else
if
(
auto
*
trace_name_var
=
op
.
as
<
RenameValue
>
())
{
}
else
if
(
auto
*
trace_name_var
=
op
.
as
<
RenameValue
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"RenameValue expects exactly one input"
);
mgb_assert
(
inputs
.
size
()
==
1
,
"RenameValue expects exactly one input"
);
auto
input
=
inputs
[
0
];
auto
input
=
inputs
[
0
];
auto
tracing_var
=
input
.
as_ref
<
TracingValue
>
(
);
auto
tracing_var
=
input
.
as_ref
(
m_value_type
);
if
(
!
tracing_var
)
{
if
(
!
tracing_var
)
{
tracing_var
=
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
tracing_var
=
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
}
else
{
}
else
{
...
@@ -260,7 +260,7 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -260,7 +260,7 @@ ValueRefList TracingTransformation::apply_transformation(
}
else
if
(
op
.
is
<
GetName
>
())
{
}
else
if
(
op
.
is
<
GetName
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"GetName expects exactly one input"
);
mgb_assert
(
inputs
.
size
()
==
1
,
"GetName expects exactly one input"
);
auto
input
=
inputs
[
0
];
auto
input
=
inputs
[
0
];
if
(
auto
tracing_var
=
input
.
as_ref
<
TracingValue
>
(
))
{
if
(
auto
tracing_var
=
input
.
as_ref
(
m_value_type
))
{
auto
name
=
m_vars
[
tracing_var
->
id
()].
name
;
auto
name
=
m_vars
[
tracing_var
->
id
()].
name
;
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
return
{
StringValue
::
make
(
name
)};
return
{
StringValue
::
make
(
name
)};
...
@@ -425,26 +425,12 @@ void CompiledTransformation::compile() {
...
@@ -425,26 +425,12 @@ void CompiledTransformation::compile() {
}
}
auto
&
node
=
var_accessors
[
input
].
node
;
auto
&
node
=
var_accessors
[
input
].
node
;
if
(
input_vars
.
empty
()
&&
require_link
&&
mm_io_link
.
node
())
{
if
(
input_vars
.
empty
()
&&
require_link
&&
mm_io_link
.
node
())
{
/*mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");*/
auto
comp_node
=
mm_io_link
.
node
()
->
comp_node
();
auto
comp_node
=
mm_io_link
.
node
()
->
comp_node
();
// auto comp_node = input_vars[0]->comp_node();
node
=
opr
::
VirtualDep
::
make
({
SymbolVar
(
node
),
mm_io_link
},
comp_node
)
node
=
opr
::
VirtualDep
::
make
({
SymbolVar
(
node
),
mm_io_link
},
comp_node
)
.
node
();
.
node
();
}
}
input_vars
.
push_back
(
node
);
input_vars
.
push_back
(
node
);
}
}
/*if (require_link && mm_io_link.node()) {
mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");
auto comp_node = mm_io_link.node()->comp_node();
// auto comp_node = input_vars[0]->comp_node();
input_vars[0] = opr::VirtualDep::make(
{SymbolVar(input_vars[0]), mm_io_link}, comp_node)
.node();
}*/
VarNodeArray
output_vars
;
VarNodeArray
output_vars
;
if
(
item
.
op
)
{
if
(
item
.
op
)
{
output_vars
=
OpDef
::
apply_on_var_node
(
*
item
.
op
,
input_vars
);
output_vars
=
OpDef
::
apply_on_var_node
(
*
item
.
op
,
input_vars
);
...
@@ -520,7 +506,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
...
@@ -520,7 +506,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
switch
(
var
.
kind
)
{
switch
(
var
.
kind
)
{
case
VarKind
::
External
:
{
case
VarKind
::
External
:
{
trace_assert
(
trace_assert
(
!
value
.
is
<
TracedValue
>
(
),
"expect external node, got internal"
);
!
value
.
is
(
m_value_type
),
"expect external node, got internal"
);
if
(
var
.
bound_data
)
{
if
(
var
.
bound_data
)
{
assert_tensor_equal
(
var
.
bound_data
,
value
);
assert_tensor_equal
(
var
.
bound_data
,
value
);
}
else
{
}
else
{
...
@@ -545,8 +531,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
...
@@ -545,8 +531,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
}
}
case
VarKind
::
Internal
:
{
case
VarKind
::
Internal
:
{
trace_assert
(
trace_assert
(
value
.
is
<
TracedValue
>
(
),
"expect internal node, got external"
);
value
.
is
(
m_value_type
),
"expect internal node, got external"
);
auto
&
traced_value
=
value
.
cast
<
TracedValue
>
(
);
auto
&
traced_value
=
value
.
cast
(
m_value_type
);
trace_assert
(
traced_value
.
id
()
==
id
,
"input id mismatch"
);
trace_assert
(
traced_value
.
id
()
==
id
,
"input id mismatch"
);
break
;
break
;
}
}
...
@@ -559,7 +545,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
...
@@ -559,7 +545,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
}
}
auto
CompiledTransformation
::
trace_output
(
size_t
id
)
->
TracedValue
::
ref_t
{
auto
CompiledTransformation
::
trace_output
(
size_t
id
)
->
TracedValue
::
ref_t
{
auto
traced_value
=
TracedValue
::
make
(
id
,
&
m_vars
[
id
],
&
m_var_accessors
[
id
]);
auto
traced_value
=
m_value_type
.
make
(
id
,
&
m_vars
[
id
],
&
m_var_accessors
[
id
]);
m_weak_values
.
push_back
(
traced_value
);
m_weak_values
.
push_back
(
traced_value
);
return
traced_value
;
return
traced_value
;
}
}
...
@@ -569,7 +555,7 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
...
@@ -569,7 +555,7 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
return
m_seq
[
m_pc
++
];
return
m_seq
[
m_pc
++
];
}
}
ShapeValue
::
ref_t
CompiledTransformation
::
Traced
Info
::
shape
()
const
{
ShapeValue
::
ref_t
CompiledTransformation
::
Traced
Value
::
shape
()
const
{
if
(
!
m_shape
)
{
if
(
!
m_shape
)
{
trace_assert
(
m_accessor
->
shape_getter
,
"shape unreadable"
);
trace_assert
(
m_accessor
->
shape_getter
,
"shape unreadable"
);
m_shape
=
ShapeValue
::
make
(
ValueShape
::
from
(
m_accessor
->
shape_getter
()));
m_shape
=
ShapeValue
::
make
(
ValueShape
::
from
(
m_accessor
->
shape_getter
()));
...
@@ -577,14 +563,14 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
...
@@ -577,14 +563,14 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
return
m_shape
;
return
m_shape
;
}
}
DTypeValue
::
ref_t
CompiledTransformation
::
Traced
Info
::
dtype
()
const
{
DTypeValue
::
ref_t
CompiledTransformation
::
Traced
Value
::
dtype
()
const
{
return
m_var
->
dtype
;
return
m_var
->
dtype
;
}
}
CompNodeValue
::
ref_t
CompiledTransformation
::
Traced
Info
::
comp_node
()
const
{
CompNodeValue
::
ref_t
CompiledTransformation
::
Traced
Value
::
comp_node
()
const
{
return
m_var
->
device
;
return
m_var
->
device
;
}
}
auto
CompiledTransformation
::
Traced
Info
::
accessor
()
const
->
const
VarAccessor
&
{
auto
CompiledTransformation
::
Traced
Value
::
accessor
()
const
->
const
VarAccessor
&
{
return
*
m_accessor
;
return
*
m_accessor
;
}
}
...
@@ -605,7 +591,7 @@ ValueRefList CompiledTransformation::apply_op(
...
@@ -605,7 +591,7 @@ ValueRefList CompiledTransformation::apply_op(
ValueRefList
CompiledTransformation
::
apply_get_attr
(
ValueRefList
CompiledTransformation
::
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
traced_value
=
inputs
[
0
].
as
<
TracedValue
>
(
))
{
if
(
auto
*
traced_value
=
inputs
[
0
].
as
(
m_value_type
))
{
ValueRef
output
;
ValueRef
output
;
auto
&
var_accessor
=
traced_value
->
accessor
();
auto
&
var_accessor
=
traced_value
->
accessor
();
switch
(
get_attr
.
attr
())
{
switch
(
get_attr
.
attr
())
{
...
@@ -718,15 +704,11 @@ void CompiledTransformation::on_unregister() noexcept {
...
@@ -718,15 +704,11 @@ void CompiledTransformation::on_unregister() noexcept {
void
CompiledTransformation
::
execute
()
{
void
CompiledTransformation
::
execute
()
{
mgb_assert
(
m_executable
!=
nullptr
);
mgb_assert
(
m_executable
!=
nullptr
);
m_graph_executor
=
std
::
thread
([
&
]
{
{
try
{
MGB_LOCK_GUARD
(
m_mutex
);
m_executable
->
execute
();
m_graph_status
=
1
;
m_executable
->
wait
();
}
}
catch
(...)
{
m_cv
.
notify_all
();
auto
exc
=
std
::
current_exception
();
set_exception
(
exc
);
}
});
}
}
void
CompiledTransformation
::
wait
()
{
void
CompiledTransformation
::
wait
()
{
...
@@ -735,8 +717,9 @@ void CompiledTransformation::wait() {
...
@@ -735,8 +717,9 @@ void CompiledTransformation::wait() {
}
catch
(...)
{
}
catch
(...)
{
}
}
mgb_assert
(
m_executable
!=
nullptr
);
mgb_assert
(
m_executable
!=
nullptr
);
m_graph_executor
.
join
();
std
::
unique_lock
lock
{
m_mutex
};
m_graph_executor
=
{};
m_cv
.
wait
(
lock
,
[
&
]
{
return
m_graph_status
==
0
;
});
lock
.
unlock
();
for
(
auto
&&
box
:
m_boxes
)
{
for
(
auto
&&
box
:
m_boxes
)
{
box
->
reset
();
box
->
reset
();
}
}
...
...
imperative/src/impl/value.cpp
浏览文件 @
177001d5
...
@@ -25,16 +25,16 @@ ValueRef::storage_t& ValueRef::storage() const {
...
@@ -25,16 +25,16 @@ ValueRef::storage_t& ValueRef::storage() const {
return
m_storage
;
return
m_storage
;
}
}
const
Value
*
ValueRef
::
as
(
size_t
typecod
e
)
const
{
const
Value
*
ValueRef
::
as
(
const
IType
&
typ
e
)
const
{
auto
&&
storage
=
this
->
storage
();
auto
&&
storage
=
this
->
storage
();
if
(
storage
->
m_typecode
!=
typecod
e
)
{
if
(
storage
->
type
()
!=
typ
e
)
{
return
nullptr
;
return
nullptr
;
}
}
return
static_cast
<
Value
*>
(
storage
.
get
());
return
static_cast
<
Value
*>
(
storage
.
get
());
}
}
bool
ValueRef
::
is
(
size_t
typecod
e
)
const
{
bool
ValueRef
::
is
(
const
IType
&
typ
e
)
const
{
return
this
->
storage
()
->
m_typecode
==
typecod
e
;
return
this
->
storage
()
->
type
()
==
typ
e
;
}
}
TypedValueRef
<
DeviceValue
>
ValueRef
::
dev_tensor
()
const
{
TypedValueRef
<
DeviceValue
>
ValueRef
::
dev_tensor
()
const
{
...
@@ -106,9 +106,7 @@ std::string ValueRef::raw_type() const {
...
@@ -106,9 +106,7 @@ std::string ValueRef::raw_type() const {
if
(
!
m_storage
)
{
if
(
!
m_storage
)
{
return
"null"
;
return
"null"
;
}
}
auto
&
types
=
Value
::
registered_types
();
return
m_storage
->
type
().
name
();
mgb_assert
(
types
.
size
()
>
m_storage
->
m_typecode
);
return
types
[
m_storage
->
m_typecode
].
name
();
}
}
bool
ValueRef
::
watching
()
const
{
bool
ValueRef
::
watching
()
const
{
...
@@ -137,7 +135,7 @@ ValueRef ValueWeakRef::lock() {
...
@@ -137,7 +135,7 @@ ValueRef ValueWeakRef::lock() {
return
{
strong_storage
};
return
{
strong_storage
};
}
}
Value
::
Value
(
size_t
typecode
)
:
m_typecode
{
typecode
}
{
Value
::
Value
(
)
{
m_id
=
nr_values
++
;
m_id
=
nr_values
++
;
}
}
...
@@ -147,17 +145,6 @@ Value::~Value() {
...
@@ -147,17 +145,6 @@ Value::~Value() {
}
}
}
}
size_t
Value
::
register_type
(
std
::
type_index
type
)
{
auto
&
types
=
const_cast
<
std
::
vector
<
std
::
type_index
>&>
(
registered_types
());
types
.
push_back
(
type
);
return
types
.
size
()
-
1
;
}
const
std
::
vector
<
std
::
type_index
>&
Value
::
registered_types
()
{
static
std
::
vector
<
std
::
type_index
>
sm_registered_types
;
return
sm_registered_types
;
}
void
Value
::
register_value
(
ValueRef
value
)
{
void
Value
::
register_value
(
ValueRef
value
)
{
registered_values
[
value
.
id
()]
=
ValueWeakRef
(
value
);
registered_values
[
value
.
id
()]
=
ValueWeakRef
(
value
);
}
}
...
@@ -188,7 +175,7 @@ std::vector<ValueRef> Value::end_record_values() {
...
@@ -188,7 +175,7 @@ std::vector<ValueRef> Value::end_record_values() {
}
}
void
Value
::
try_rethrow
()
{
void
Value
::
try_rethrow
()
{
if
(
m_typecode
==
ErrorValue
::
TYPE_CODE
)
{
if
(
type
()
==
PrimitiveType
<
ErrorValue
>::
instance
)
{
auto
message
=
static_cast
<
ErrorValue
*>
(
this
)
->
message
();
auto
message
=
static_cast
<
ErrorValue
*>
(
this
)
->
message
();
mgb_throw
(
MegBrainError
,
"invalid value: %s"
,
message
.
c_str
());
mgb_throw
(
MegBrainError
,
"invalid value: %s"
,
message
.
c_str
());
}
}
...
@@ -198,13 +185,9 @@ inline void ValueRefList::init(size_t nr_elems) {
...
@@ -198,13 +185,9 @@ inline void ValueRefList::init(size_t nr_elems) {
m_size
=
nr_elems
;
m_size
=
nr_elems
;
if
(
m_size
>
0
)
{
if
(
m_size
>
0
)
{
if
(
m_size
==
1
)
{
if
(
m_size
==
1
)
{
m_data
=
inline_storage
();
m_data
=
new
(
inline_storage
())
ValueRef
();
}
else
{
}
else
{
auto
&
context
=
Transformation
::
get_context
();
m_data
=
new
ValueRef
[
m_size
];
m_data
=
context
.
allocator
.
allocate
(
m_size
);
}
for
(
size_t
i
=
0
;
i
<
m_size
;
++
i
)
{
new
(
m_data
+
i
)
ValueRef
();
}
}
}
else
{
}
else
{
m_data
=
nullptr
;
m_data
=
nullptr
;
...
@@ -215,9 +198,6 @@ ValueRefList::ValueRefList(size_t nr_elems) {
...
@@ -215,9 +198,6 @@ ValueRefList::ValueRefList(size_t nr_elems) {
init
(
nr_elems
);
init
(
nr_elems
);
}
}
/*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
: ValueRefList(values.begin(), values.end()) {}*/
ValueRefList
::
ValueRefList
(
const
ValueRefList
&
rhs
)
ValueRefList
::
ValueRefList
(
const
ValueRefList
&
rhs
)
:
ValueRefList
(
rhs
.
cbegin
(),
rhs
.
cend
())
{}
:
ValueRefList
(
rhs
.
cbegin
(),
rhs
.
cend
())
{}
...
@@ -271,14 +251,12 @@ ValueRefList::~ValueRefList() {
...
@@ -271,14 +251,12 @@ ValueRefList::~ValueRefList() {
}
}
void
ValueRefList
::
clear
()
{
void
ValueRefList
::
clear
()
{
for
(
size_t
i
=
0
;
i
<
m_size
;
++
i
)
{
m_data
[
i
].
~
ValueRef
();
}
if
(
m_data
)
{
if
(
m_data
)
{
if
(
m_size
!=
1
)
{
if
(
m_size
!=
1
)
{
Transformation
::
get_context
().
allocator
.
deallocate
(
m_data
,
m_size
)
;
delete
[]
m_data
;
}
else
{
}
else
{
mgb_assert
(
m_data
==
inline_storage
());
mgb_assert
(
m_data
==
inline_storage
());
m_data
->~
ValueRef
();
}
}
}
}
m_data
=
nullptr
;
m_data
=
nullptr
;
...
...
imperative/src/include/megbrain/imperative/basic_values.h
浏览文件 @
177001d5
...
@@ -25,79 +25,68 @@ class GradKey;
...
@@ -25,79 +25,68 @@ class GradKey;
using
GenericFunction
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
)
>
;
using
GenericFunction
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
)
>
;
class
ShapeValue
final
class
ShapeValue
final
:
public
PrimitiveValue
<
ShapeValue
,
ValueShape
>
{
:
public
MixinValueImpl
<
ShapeValue
,
ValueKind
::
Primitive
,
ValueShape
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
CompNodeValue
final
class
CompNodeValue
final
:
public
PrimitiveValue
<
CompNodeValue
,
CompNode
>
{
:
public
MixinValueImpl
<
CompNodeValue
,
ValueKind
::
Primitive
,
CompNode
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
// TODO: override factory method
class
Boolean
{
class
BoolValue
final
:
public
ValueImpl
<
BoolValue
,
ValueKind
::
Primitive
>
{
private:
private:
std
::
optional
<
bool
>
m_value
;
bool
m_value
;
public:
public:
Bool
Value
(
bool
value
)
:
m_value
{
value
}
{}
Bool
ean
()
=
default
;
operator
bool
()
const
{
return
*
m_value
;
}
Boolean
(
bool
value
)
:
m_value
(
value
)
{
}
std
::
string
to_string
()
const
override
;
operator
bool
()
const
{
return
m_value
;
}
};
void
clear
()
override
{
m_value
.
reset
();
}
// TODO: override factory method
class
BoolValue
final
:
public
PrimitiveValue
<
BoolValue
,
Boolean
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
};
};
class
HostStorage
final
class
HostStorage
final
:
public
PrimitiveValue
<
HostStorage
,
HostTensorStorage
>
{
:
public
MixinValueImpl
<
HostStorage
,
ValueKind
::
Primitive
,
HostTensorStorage
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
DeviceStorage
final
class
DeviceStorage
final
:
public
PrimitiveValue
<
DeviceStorage
,
DeviceTensorStorage
>
{
:
public
MixinValueImpl
<
DeviceStorage
,
ValueKind
::
Primitive
,
DeviceTensorStorage
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
/**
class
HostTensor
{
* \brief like HostTensorND mixin, but allow scalar value
*
*/
class
HostValue
final
:
public
ValueImpl
<
HostValue
,
ValueKind
::
Primitive
>
{
private:
private:
DType
m_dtype
;
DType
m_dtype
;
ValueShape
m_shape
;
ValueShape
m_shape
;
HostTensorStorage
m_storage
;
HostTensorStorage
m_storage
;
public:
public:
HostValue
(
DType
dtype
,
ValueShape
shape
,
HostTensorStorage
storage
)
HostTensor
()
=
default
;
HostTensor
(
DType
dtype
,
ValueShape
shape
,
HostTensorStorage
storage
)
:
m_dtype
(
dtype
),
m_shape
(
shape
),
m_storage
(
storage
)
{}
:
m_dtype
(
dtype
),
m_shape
(
shape
),
m_storage
(
storage
)
{}
Host
Value
(
HostTensorND
value
)
Host
Tensor
(
HostTensorND
value
)
:
Host
Value
(
:
Host
Tensor
(
value
.
dtype
(),
ValueShape
::
from
(
value
.
shape
()),
value
.
storage
())
{
value
.
dtype
(),
ValueShape
::
from
(
value
.
shape
()),
value
.
storage
())
{
}
}
std
::
string
to_string
()
const
override
;
void
clear
()
override
{
m_dtype
=
{};
m_shape
=
{};
m_storage
=
{};
}
DType
dtype
()
const
{
return
m_dtype
;
}
DType
dtype
()
const
{
return
m_dtype
;
}
const
ValueShape
&
shape
()
const
{
return
m_shape
;
}
const
ValueShape
&
shape
()
const
{
return
m_shape
;
}
CompNode
device
()
const
{
return
m_storage
.
comp_node
();
}
CompNode
device
()
const
{
return
m_storage
.
comp_node
();
}
...
@@ -112,31 +101,31 @@ public:
...
@@ -112,31 +101,31 @@ public:
};
};
/**
/**
* \brief like
Device
TensorND mixin, but allow scalar value
* \brief like
Host
TensorND mixin, but allow scalar value
*
*
*/
*/
class
DeviceValue
final
:
public
ValueImpl
<
DeviceValue
,
ValueKind
::
Primitive
>
{
class
HostValue
final
:
public
PrimitiveValue
<
HostValue
,
HostTensor
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
};
class
DeviceTensor
{
private:
private:
DType
m_dtype
;
DType
m_dtype
;
ValueShape
m_shape
;
ValueShape
m_shape
;
DeviceTensorStorage
m_storage
;
DeviceTensorStorage
m_storage
;
public:
public:
DeviceValue
(
DType
dtype
,
ValueShape
shape
,
DeviceTensorStorage
storage
)
DeviceTensor
()
=
default
;
DeviceTensor
(
DType
dtype
,
ValueShape
shape
,
DeviceTensorStorage
storage
)
:
m_dtype
(
dtype
),
m_shape
(
shape
),
m_storage
(
std
::
move
(
storage
))
{}
:
m_dtype
(
dtype
),
m_shape
(
shape
),
m_storage
(
std
::
move
(
storage
))
{}
Device
Value
(
const
DeviceTensorND
&
value
)
Device
Tensor
(
const
DeviceTensorND
&
value
)
:
Device
Value
(
:
Device
Tensor
(
value
.
dtype
(),
ValueShape
::
from
(
value
.
shape
()),
value
.
storage
())
{
value
.
dtype
(),
ValueShape
::
from
(
value
.
shape
()),
value
.
storage
())
{
}
}
std
::
string
to_string
()
const
override
;
void
clear
()
override
{
m_dtype
=
{};
m_shape
=
{};
m_storage
=
{};
}
DType
dtype
()
const
{
return
m_dtype
;
}
DType
dtype
()
const
{
return
m_dtype
;
}
const
ValueShape
&
shape
()
const
{
return
m_shape
;
}
const
ValueShape
&
shape
()
const
{
return
m_shape
;
}
CompNode
device
()
const
{
return
m_storage
.
comp_node
();
}
CompNode
device
()
const
{
return
m_storage
.
comp_node
();
}
...
@@ -145,26 +134,34 @@ public:
...
@@ -145,26 +134,34 @@ public:
DeviceTensorND
as_nd
(
bool
allow_scalar
=
false
)
const
;
DeviceTensorND
as_nd
(
bool
allow_scalar
=
false
)
const
;
};
};
class
FunctionValue
final
/**
:
public
MixinValueImpl
<
FunctionValue
,
ValueKind
::
Primitive
,
GenericFunction
>
{
* \brief like DeviceTensorND mixin, but allow scalar value
*
*/
class
DeviceValue
final
:
public
PrimitiveValue
<
DeviceValue
,
DeviceTensor
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
};
class
FunctionValue
final
:
public
PrimitiveValue
<
FunctionValue
,
GenericFunction
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
DTypeValue
final
class
DTypeValue
final
:
public
PrimitiveValue
<
DTypeValue
,
DType
>
{
:
public
MixinValueImpl
<
DTypeValue
,
ValueKind
::
Primitive
,
DType
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
StringValue
final
class
StringValue
final
:
public
PrimitiveValue
<
StringValue
,
std
::
string
>
{
:
public
MixinValueImpl
<
StringValue
,
ValueKind
::
Primitive
,
std
::
string
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
...
@@ -180,10 +177,9 @@ public:
...
@@ -180,10 +177,9 @@ public:
std
::
string
message
()
const
{
return
m_message
;
}
std
::
string
message
()
const
{
return
m_message
;
}
};
};
class
ErrorValue
final
class
ErrorValue
final
:
public
PrimitiveValue
<
ErrorValue
,
Error
>
{
:
public
MixinValueImpl
<
ErrorValue
,
ValueKind
::
Primitive
,
Error
>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
...
...
imperative/src/include/megbrain/imperative/subgraph.h
浏览文件 @
177001d5
...
@@ -57,7 +57,7 @@ struct Subgraph {
...
@@ -57,7 +57,7 @@ struct Subgraph {
SmallVector
<
expr_t
>
exprs
;
SmallVector
<
expr_t
>
exprs
;
template
<
typename
T
,
typename
F
,
typename
C
>
template
<
typename
T
,
typename
F
,
typename
C
>
SmallVector
<
T
>
apply
(
S
mallVector
<
T
>
input_vars
,
F
&&
f
,
C
&&
c
)
const
{
SmallVector
<
T
>
apply
(
S
pan
<
T
>
input_vars
,
F
&&
f
,
C
&&
c
)
const
{
std
::
unordered_map
<
size_t
,
T
>
idx2var
;
std
::
unordered_map
<
size_t
,
T
>
idx2var
;
mgb_assert
(
inputs
.
size
()
==
input_vars
.
size
(),
"input size mismatch"
);
mgb_assert
(
inputs
.
size
()
==
input_vars
.
size
(),
"input size mismatch"
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
...
@@ -71,8 +71,7 @@ struct Subgraph {
...
@@ -71,8 +71,7 @@ struct Subgraph {
for
(
auto
idx
:
expr
.
inputs
)
{
for
(
auto
idx
:
expr
.
inputs
)
{
expr_inputs
.
push_back
(
idx2var
[
idx
]);
expr_inputs
.
push_back
(
idx2var
[
idx
]);
}
}
SmallVector
<
T
>
expr_outputs
=
SmallVector
<
T
>
expr_outputs
=
f
(
expr
.
op
,
expr_inputs
,
expr
.
outputs
.
size
());
f
(
expr
.
op
,
std
::
move
(
expr_inputs
),
expr
.
outputs
.
size
());
mgb_assert
(
mgb_assert
(
expr_outputs
.
size
()
==
expr
.
outputs
.
size
(),
"output size mismatch"
);
expr_outputs
.
size
()
==
expr
.
outputs
.
size
(),
"output size mismatch"
);
for
(
size_t
i
=
0
;
i
<
expr_outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
expr_outputs
.
size
();
++
i
)
{
...
@@ -102,9 +101,9 @@ struct EncodedSubgraph {
...
@@ -102,9 +101,9 @@ struct EncodedSubgraph {
SmallVector
<
bool
>
input_mask
;
SmallVector
<
bool
>
input_mask
;
SmallVector
<
bool
>
output_mask
;
SmallVector
<
bool
>
output_mask
;
template
<
typename
T
Container
>
template
<
typename
T
>
TContainer
encode_inputs
(
TContainer
inputs
)
const
{
SmallVector
<
T
>
encode_inputs
(
Span
<
T
>
inputs
)
const
{
TContainer
encoded_inputs
;
SmallVector
<
T
>
encoded_inputs
;
size_t
index
=
0
;
size_t
index
=
0
;
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
mgb_assert
(
index
<
input_mask
.
size
(),
"index out of range"
);
mgb_assert
(
index
<
input_mask
.
size
(),
"index out of range"
);
...
@@ -116,9 +115,9 @@ struct EncodedSubgraph {
...
@@ -116,9 +115,9 @@ struct EncodedSubgraph {
return
encoded_inputs
;
return
encoded_inputs
;
}
}
template
<
typename
T
Container
>
template
<
typename
T
>
TContainer
encode_outputs
(
TContainer
outputs
)
const
{
SmallVector
<
T
>
encode_outputs
(
Span
<
T
>
outputs
)
const
{
TContainer
encoded_outputs
;
SmallVector
<
T
>
encoded_outputs
;
size_t
index
=
0
;
size_t
index
=
0
;
for
(
auto
&&
output
:
outputs
)
{
for
(
auto
&&
output
:
outputs
)
{
mgb_assert
(
index
<
output_mask
.
size
(),
"index out of range"
);
mgb_assert
(
index
<
output_mask
.
size
(),
"index out of range"
);
...
@@ -130,9 +129,9 @@ struct EncodedSubgraph {
...
@@ -130,9 +129,9 @@ struct EncodedSubgraph {
return
encoded_outputs
;
return
encoded_outputs
;
}
}
template
<
typename
T
Container
>
template
<
typename
T
>
TContainer
decode_outputs
(
TContainer
outputs
)
const
{
SmallVector
<
T
>
decode_outputs
(
Span
<
T
>
outputs
)
const
{
TContainer
decoded_outputs
;
SmallVector
<
T
>
decoded_outputs
;
size_t
index
=
0
;
size_t
index
=
0
;
for
(
size_t
i
=
0
;
i
<
output_mask
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
output_mask
.
size
();
i
++
)
{
mgb_assert
(
index
<
output_mask
.
size
(),
"index out of range"
);
mgb_assert
(
index
<
output_mask
.
size
(),
"index out of range"
);
...
@@ -150,8 +149,8 @@ struct EncodedSubgraph {
...
@@ -150,8 +149,8 @@ struct EncodedSubgraph {
EncodedSubgraph
result
;
EncodedSubgraph
result
;
result
.
input_mask
=
graph
.
gen_input_mask
();
result
.
input_mask
=
graph
.
gen_input_mask
();
result
.
output_mask
=
graph
.
gen_output_mask
();
result
.
output_mask
=
graph
.
gen_output_mask
();
graph
.
inputs
=
result
.
encode_inputs
(
graph
.
inputs
);
graph
.
inputs
=
result
.
encode_inputs
<
Subgraph
::
var_t
>
(
graph
.
inputs
);
graph
.
outputs
=
result
.
encode_outputs
(
graph
.
outputs
);
graph
.
outputs
=
result
.
encode_outputs
<
Subgraph
::
var_t
>
(
graph
.
outputs
);
result
.
graph
=
graph
;
result
.
graph
=
graph
;
return
result
;
return
result
;
}
}
...
@@ -179,11 +178,11 @@ struct EncodedSubgraph {
...
@@ -179,11 +178,11 @@ struct EncodedSubgraph {
}
}
template
<
typename
T
,
typename
F
,
typename
C
>
template
<
typename
T
,
typename
F
,
typename
C
>
SmallVector
<
T
>
apply
(
S
mallVector
<
T
>
input_vars
,
F
&&
f
,
C
&&
c
)
const
{
SmallVector
<
T
>
apply
(
S
pan
<
T
>
input_vars
,
F
&&
f
,
C
&&
c
)
const
{
auto
encoded_inputs
=
encode_inputs
(
input_vars
);
auto
encoded_inputs
=
encode_inputs
<
T
>
(
input_vars
);
auto
encoded_outputs
=
auto
encoded_outputs
=
graph
.
apply
(
encoded_inputs
,
std
::
forward
<
F
>
(
f
),
std
::
forward
<
C
>
(
c
));
graph
.
apply
<
T
>
(
encoded_inputs
,
std
::
forward
<
F
>
(
f
),
std
::
forward
<
C
>
(
c
));
return
decode_outputs
(
encoded_outputs
);
return
decode_outputs
<
T
>
(
encoded_outputs
);
}
}
std
::
string
repr
()
const
;
std
::
string
repr
()
const
;
...
@@ -280,4 +279,4 @@ public:
...
@@ -280,4 +279,4 @@ public:
};
};
}
// namespace imperative
}
// namespace imperative
}
// namespace mgb
}
// namespace mgb
\ No newline at end of file
imperative/src/include/megbrain/imperative/transformations/eval.h
浏览文件 @
177001d5
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
namespace
mgb
::
imperative
{
namespace
mgb
::
imperative
{
struct
InterpreterInfo
{
class
InterpreterValue
final
:
public
ObjectValue
<
InterpreterValue
>
{
public:
public:
using
Handle
=
interpreter
::
Interpreter
::
Handle
;
using
Handle
=
interpreter
::
Interpreter
::
Handle
;
using
Channel
=
interpreter
::
Interpreter
::
Channel
;
using
Channel
=
interpreter
::
Interpreter
::
Channel
;
...
@@ -46,8 +46,7 @@ private:
...
@@ -46,8 +46,7 @@ private:
mutable
ShapeValue
::
ref_t
m_shape
;
mutable
ShapeValue
::
ref_t
m_shape
;
public:
public:
InterpreterInfo
()
=
default
;
InterpreterValue
(
LocalPtr
<
RAIIHandle
>
handle
,
std
::
string
name
=
{})
InterpreterInfo
(
LocalPtr
<
RAIIHandle
>
handle
,
std
::
string
name
=
{})
:
m_handle
(
handle
),
m_name
(
name
)
{}
:
m_handle
(
handle
),
m_name
(
name
)
{}
const
LocalPtr
<
RAIIHandle
>&
handle
()
const
{
return
m_handle
;
}
const
LocalPtr
<
RAIIHandle
>&
handle
()
const
{
return
m_handle
;
}
...
@@ -57,18 +56,14 @@ public:
...
@@ -57,18 +56,14 @@ public:
ShapeValue
::
ref_t
shape
()
const
;
ShapeValue
::
ref_t
shape
()
const
;
std
::
string
name
()
const
{
return
m_name
;
}
std
::
string
name
()
const
{
return
m_name
;
}
};
class
InterpreterValue
final
:
public
MixinValueImpl
<
InterpreterValue
,
ValueKind
::
Object
,
InterpreterInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
std
::
string
to_string
()
const
override
{
return
ssprintf
(
return
ssprintf
(
"Handle{ptr=%p, name=%s}"
,
handle
().
get
(),
"Handle{ptr=%p, name=%s}"
,
handle
().
get
(),
imperative
::
quoted
(
name
()).
c_str
());
imperative
::
quoted
(
name
()).
c_str
());
}
}
void
clear
()
override
{
m_handle
=
{};
}
};
};
/**
/**
...
@@ -82,11 +77,12 @@ class InterpreterTransformation final : public Transformation {
...
@@ -82,11 +77,12 @@ class InterpreterTransformation final : public Transformation {
public:
public:
using
Interpreter
=
interpreter
::
Interpreter
;
using
Interpreter
=
interpreter
::
Interpreter
;
using
Handle
=
Interpreter
::
Handle
;
using
Handle
=
Interpreter
::
Handle
;
using
SharedHandle
=
LocalPtr
<
Interpreter
Info
::
RAIIHandle
>
;
using
SharedHandle
=
LocalPtr
<
Interpreter
Value
::
RAIIHandle
>
;
using
Channel
=
Interpreter
::
Channel
;
using
Channel
=
Interpreter
::
Channel
;
private:
private:
std
::
shared_ptr
<
Channel
>
m_channel
;
std
::
shared_ptr
<
Channel
>
m_channel
;
ObjectType
<
InterpreterValue
>
m_value_type
{
"InterpreterValue"
};
public:
public:
explicit
InterpreterTransformation
(
std
::
shared_ptr
<
Channel
>
channel
)
explicit
InterpreterTransformation
(
std
::
shared_ptr
<
Channel
>
channel
)
...
@@ -105,7 +101,7 @@ public:
...
@@ -105,7 +101,7 @@ public:
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
InterpreterValue
>
(
));
mgb_assert
(
!
value
.
is
(
m_value_type
));
return
value
;
return
value
;
}
}
...
...
imperative/src/include/megbrain/imperative/transformations/grad.h
浏览文件 @
177001d5
...
@@ -34,7 +34,8 @@ struct BackwardGraphWithClosure {
...
@@ -34,7 +34,8 @@ struct BackwardGraphWithClosure {
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
backward_graph
,
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
backward_graph
,
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
);
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
);
void
operator
()(
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
void
operator
()(
Span
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
bool
input_has_grad
(
size_t
i
)
{
return
backward_graph
->
input_has_grad
[
i
];
}
bool
input_has_grad
(
size_t
i
)
{
return
backward_graph
->
input_has_grad
[
i
];
}
...
@@ -51,7 +52,7 @@ struct CustomBackward;
...
@@ -51,7 +52,7 @@ struct CustomBackward;
using
GradRuleFn
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
inputs
,
CustomBackward
&
)
>
;
using
GradRuleFn
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
inputs
,
CustomBackward
&
)
>
;
struct
CustomBackward
{
struct
CustomBackward
{
using
BackwardFn
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
)
>
;
using
BackwardFn
=
std
::
function
<
SmallVector
<
ValueRef
>
(
Span
<
ValueRef
>
)
>
;
using
BackwardRule
=
std
::
function
<
std
::
optional
<
ValueRefList
>
(
using
BackwardRule
=
std
::
function
<
std
::
optional
<
ValueRefList
>
(
const
OpDef
&
,
Span
<
ValueRef
>
,
Span
<
bool
>
,
CustomBackward
&
)
>
;
const
OpDef
&
,
Span
<
ValueRef
>
,
Span
<
bool
>
,
CustomBackward
&
)
>
;
BackwardFn
m_backward
;
BackwardFn
m_backward
;
...
@@ -62,7 +63,8 @@ struct CustomBackward {
...
@@ -62,7 +63,8 @@ struct CustomBackward {
SmallVector
<
OutputAttr
>
m_output_attrs
;
SmallVector
<
OutputAttr
>
m_output_attrs
;
public:
public:
void
operator
()(
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
void
operator
()(
Span
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
bool
input_has_grad
(
size_t
i
)
{
return
m_input_has_grad
[
i
];
}
bool
input_has_grad
(
size_t
i
)
{
return
m_input_has_grad
[
i
];
}
bool
output_requires_grad
(
size_t
i
)
{
return
m_output_attrs
[
i
].
requires_grad
;
}
bool
output_requires_grad
(
size_t
i
)
{
return
m_output_attrs
[
i
].
requires_grad
;
}
...
@@ -175,7 +177,7 @@ inline GradSlot* GradSlotPtr::operator->() const {
...
@@ -175,7 +177,7 @@ inline GradSlot* GradSlotPtr::operator->() const {
return
&
m_fn
->
m_slots
[
m_index
];
return
&
m_fn
->
m_slots
[
m_index
];
}
}
class
GradValue
final
:
public
ValueImpl
<
GradValue
,
ValueKind
::
Object
>
{
class
GradValue
final
:
public
ObjectValue
<
GradValue
>
{
private:
private:
ValueRef
m_value
;
ValueRef
m_value
;
std
::
shared_ptr
<
GradKey
>
m_key
;
std
::
shared_ptr
<
GradKey
>
m_key
;
...
@@ -187,14 +189,9 @@ public:
...
@@ -187,14 +189,9 @@ public:
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
bool
has_key
(
const
std
::
shared_ptr
<
GradKey
>&
key
)
const
{
return
m_key
==
key
;
}
const
GradSlotPtr
&
slot
()
const
{
return
m_slot
;
}
const
GradSlotPtr
&
slot_for
(
std
::
shared_ptr
<
GradKey
>
key
)
const
{
// std::shared_ptr<GradKey> key() const { return m_key; }
mgb_assert
(
m_key
==
key
);
return
m_slot
;
}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
void
clear
()
override
{
void
clear
()
override
{
m_slot
=
{};
m_slot
=
{};
...
@@ -216,9 +213,12 @@ private:
...
@@ -216,9 +213,12 @@ private:
std
::
vector
<
std
::
pair
<
LocalWeakPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_tape
;
std
::
vector
<
std
::
pair
<
LocalWeakPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_tape
;
std
::
vector
<
std
::
pair
<
LocalPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_frozen_tape
;
std
::
vector
<
std
::
pair
<
LocalPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_frozen_tape
;
bool
m_frozen
=
false
;
bool
m_frozen
=
false
;
const
Type
<
GradValue
>&
m_value_type
;
public:
public:
GradKey
()
{
m_tape
.
reserve
(
4
*
1024
);
}
GradKey
(
const
Type
<
GradValue
>&
value_type
)
:
m_value_type
(
value_type
)
{
m_tape
.
reserve
(
4
*
1024
);
}
void
backward
();
void
backward
();
GradValue
::
ref_t
attach
(
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
);
GradValue
::
ref_t
attach
(
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
);
...
@@ -230,10 +230,9 @@ public:
...
@@ -230,10 +230,9 @@ public:
};
};
class
GradKeyValue
final
class
GradKeyValue
final
:
public
MixinValueImpl
<
:
public
PrimitiveValue
<
GradKeyValue
,
std
::
shared_ptr
<
GradKey
>>
{
GradKeyValue
,
ValueKind
::
Primitive
,
std
::
shared_ptr
<
GradKey
>>
{
public:
public:
using
MixinValueImpl
::
MixinValueImpl
;
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
{
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GradKey{%s}"
,
(
*
this
)
->
name
().
c_str
());
return
ssprintf
(
"GradKey{%s}"
,
(
*
this
)
->
name
().
c_str
());
...
@@ -242,26 +241,20 @@ public:
...
@@ -242,26 +241,20 @@ public:
class
GradTransformation
final
:
public
Transformation
{
class
GradTransformation
final
:
public
Transformation
{
private:
private:
ObjectType
<
GradValue
>
m_value_type
{
"GradValue"
};
std
::
shared_ptr
<
GradKey
>
m_key
;
std
::
shared_ptr
<
GradKey
>
m_key
;
std
::
vector
<
GradValue
::
weak_ref_t
>
m_weak_values
;
std
::
vector
<
GradValue
::
weak_ref_t
>
m_weak_values
;
size_t
m_suppressed
=
0
;
size_t
m_suppressed
=
0
;
public:
public:
GradTransformation
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{
}
GradTransformation
(
)
{
m_key
=
std
::
make_shared
<
GradKey
>
(
m_value_type
);
}
auto
record_grad
(
GradValue
::
ref_t
tensor
)
{
auto
record_grad
(
GradValue
::
ref_t
tensor
)
{
m_weak_values
.
push_back
(
tensor
);
m_weak_values
.
push_back
(
tensor
);
return
tensor
;
return
tensor
;
}
}
bool
is_grad_value
(
const
ValueRef
&
value
)
{
bool
is_grad_value
(
const
ValueRef
&
value
)
{
return
value
.
is
(
m_value_type
);
}
if
(
auto
*
grad_value
=
value
.
as
<
GradValue
>
())
{
if
(
grad_value
->
has_key
(
m_key
))
{
return
true
;
}
}
return
false
;
}
/**
/**
* \brief test whether value is related to this GradTransformation
* \brief test whether value is related to this GradTransformation
...
@@ -273,13 +266,7 @@ public:
...
@@ -273,13 +266,7 @@ public:
* \return GradValue::ref_t
* \return GradValue::ref_t
*/
*/
const
GradValue
::
ref_t
&
as_grad_value
(
const
ValueRef
&
value
)
{
const
GradValue
::
ref_t
&
as_grad_value
(
const
ValueRef
&
value
)
{
auto
&&
grad_value
=
value
.
as_ref
<
GradValue
>
();
return
value
.
as_ref
(
m_value_type
);
if
(
grad_value
)
{
if
(
grad_value
->
has_key
(
m_key
))
{
return
grad_value
;
}
}
return
GradValue
::
ref_t
::
nil
;
}
}
bool
has_key
(
std
::
shared_ptr
<
GradKey
>
key
)
{
bool
has_key
(
std
::
shared_ptr
<
GradKey
>
key
)
{
...
@@ -299,6 +286,8 @@ public:
...
@@ -299,6 +286,8 @@ public:
return
value
;
return
value
;
}
}
const
std
::
shared_ptr
<
GradKey
>&
key
()
const
{
return
m_key
;
}
std
::
string
name
()
const
override
{
return
"GradTransformation"
;
}
std
::
string
name
()
const
override
{
return
"GradTransformation"
;
}
GenericFunction
make_backward_closure
(
Span
<
ValueRef
>
ys
);
GenericFunction
make_backward_closure
(
Span
<
ValueRef
>
ys
);
...
...
imperative/src/include/megbrain/imperative/transformations/lazy.h
浏览文件 @
177001d5
...
@@ -22,32 +22,27 @@
...
@@ -22,32 +22,27 @@
namespace
mgb
::
imperative
{
namespace
mgb
::
imperative
{
class
LazyEval
Info
{
class
LazyEval
Value
final
:
public
ObjectValue
<
LazyEvalValue
>
{
private:
private:
VarNode
*
m_node
=
nullptr
;
VarNode
*
m_node
=
nullptr
;
ValueRef
m_bound_data
;
ValueRef
m_bound_data
;
std
::
string
m_name
;
std
::
string
m_name
;
public:
public:
LazyEvalInfo
()
=
default
;
LazyEvalValue
(
VarNode
*
node
,
ValueRef
bound_data
,
std
::
string
name
)
LazyEvalInfo
(
VarNode
*
node
,
ValueRef
bound_data
,
std
::
string
name
)
:
m_node
(
node
),
m_bound_data
(
bound_data
),
m_name
(
name
)
{}
:
m_node
(
node
),
m_bound_data
(
bound_data
),
m_name
(
name
)
{}
VarNode
*
node
()
const
{
return
m_node
;
}
VarNode
*
node
()
const
{
return
m_node
;
}
ValueRef
bound_data
()
const
{
return
m_bound_data
;
}
ValueRef
bound_data
()
const
{
return
m_bound_data
;
}
std
::
string
name
()
const
{
return
m_name
;
}
std
::
string
name
()
const
{
return
m_name
;
}
};
class
LazyEvalValue
final
:
public
MixinValueImpl
<
LazyEvalValue
,
ValueKind
::
Object
,
LazyEvalInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
std
::
string
to_string
()
const
override
{
return
ssprintf
(
return
ssprintf
(
"LazyEvalValue{node=%p, name=%s}"
,
node
(),
node
()
->
name
().
c_str
());
"LazyEvalValue{node=%p, name=%s}"
,
node
(),
node
()
->
name
().
c_str
());
}
}
void
clear
()
override
{}
};
};
/**
/**
...
@@ -67,6 +62,7 @@ private:
...
@@ -67,6 +62,7 @@ private:
std
::
vector
<
LazyEvalValue
::
weak_ref_t
>
m_weak_vars
;
std
::
vector
<
LazyEvalValue
::
weak_ref_t
>
m_weak_vars
;
SymbolVar
m_io_link
=
nullptr
;
SymbolVar
m_io_link
=
nullptr
;
std
::
exception_ptr
m_graph_exc
;
std
::
exception_ptr
m_graph_exc
;
ObjectType
<
LazyEvalValue
>
m_value_type
{
"LazyEvalValue"
};
public:
public:
LazyEvalTransformation
(
bool
no_exec
)
:
m_no_exec
(
no_exec
)
{
LazyEvalTransformation
(
bool
no_exec
)
:
m_no_exec
(
no_exec
)
{
...
@@ -75,7 +71,7 @@ public:
...
@@ -75,7 +71,7 @@ public:
LazyEvalValue
::
ref_t
record_var
(
LazyEvalValue
::
ref_t
record_var
(
VarNode
*
node
,
ValueRef
bound_data
=
{},
std
::
string
name
=
{})
{
VarNode
*
node
,
ValueRef
bound_data
=
{},
std
::
string
name
=
{})
{
auto
lazy_eval_val
=
LazyEvalValue
::
make
(
node
,
bound_data
,
name
);
auto
lazy_eval_val
=
m_value_type
.
make
(
node
,
bound_data
,
name
);
m_weak_vars
.
push_back
(
lazy_eval_val
);
m_weak_vars
.
push_back
(
lazy_eval_val
);
return
lazy_eval_val
;
return
lazy_eval_val
;
}
}
...
@@ -86,7 +82,7 @@ public:
...
@@ -86,7 +82,7 @@ public:
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
LazyEvalValue
>
(
));
mgb_assert
(
!
value
.
is
(
m_value_type
));
return
value
;
return
value
;
}
}
...
...
imperative/src/include/megbrain/imperative/transformations/scalar.h
浏览文件 @
177001d5
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
namespace
mgb
::
imperative
{
namespace
mgb
::
imperative
{
class
ScalarValue
final
:
public
ValueImpl
<
ScalarValue
,
ValueKind
::
Object
>
{
class
ScalarValue
final
:
public
ObjectValue
<
ScalarValue
>
{
private:
private:
ValueRef
m_value
;
ValueRef
m_value
;
...
@@ -47,17 +47,21 @@ public:
...
@@ -47,17 +47,21 @@ public:
class
ScalarTransformation
final
:
public
Transformation
{
class
ScalarTransformation
final
:
public
Transformation
{
private:
private:
ShapeValue
::
ref_t
m_empty_shape
;
// []
ShapeValue
::
ref_t
m_empty_shape
;
// []
ObjectType
<
ScalarValue
>
m_value_type
{
"ScalarValue"
};
public:
public:
ValueRefList
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_transformation
(
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
ScalarValue
>
(
));
mgb_assert
(
!
value
.
is
(
m_value_type
));
return
value
;
return
value
;
}
}
std
::
string
name
()
const
override
{
return
"ScalarTransformation"
;
}
std
::
string
name
()
const
override
{
return
"ScalarTransformation"
;
}
const
Type
<
ScalarValue
>&
value_type
()
const
{
return
m_value_type
;
}
};
};
}
// namespace mgb::imperative
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/transformations/symbol.h
浏览文件 @
177001d5
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
namespace
mgb
::
imperative
{
namespace
mgb
::
imperative
{
class
SymbolValue
final
:
public
ValueImpl
<
SymbolValue
,
ValueKind
::
Object
>
{
class
SymbolValue
final
:
public
ObjectValue
<
SymbolValue
>
{
private:
private:
VarNode
*
m_node
=
nullptr
;
VarNode
*
m_node
=
nullptr
;
...
@@ -47,6 +47,7 @@ public:
...
@@ -47,6 +47,7 @@ public:
class
SymbolTransformation
final
:
public
Transformation
{
class
SymbolTransformation
final
:
public
Transformation
{
private:
private:
ComputingGraph
*
m_graph
=
nullptr
;
ComputingGraph
*
m_graph
=
nullptr
;
ObjectType
<
SymbolValue
>
m_value_type
{
"SymbolValue"
};
public:
public:
SymbolTransformation
(
ComputingGraph
*
graph
)
:
m_graph
(
graph
)
{}
SymbolTransformation
(
ComputingGraph
*
graph
)
:
m_graph
(
graph
)
{}
...
@@ -55,12 +56,12 @@ public:
...
@@ -55,12 +56,12 @@ public:
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
SmallVector
<
VarNode
*>
input_nodes
;
SmallVector
<
VarNode
*>
input_nodes
;
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
input_nodes
.
push_back
(
input
.
cast
<
SymbolValue
>
(
).
node
());
input_nodes
.
push_back
(
input
.
cast
(
m_value_type
).
node
());
}
}
auto
output_nodes
=
OpDef
::
apply_on_var_node
(
apply_op
->
op
(),
input_nodes
);
auto
output_nodes
=
OpDef
::
apply_on_var_node
(
apply_op
->
op
(),
input_nodes
);
ValueRefList
outputs
(
output_nodes
.
size
());
ValueRefList
outputs
(
output_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
output_nodes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_nodes
.
size
();
++
i
)
{
outputs
[
i
]
=
SymbolValue
::
make
(
output_nodes
[
i
]);
outputs
[
i
]
=
m_value_type
.
make
(
output_nodes
[
i
]);
}
}
return
outputs
;
return
outputs
;
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
...
@@ -69,9 +70,9 @@ public:
...
@@ -69,9 +70,9 @@ public:
args
.
kind
==
CreateTensor
::
Const
,
args
.
kind
==
CreateTensor
::
Const
,
"only const value is allowed here"
);
"only const value is allowed here"
);
auto
*
node
=
opr
::
ImmutableTensor
::
make
(
*
m_graph
,
*
args
.
host
,
{}).
node
();
auto
*
node
=
opr
::
ImmutableTensor
::
make
(
*
m_graph
,
*
args
.
host
,
{}).
node
();
return
{
SymbolValue
::
make
(
node
)};
return
{
m_value_type
.
make
(
node
)};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
*
node
=
inputs
.
as_array
<
1
>
()[
0
].
cast
<
SymbolValue
>
(
).
node
();
auto
*
node
=
inputs
.
item
().
cast
(
m_value_type
).
node
();
switch
(
get_attr
->
attr
())
{
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
DType
:
case
GetAttr
::
DType
:
return
{
DTypeValue
::
make
(
node
->
dtype
())};
return
{
DTypeValue
::
make
(
node
->
dtype
())};
...
@@ -121,11 +122,13 @@ public:
...
@@ -121,11 +122,13 @@ public:
}
}
ValueRef
unwrap
(
ValueRef
value
)
override
{
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
SymbolValue
>
(
),
"SymbolValue doesn't support unwrap"
);
mgb_assert
(
!
value
.
is
(
m_value_type
),
"SymbolValue doesn't support unwrap"
);
return
value
;
return
value
;
}
}
std
::
string
name
()
const
override
{
return
"SymbolTransformation"
;
}
std
::
string
name
()
const
override
{
return
"SymbolTransformation"
;
}
const
Type
<
SymbolValue
>&
value_type
()
const
{
return
m_value_type
;
}
};
};
}
// namespace mgb::imperative
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/transformations/trace.h
浏览文件 @
177001d5
...
@@ -100,22 +100,15 @@ public:
...
@@ -100,22 +100,15 @@ public:
}
}
};
};
class
Tracing
Info
{
class
Tracing
Value
final
:
public
ObjectValue
<
TracingValue
>
{
private:
private:
ValueRef
m_value
=
{};
ValueRef
m_value
=
{};
size_t
m_id
=
0
;
size_t
m_id
=
0
;
public:
public:
TracingInfo
()
=
default
;
TracingValue
(
ValueRef
value
,
size_t
id
)
:
m_value
(
value
),
m_id
(
id
)
{}
TracingInfo
(
ValueRef
value
,
size_t
id
)
:
m_value
(
value
),
m_id
(
id
)
{}
ValueRef
value
()
const
{
return
m_value
;
}
ValueRef
value
()
const
{
return
m_value
;
}
size_t
id
()
const
{
return
m_id
;
}
size_t
id
()
const
{
return
m_id
;
}
};
class
TracingValue
final
:
public
MixinValueImpl
<
TracingValue
,
ValueKind
::
Object
,
TracingInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
std
::
string
to_string
()
const
override
{
return
ssprintf
(
return
ssprintf
(
...
@@ -126,6 +119,8 @@ public:
...
@@ -126,6 +119,8 @@ public:
void
on_watch
()
override
{
value
().
watch
();
}
void
on_watch
()
override
{
value
().
watch
();
}
void
on_unwatch
()
override
{
value
().
unwatch
();
}
void
on_unwatch
()
override
{
value
().
unwatch
();
}
void
clear
()
override
{
m_value
=
{};
}
};
};
/**
/**
...
@@ -146,6 +141,7 @@ private:
...
@@ -146,6 +141,7 @@ private:
std
::
vector
<
TracingValue
::
weak_ref_t
>
m_weak_vars
;
std
::
vector
<
TracingValue
::
weak_ref_t
>
m_weak_vars
;
bool
m_capture_as_const
=
false
;
bool
m_capture_as_const
=
false
;
bool
m_record_input_shapes
=
false
;
bool
m_record_input_shapes
=
false
;
ObjectType
<
TracingValue
>
m_value_type
{
"TracingValue"
};
public:
public:
TracingTransformation
(
bool
capture_as_const
,
bool
record_input_shapes
)
TracingTransformation
(
bool
capture_as_const
,
bool
record_input_shapes
)
...
@@ -162,7 +158,7 @@ public:
...
@@ -162,7 +158,7 @@ public:
*/
*/
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
size_t
id
=
m_vars
.
size
();
size_t
id
=
m_vars
.
size
();
auto
wrapped_value
=
TracingValue
::
make
(
value
,
id
);
auto
wrapped_value
=
m_value_type
.
make
(
value
,
id
);
m_vars
.
push_back
({
id
,
value
.
dtype
(),
value
.
device
()});
m_vars
.
push_back
({
id
,
value
.
dtype
(),
value
.
device
()});
auto
&
var
=
m_vars
.
back
();
auto
&
var
=
m_vars
.
back
();
if
(
capture
)
{
if
(
capture
)
{
...
@@ -179,7 +175,7 @@ public:
...
@@ -179,7 +175,7 @@ public:
return
wrapped_value
;
return
wrapped_value
;
}
}
ValueRef
unwrap_var
(
ValueRef
value
)
{
ValueRef
unwrap_var
(
ValueRef
value
)
{
if
(
auto
*
tracing_value
=
value
.
as
<
TracingValue
>
(
))
{
if
(
auto
*
tracing_value
=
value
.
as
(
m_value_type
))
{
return
tracing_value
->
value
();
return
tracing_value
->
value
();
}
}
return
value
;
return
value
;
...
@@ -189,7 +185,7 @@ public:
...
@@ -189,7 +185,7 @@ public:
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
ValueRef
unwrap
(
ValueRef
value
)
override
{
if
(
auto
*
tracing_value
=
value
.
as
<
TracingValue
>
(
))
{
if
(
auto
*
tracing_value
=
value
.
as
(
m_value_type
))
{
return
tracing_value
->
value
();
return
tracing_value
->
value
();
}
}
return
value
;
return
value
;
...
@@ -234,7 +230,7 @@ public:
...
@@ -234,7 +230,7 @@ public:
std
::
function
<
void
(
std
::
exception_ptr
)
>
exc_setter
;
std
::
function
<
void
(
std
::
exception_ptr
)
>
exc_setter
;
};
};
class
Traced
Info
{
class
Traced
Value
final
:
public
ObjectValue
<
TracedValue
>
{
private:
private:
size_t
m_id
=
0
;
size_t
m_id
=
0
;
VarInfo
*
m_var
=
nullptr
;
VarInfo
*
m_var
=
nullptr
;
...
@@ -244,8 +240,7 @@ public:
...
@@ -244,8 +240,7 @@ public:
mutable
CompNodeValue
::
ref_t
m_comp_node
;
mutable
CompNodeValue
::
ref_t
m_comp_node
;
public:
public:
TracedInfo
()
=
default
;
TracedValue
(
size_t
id
,
VarInfo
*
var
,
VarAccessor
*
accessor
)
TracedInfo
(
size_t
id
,
VarInfo
*
var
,
VarAccessor
*
accessor
)
:
m_id
(
id
),
m_var
(
var
),
m_accessor
(
accessor
)
{}
:
m_id
(
id
),
m_var
(
var
),
m_accessor
(
accessor
)
{}
size_t
id
()
const
{
return
m_id
;
}
size_t
id
()
const
{
return
m_id
;
}
ShapeValue
::
ref_t
shape
()
const
;
ShapeValue
::
ref_t
shape
()
const
;
...
@@ -256,16 +251,12 @@ public:
...
@@ -256,16 +251,12 @@ public:
void
set_exception
(
std
::
exception_ptr
exc
)
const
{
void
set_exception
(
std
::
exception_ptr
exc
)
const
{
m_accessor
->
exc_setter
(
exc
);
m_accessor
->
exc_setter
(
exc
);
}
}
};
class
TracedValue
final
:
public
MixinValueImpl
<
TracedValue
,
ValueKind
::
Object
,
TracedInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TracedValue{
\"
id
\"
=%zu}"
,
id
());
return
ssprintf
(
"TracedValue{
\"
id
\"
=%zu}"
,
id
());
}
}
void
clear
()
override
{}
};
};
private:
private:
...
@@ -280,9 +271,12 @@ private:
...
@@ -280,9 +271,12 @@ private:
std
::
function
<
bool
(
ValueRef
,
ValueRef
)
>
m_value_comparator
;
std
::
function
<
bool
(
ValueRef
,
ValueRef
)
>
m_value_comparator
;
bool
m_input_shape_static
;
bool
m_input_shape_static
;
std
::
mutex
m_mutex
;
std
::
mutex
m_mutex
;
std
::
condition_variable
m_cv
;
std
::
exception_ptr
m_graph_exc
;
std
::
exception_ptr
m_graph_exc
;
int
m_graph_status
=
0
;
// 0 = stop, 1 = running, 2 = finalizing
std
::
vector
<
std
::
shared_ptr
<
BoxBase
>>
m_boxes
;
std
::
vector
<
std
::
shared_ptr
<
BoxBase
>>
m_boxes
;
ComputingGraph
::
OutputSpec
m_output_spec
;
ComputingGraph
::
OutputSpec
m_output_spec
;
ObjectType
<
TracedValue
>
m_value_type
{
"TracedValue"
};
public:
public:
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
)
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
)
...
@@ -292,6 +286,27 @@ public:
...
@@ -292,6 +286,27 @@ public:
m_graph
=
ComputingGraph
::
make
();
m_graph
=
ComputingGraph
::
make
();
options
().
no_force_inplace
=
true
;
options
().
no_force_inplace
=
true
;
options
().
async_exec_level
=
0b100
;
options
().
async_exec_level
=
0b100
;
m_graph_executor
=
std
::
thread
([
&
]
{
while
(
true
)
{
std
::
unique_lock
lock
{
m_mutex
};
m_cv
.
wait
(
lock
,
[
&
]
{
return
m_graph_status
!=
0
;
});
lock
.
unlock
();
if
(
m_graph_status
==
2
)
{
break
;
}
try
{
m_executable
->
execute
();
m_executable
->
wait
();
}
catch
(...)
{
auto
exc
=
std
::
current_exception
();
set_exception
(
exc
);
}
lock
.
lock
();
m_graph_status
=
0
;
lock
.
unlock
();
m_cv
.
notify_all
();
}
});
}
}
ComputingGraph
&
graph
()
{
return
*
m_graph
;
}
ComputingGraph
&
graph
()
{
return
*
m_graph
;
}
...
@@ -350,7 +365,7 @@ public:
...
@@ -350,7 +365,7 @@ public:
void
on_unregister
()
noexcept
override
;
void
on_unregister
()
noexcept
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
TracedValue
>
(
));
mgb_assert
(
!
value
.
is
(
m_value_type
));
return
value
;
return
value
;
}
}
...
@@ -368,6 +383,15 @@ public:
...
@@ -368,6 +383,15 @@ public:
m_boxes
.
push_back
(
box
);
m_boxes
.
push_back
(
box
);
return
box
;
return
box
;
}
}
~
CompiledTransformation
()
{
{
MGB_LOCK_GUARD
(
m_mutex
);
m_graph_status
=
2
;
}
m_cv
.
notify_all
();
m_graph_executor
.
join
();
}
};
};
}
// namespace mgb::imperative
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/allocator.h
浏览文件 @
177001d5
...
@@ -11,7 +11,9 @@
...
@@ -11,7 +11,9 @@
#pragma once
#pragma once
#include <optional>
#include <typeindex>
#include <typeindex>
#include <vector>
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/metahelper.h"
...
...
imperative/src/include/megbrain/imperative/utils/span.h
浏览文件 @
177001d5
...
@@ -34,7 +34,7 @@ public:
...
@@ -34,7 +34,7 @@ public:
Span
(
const
T
*
begin
,
const
T
*
end
)
:
m_begin
{
begin
},
m_end
{
end
}
{}
Span
(
const
T
*
begin
,
const
T
*
end
)
:
m_begin
{
begin
},
m_end
{
end
}
{}
Span
(
const
T
*
begin
,
size_t
size
)
:
Span
(
begin
,
begin
+
size
)
{}
Span
(
const
T
*
begin
,
size_t
size
)
:
Span
(
begin
,
begin
+
size
)
{}
template
<
typename
TContainer
>
template
<
typename
TContainer
>
Span
(
TContainer
&
container
)
:
Span
(
container
.
data
(),
container
.
size
())
{}
Span
(
const
TContainer
&
container
)
:
Span
(
container
.
data
(),
container
.
size
())
{}
const
T
*
begin
()
const
{
return
m_begin
;
}
const
T
*
begin
()
const
{
return
m_begin
;
}
const
T
*
end
()
const
{
return
m_end
;
}
const
T
*
end
()
const
{
return
m_end
;
}
const
T
*
data
()
const
{
return
m_begin
;
}
const
T
*
data
()
const
{
return
m_begin
;
}
...
...
imperative/src/include/megbrain/imperative/utils/stats.h
浏览文件 @
177001d5
...
@@ -2,7 +2,10 @@
...
@@ -2,7 +2,10 @@
#include <chrono>
#include <chrono>
#include <iostream>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <vector>
#include <vector>
namespace
mgb
{
namespace
mgb
{
...
@@ -18,7 +21,7 @@ public:
...
@@ -18,7 +21,7 @@ public:
private:
private:
clock_t
::
duration
m_duration
=
clock_t
::
duration
{
0
};
clock_t
::
duration
m_duration
=
clock_t
::
duration
{
0
};
size_t
m_timing
=
0
;
size_t
m_timing
=
0
;
const
char
*
m_name
=
nullptr
;
std
::
string
m_name
;
uint64_t
m_count
=
0
;
uint64_t
m_count
=
0
;
size_t
m_enabled
=
1
;
size_t
m_enabled
=
1
;
bool
m_default_enabled
=
true
;
bool
m_default_enabled
=
true
;
...
@@ -42,7 +45,8 @@ private:
...
@@ -42,7 +45,8 @@ private:
}
}
if
(
timer
.
m_enabled
)
{
if
(
timer
.
m_enabled
)
{
if
(
!--
timer
.
m_timing
)
{
if
(
!--
timer
.
m_timing
)
{
timer
.
m_duration
+=
(
clock_t
::
now
()
-
start
);
auto
duration
=
(
clock_t
::
now
()
-
start
);
timer
.
m_duration
+=
duration
;
}
}
timer
.
m_count
++
;
timer
.
m_count
++
;
}
}
...
@@ -67,13 +71,10 @@ private:
...
@@ -67,13 +71,10 @@ private:
}
}
};
};
using
TimeScope
=
TimeScopeRecursive
;
public:
public:
Timer
(
const
char
*
name
,
bool
default_enabled
);
Timer
(
std
::
string
name
,
bool
default_enabled
=
true
);
const
char
*
name
()
{
return
m_name
;
}
std
::
string
name
()
{
return
m_name
;
}
auto
time_scope
()
{
return
TimeScope
(
*
this
);
}
auto
time_scope_recursive
()
{
return
TimeScopeRecursive
(
*
this
);
};
auto
time_scope_recursive
()
{
return
TimeScopeRecursive
(
*
this
);
};
auto
enable_scope
()
{
return
EnableScope
(
*
this
);
}
auto
enable_scope
()
{
return
EnableScope
(
*
this
);
}
void
reset
()
{
void
reset
()
{
...
@@ -88,7 +89,14 @@ public:
...
@@ -88,7 +89,14 @@ public:
}
// namespace stats
}
// namespace stats
struct
Stats
{
struct
Stats
{
static
inline
std
::
vector
<
stats
::
Timer
*>
sm_timers
;
struct
TimerNode
{
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
TimerNode
>>
children
;
stats
::
Timer
*
timer
=
nullptr
;
TimerNode
()
{}
};
static
inline
TimerNode
sm_root
;
// register your timers here
// register your timers here
// for example:
// for example:
...
@@ -97,33 +105,84 @@ struct Stats {
...
@@ -97,33 +105,84 @@ struct Stats {
//
//
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code
static
void
print
()
{
static
std
::
pair
<
long
,
long
>
print_node
(
std
::
vector
<
const
char
*>
unused_timers
;
std
::
string
name
,
TimerNode
&
node
,
size_t
indent
=
0
)
{
auto
print_indent
=
[
&
]
{
for
(
auto
*
timer
:
sm_timers
)
{
for
(
size_t
i
=
0
;
i
<
indent
;
++
i
)
{
if
(
timer
->
count
()
==
0
)
{
printf
(
" "
);
unused_timers
.
push_back
(
timer
->
name
());
}
}
else
{
};
printf
(
"%s costs %ld ns, happens %ld times
\n
"
,
timer
->
name
(),
long
ns
=
0
,
count
=
0
;
timer
->
get
().
count
(),
timer
->
count
());
if
(
auto
*
timer
=
node
.
timer
)
{
print_indent
();
printf
(
"%s costs %'ld ns, hits %'ld times
\n
"
,
name
.
c_str
(),
(
long
)
timer
->
get
().
count
(),
(
long
)
timer
->
count
());
ns
=
timer
->
get
().
count
();
count
=
timer
->
count
();
}
if
(
!
node
.
children
.
empty
())
{
bool
collect_children
=
node
.
timer
==
nullptr
;
if
(
collect_children
)
{
print_indent
();
printf
(
"%s:
\n
"
,
name
.
c_str
());
}
long
ns
=
0
,
count
=
0
;
for
(
auto
&&
child
:
node
.
children
)
{
auto
[
child_ns
,
child_count
]
=
print_node
(
child
.
first
,
*
child
.
second
,
indent
+
4
);
if
(
collect_children
)
{
ns
+=
child_ns
;
count
+=
child_count
;
}
}
if
(
collect_children
)
{
print_indent
();
printf
(
"total costs %'ld ns, hits %'ld times
\n
"
,
ns
,
count
);
}
}
}
}
return
{
ns
,
count
};
}
if
(
!
unused_timers
.
empty
())
{
static
void
print
()
{
printf
(
"%zu timers unused
\n
"
,
unused_timers
.
size
());
for
(
auto
&&
child
:
sm_root
.
children
)
{
print_node
(
child
.
first
,
*
child
.
second
);
}
}
}
}
static
void
reset
()
{
static
void
reset
()
{
for
(
auto
*
timer
:
sm_timers
)
{
auto
reset_node
=
[](
TimerNode
&
node
,
auto
&&
reset_node
)
->
void
{
timer
->
reset
();
if
(
auto
*
timer
=
node
.
timer
)
{
}
timer
->
reset
();
}
for
(
auto
&&
child
:
node
.
children
)
{
reset_node
(
*
child
.
second
,
reset_node
);
}
};
reset_node
(
sm_root
,
reset_node
);
}
}
};
};
inline
stats
::
Timer
::
Timer
(
const
char
*
name
,
bool
default_enabled
)
inline
stats
::
Timer
::
Timer
(
std
::
string
name
,
bool
default_enabled
)
:
m_name
(
name
),
m_default_enabled
(
default_enabled
)
{
:
m_name
(
name
),
m_default_enabled
(
default_enabled
)
{
Stats
::
sm_timers
.
push_back
(
this
);
std
::
vector
<
std
::
string
>
terms
;
Stats
::
TimerNode
*
node
=
&
Stats
::
sm_root
;
while
(
true
)
{
auto
pos
=
name
.
find
(
"."
);
if
(
pos
==
std
::
string
::
npos
)
{
auto
&
child
=
node
->
children
[
name
];
child
=
std
::
make_unique
<
Stats
::
TimerNode
>
();
node
=
child
.
get
();
node
->
timer
=
this
;
break
;
}
else
{
auto
&
child
=
node
->
children
[
name
.
substr
(
0
,
pos
)];
if
(
!
child
)
{
child
=
std
::
make_unique
<
Stats
::
TimerNode
>
();
}
node
=
child
.
get
();
name
=
name
.
substr
(
pos
+
1
);
}
}
}
}
#if MGE_ENABLE_STATS
#if MGE_ENABLE_STATS
...
...
imperative/src/include/megbrain/imperative/value.h
浏览文件 @
177001d5
...
@@ -50,18 +50,70 @@ class Operator;
...
@@ -50,18 +50,70 @@ class Operator;
class
ValueRefList
;
class
ValueRefList
;
/**
* \brief base class of all value types
*/
class
IType
:
public
NonCopyableObj
{
private:
std
::
string
m_name
;
// TODO: count values, or make an linkedlist
public:
IType
(
std
::
string
name
)
:
m_name
(
std
::
move
(
name
))
{}
const
std
::
string
&
name
()
const
{
return
m_name
;
}
bool
operator
==
(
const
IType
&
rhs
)
const
{
return
this
==
&
rhs
;
}
bool
operator
!=
(
const
IType
&
rhs
)
const
{
return
this
!=
&
rhs
;
}
};
/**
* \brief type of values.
*
* \tparam T ctype of value
*/
template
<
typename
T
>
class
Type
:
public
IType
{
protected:
Type
(
std
::
string
name
)
:
IType
(
std
::
move
(
name
))
{}
// TODO: each type owns an allocator
public:
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template
<
typename
...
TArgs
>
TypedValueRef
<
T
>
make
(
TArgs
&&
...
args
)
const
;
};
/**
* \brief type of primitive values.
*
* \tparam T ctype of value
*/
template
<
typename
T
>
template
<
typename
T
>
class
Type
{
class
PrimitiveType
:
public
Type
<
T
>
{
private:
private:
const
size_t
m_code
=
T
::
TYPE_CODE
;
PrimitiveType
()
;
public:
public:
inline
size_t
code
()
const
{
return
m_code
;
}
static
inline
PrimitiveType
instance
;
};
};
enum
class
ValueKind
{
/**
Primitive
,
* \brief type of object values.
Object
,
*
* \tparam T ctype of value
*/
template
<
typename
T
>
class
ObjectType
:
public
Type
<
T
>
{
public:
ObjectType
(
std
::
string
name
)
:
Type
<
T
>
(
name
)
{}
};
};
/**
/**
...
@@ -71,9 +123,8 @@ enum class ValueKind {
...
@@ -71,9 +123,8 @@ enum class ValueKind {
* and only the tail node is valid. ValueRef stores a value node, and it may be
* and only the tail node is valid. ValueRef stores a value node, and it may be
* an invalid internal node. When you dereference it, it will check its successor,
* an invalid internal node. When you dereference it, it will check its successor,
* automatically find the tail node and return. This list would be modified to reduce
* automatically find the tail node and return. This list would be modified to reduce
* list length by change value's successor, but a ValueRef always has steady m_storage
* list length by change value's successor, but a steady id was kept in ValueRef
* when not explicitly modified.
* so we can use it for identify a ValueRef ( hash / equility / id ).
* So we use m_storage to identify a ValueRef ( hash / equility / id ).
*/
*/
class
ValueRef
{
class
ValueRef
{
public:
public:
...
@@ -93,9 +144,7 @@ private:
...
@@ -93,9 +144,7 @@ private:
*/
*/
storage_t
&
storage
()
const
;
storage_t
&
storage
()
const
;
const
Value
*
as
(
size_t
typecode
)
const
;
const
Value
*
as
(
const
IType
&
type
)
const
;
bool
is
(
size_t
typecode
)
const
;
public:
public:
ValueRef
()
=
default
;
ValueRef
()
=
default
;
...
@@ -103,45 +152,76 @@ public:
...
@@ -103,45 +152,76 @@ public:
/**
/**
* \brief whether value is instance of target type or not
* \brief whether value is instance of target type or not
*
*
* \
tparam TValu
e target type
* \
param typ
e target type
* \return true if type of value is
TValu
e
* \return true if type of value is
instance of typ
e
* \return false if empty or type of value is not
TValu
e
* \return false if empty or type of value is not
instance of typ
e
*/
*/
template
<
typename
TValue
>
bool
is
(
const
IType
&
type
)
const
;
inline
bool
is
(
Type
<
TValue
>
type
=
{})
const
;
/**
/**
* \brief try cast value as target type
* \brief try cast value as target type
*
*
* \tparam
TValu
e target type
* \tparam
typ
e target type
* \return TValue* raw pointer if success, otherwise nullptr
* \return TValue* raw pointer if success, otherwise nullptr
*/
*/
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
const
TValue
*
as
(
Type
<
TValue
>
type
=
{}
)
const
;
inline
const
TValue
*
as
(
const
Type
<
TValue
>&
type
)
const
;
/**
/**
* \brief cast value to target type
* \brief cast value to target type
*
*
* \
tparam TValu
e target type
* \
param typ
e target type
* \return TValue& reference of value
* \return TValue& reference of value
*/
*/
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
const
TValue
&
cast
(
Type
<
TValue
>
type
=
{}
)
const
;
inline
const
TValue
&
cast
(
const
Type
<
TValue
>&
type
)
const
;
/**
/**
* \brief like as(), but returns TypedValueRef instead
* \brief like as(), but returns TypedValueRef instead
*
*
* \
tparam TValu
e target type
* \
param typ
e target type
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
*/
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
const
TypedValueRef
<
TValue
>&
as_ref
(
Type
<
TValue
>
type
=
{})
const
;
inline
const
TypedValueRef
<
TValue
>&
as_ref
(
const
Type
<
TValue
>&
type
)
const
;
/**
* \brief like cast(), but allow empty value and returns TypedValueRef instead
*
* \param type target type
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template
<
typename
TValue
>
inline
const
TypedValueRef
<
TValue
>&
cast_ref
(
const
Type
<
TValue
>&
type
)
const
;
template
<
typename
TValue
>
inline
std
::
enable_if_t
<
TValue
::
is_primitive
,
bool
>
is
()
const
{
return
is
(
PrimitiveType
<
TValue
>::
instance
);
}
template
<
typename
TValue
>
inline
std
::
enable_if_t
<
TValue
::
is_primitive
,
const
TValue
*>
as
()
const
{
return
as
(
PrimitiveType
<
TValue
>::
instance
);
}
template
<
typename
TValue
>
inline
std
::
enable_if_t
<
TValue
::
is_primitive
,
const
TValue
&>
cast
()
const
{
return
cast
(
PrimitiveType
<
TValue
>::
instance
);
}
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
const
TypedValueRef
<
TValue
>&
cast_ref
(
Type
<
TValue
>
type
=
{})
const
;
inline
std
::
enable_if_t
<
TValue
::
is_primitive
,
const
TypedValueRef
<
TValue
>&>
as_ref
()
const
{
return
as_ref
(
PrimitiveType
<
TValue
>::
instance
);
}
template
<
typename
TValue
>
template
<
typename
TValue
>
void
on_cast_failure
()
const
;
inline
std
::
enable_if_t
<
TValue
::
is_primitive
,
const
TypedValueRef
<
TValue
>&>
cast_ref
()
const
{
return
cast_ref
(
PrimitiveType
<
TValue
>::
instance
);
}
void
on_cast_failure
(
const
IType
&
type
)
const
;
operator
bool
()
const
{
return
bool
(
m_storage
);
}
operator
bool
()
const
{
return
bool
(
m_storage
);
}
...
@@ -172,8 +252,6 @@ public:
...
@@ -172,8 +252,6 @@ public:
friend
class
ValueWeakRef
;
friend
class
ValueWeakRef
;
template
<
typename
>
template
<
typename
>
friend
class
TypedValueRef
;
friend
class
TypedValueRef
;
template
<
typename
,
ValueKind
>
friend
class
ValueImpl
;
friend
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
friend
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
};
};
...
@@ -195,7 +273,8 @@ protected:
...
@@ -195,7 +273,8 @@ protected:
public:
public:
ValueWeakRef
()
=
default
;
ValueWeakRef
()
=
default
;
ValueWeakRef
(
ValueRef
value
)
:
m_id
(
value
.
id
()),
m_storage
(
value
.
m_storage
)
{}
ValueWeakRef
(
const
ValueRef
&
value
)
:
m_id
(
value
.
id
()),
m_storage
(
value
.
m_storage
)
{}
/**
/**
* \brief try promote to ValueRef
* \brief try promote to ValueRef
...
@@ -218,19 +297,15 @@ public:
...
@@ -218,19 +297,15 @@ public:
class
Value
:
public
NonCopyableObj
{
class
Value
:
public
NonCopyableObj
{
private:
private:
uint64_t
m_id
=
std
::
numeric_limits
<
uint64_t
>::
max
();
uint64_t
m_id
=
std
::
numeric_limits
<
uint64_t
>::
max
();
size_t
m_typecode
=
0
;
const
IType
*
m_type
=
nullptr
;
ValueRef
m_successor
;
ValueRef
m_successor
;
size_t
m_watching
=
0
;
size_t
m_watching
=
0
;
protected:
protected:
Value
(
size_t
typecode
);
Value
();
public:
public:
size_t
typecode
()
const
{
return
m_typecode
;
}
const
IType
&
type
()
const
{
return
*
m_type
;
}
const
std
::
type_index
type
()
const
{
return
registered_types
()[
m_typecode
];
}
static
size_t
register_type
(
std
::
type_index
type
);
static
const
std
::
vector
<
std
::
type_index
>&
registered_types
();
static
void
register_value
(
ValueRef
value
);
static
void
register_value
(
ValueRef
value
);
static
ValueRef
get_value_by_id
(
uint64_t
id
);
static
ValueRef
get_value_by_id
(
uint64_t
id
);
...
@@ -251,11 +326,12 @@ public:
...
@@ -251,11 +326,12 @@ public:
friend
class
ValueRef
;
friend
class
ValueRef
;
friend
class
ValueWeakRef
;
friend
class
ValueWeakRef
;
template
<
typename
,
ValueKind
>
friend
class
ValueImpl
;
template
<
typename
T
>
template
<
typename
T
>
friend
class
TypedValueRef
;
friend
class
TypedValueRef
;
template
<
typename
T
>
friend
class
Type
;
~
Value
();
~
Value
();
private:
private:
...
@@ -267,30 +343,17 @@ private:
...
@@ -267,30 +343,17 @@ private:
*
*
* \tparam T type of value
* \tparam T type of value
*/
*/
template
<
typename
T
,
ValueKind
Kind
>
template
<
typename
T
>
class
ValueImpl
:
public
Value
{
class
ObjectValue
:
public
Value
{
protected:
protected:
ValueImpl
()
:
Value
(
TYPE_CODE
)
{}
ObjectValue
(
)
{}
public:
public:
using
ref_t
=
TypedValueRef
<
T
>
;
using
ref_t
=
TypedValueRef
<
T
>
;
using
weak_ref_t
=
TypedValueWeakRef
<
T
>
;
using
weak_ref_t
=
TypedValueWeakRef
<
T
>
;
static
inline
const
size_t
TYPE_CODE
=
[]
{
return
register_type
(
typeid
(
T
));
}();
static
constexpr
bool
is_primitive
=
false
;
static
constexpr
ValueKind
KIND
=
Kind
;
static
constexpr
bool
is_object
=
true
;
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template
<
typename
...
TArgs
>
static
MGB_NOINLINE
TypedValueRef
<
T
>
make
(
TArgs
&&
...
args
)
{
static_assert
(
std
::
is_final_v
<
T
>
);
return
ValueRef
::
make
(
LocalPtr
<
Value
>::
make
<
T
>
(
std
::
forward
<
TArgs
&&>
(
args
)...));
}
};
};
/**
/**
...
@@ -299,74 +362,89 @@ public:
...
@@ -299,74 +362,89 @@ public:
* \tparam T type of value
* \tparam T type of value
* \tparam TMixin type of mixin class
* \tparam TMixin type of mixin class
*/
*/
template
<
typename
T
,
ValueKind
Kind
,
typename
TMixin
>
template
<
typename
T
,
typename
TMixin
>
class
MixinValueImpl
:
public
ValueImpl
<
T
,
Kind
>
,
public
TMixin
{
class
PrimitiveValue
:
public
Value
,
public
TMixin
{
public:
public:
using
ref_t
=
TypedValueRef
<
T
>
;
using
weak_ref_t
=
TypedValueWeakRef
<
T
>
;
using
TMixin
::
TMixin
;
using
TMixin
::
TMixin
;
MixinValueImpl
(
TMixin
mixin
)
:
TMixin
(
std
::
move
(
mixin
))
{}
PrimitiveValue
(
TMixin
&&
mixin
)
:
TMixin
(
std
::
move
(
mixin
))
{}
PrimitiveValue
(
const
TMixin
&
mixin
)
:
TMixin
(
mixin
)
{}
public:
public:
void
clear
()
override
final
{
((
TMixin
&
)
*
this
)
=
{};
}
void
clear
()
override
final
{
((
TMixin
&
)
*
this
)
=
{};
}
bool
eq
(
const
TMixin
&
value
)
const
{
return
((
const
TMixin
&
)
*
this
)
==
value
;
}
bool
eq
(
const
TMixin
&
value
)
const
{
return
((
const
TMixin
&
)
*
this
)
==
value
;
}
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template
<
typename
...
TArgs
>
static
TypedValueRef
<
T
>
make
(
TArgs
&&
...
args
)
{
return
PrimitiveType
<
T
>::
instance
.
make
(
std
::
forward
<
TArgs
&&>
(
args
)...);
}
static
constexpr
bool
is_primitive
=
true
;
static
constexpr
bool
is_object
=
false
;
};
};
template
<
typename
T
>
PrimitiveType
<
T
>::
PrimitiveType
()
:
Type
<
T
>
(
typeid
(
T
).
name
())
{
static_assert
(
std
::
is_base_of_v
<
Value
,
T
>
);
static_assert
(
!
std
::
is_base_of_v
<
ObjectValue
<
T
>
,
T
>
);
}
inline
ValueRef
::
ValueRef
(
storage_t
storage
)
{
inline
ValueRef
::
ValueRef
(
storage_t
storage
)
{
// mgb_assert(storage);
m_storage
=
storage
;
m_storage
=
storage
;
m_id
=
m_storage
->
m_id
;
m_id
=
m_storage
->
m_id
;
}
}
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
const
TValue
*
ValueRef
::
as
(
Type
<
TValue
>
type
)
const
{
inline
const
TValue
*
ValueRef
::
as
(
const
Type
<
TValue
>&
type
)
const
{
// auto _ = Stats::time_value_as.time_scope();
static_assert
(
std
::
is_base_of_v
<
Value
,
TValue
>
);
static_assert
(
std
::
is_base_of_v
<
Value
,
TValue
>
);
return
static_cast
<
const
TValue
*>
(
as
(
type
.
code
()
));
return
static_cast
<
const
TValue
*>
(
as
(
(
const
IType
&
)
type
));
}
}
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
const
TValue
&
ValueRef
::
cast
(
Type
<
TValue
>
type
)
const
{
inline
const
TValue
&
ValueRef
::
cast
(
const
Type
<
TValue
>&
type
)
const
{
// auto _ = Stats::time_value_cast.time_scope();
auto
*
ptr
=
as
<
TValue
>
(
type
);
auto
*
ptr
=
as
<
TValue
>
(
type
);
if
(
mgb_unlikely
(
!
ptr
))
{
if
(
mgb_unlikely
(
!
ptr
))
{
on_cast_failure
<
TValue
>
(
);
on_cast_failure
(
type
);
}
}
return
static_cast
<
const
TValue
&>
(
*
ptr
);
return
static_cast
<
const
TValue
&>
(
*
ptr
);
}
}
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
bool
ValueRef
::
is
(
Type
<
TValue
>
type
)
const
{
inline
const
TypedValueRef
<
TValue
>&
ValueRef
::
as_ref
(
const
Type
<
TValue
>&
type
)
const
{
// auto _ = Stats::time_value_is.time_scope();
if
(
!
is
(
type
))
{
return
is
(
type
.
code
());
}
template
<
typename
TValue
>
inline
const
TypedValueRef
<
TValue
>&
ValueRef
::
as_ref
(
Type
<
TValue
>
type
)
const
{
if
(
!
is
<
TValue
>
(
type
))
{
return
TypedValueRef
<
TValue
>::
nil
;
return
TypedValueRef
<
TValue
>::
nil
;
}
}
return
*
reinterpret_cast
<
const
TypedValueRef
<
TValue
>*>
(
this
);
return
*
reinterpret_cast
<
const
TypedValueRef
<
TValue
>*>
(
this
);
}
}
template
<
typename
TValue
>
template
<
typename
TValue
>
inline
const
TypedValueRef
<
TValue
>&
ValueRef
::
cast_ref
(
Type
<
TValue
>
type
)
const
{
inline
const
TypedValueRef
<
TValue
>&
ValueRef
::
cast_ref
(
const
Type
<
TValue
>&
type
)
const
{
if
(
!
m_storage
)
{
if
(
!
m_storage
)
{
return
TypedValueRef
<
TValue
>::
nil
;
return
TypedValueRef
<
TValue
>::
nil
;
}
}
if
(
mgb_unlikely
(
!
is
<
TValue
>
(
type
)))
{
if
(
mgb_unlikely
(
!
is
(
type
)))
{
on_cast_failure
<
TValue
>
(
);
on_cast_failure
(
type
);
}
}
return
*
reinterpret_cast
<
const
TypedValueRef
<
TValue
>*>
(
this
);
return
*
reinterpret_cast
<
const
TypedValueRef
<
TValue
>*>
(
this
);
}
}
template
<
typename
TValue
>
inline
void
ValueRef
::
on_cast_failure
(
const
IType
&
type
)
const
{
void
ValueRef
::
on_cast_failure
()
const
{
// if this is ErrorValue, rethrow directly
// if this is ErrorValue, rethrow directly
storage
()
->
try_rethrow
();
storage
()
->
try_rethrow
();
mgb_assert
(
mgb_assert
(
storage
()
->
m_typecode
!=
TValue
::
TYPE_CODE
,
"expect type %s, got %s"
,
storage
()
->
type
()
!=
type
,
"expect type %s, got %s"
,
type
.
name
().
c_str
()
,
t
ypeid
(
TValue
).
name
(),
t
o_string
().
c_str
());
to_string
().
c_str
());
}
}
/**
/**
...
@@ -382,26 +460,10 @@ private:
...
@@ -382,26 +460,10 @@ private:
public:
public:
TypedValueRef
()
=
default
;
TypedValueRef
()
=
default
;
const
T
&
operator
*
()
const
{
const
T
&
operator
*
()
const
{
if
constexpr
(
T
::
KIND
==
ValueKind
::
Object
)
{
mgb_assert
(
m_storage
,
"empty storage"
);
return
this
->
template
cast
<
T
>();
return
static_cast
<
const
T
&>
(
*
m_storage
);
}
else
if
constexpr
(
T
::
KIND
==
ValueKind
::
Primitive
)
{
if
(
!
m_storage
)
{
on_cast_failure
<
T
>
();
}
return
static_cast
<
const
T
&>
(
*
m_storage
);
}
else
{
static_assert
(
!
std
::
is_same_v
<
T
,
T
>
);
}
}
const
T
*
operator
->
()
const
{
if
constexpr
(
T
::
KIND
==
ValueKind
::
Object
)
{
return
this
->
template
as
<
T
>();
}
else
if
constexpr
(
T
::
KIND
==
ValueKind
::
Primitive
)
{
return
static_cast
<
const
T
*>
(
m_storage
.
get
());
}
else
{
static_assert
(
!
std
::
is_same_v
<
T
,
T
>
);
}
}
}
const
T
*
operator
->
()
const
{
return
static_cast
<
const
T
*>
(
m_storage
.
get
());
}
/**
/**
* \brief reset underlying value to another value
* \brief reset underlying value to another value
...
@@ -409,7 +471,7 @@ public:
...
@@ -409,7 +471,7 @@ public:
* \param successor new value
* \param successor new value
*/
*/
inline
void
reset
(
ValueRef
successor
)
{
inline
void
reset
(
ValueRef
successor
)
{
static_assert
(
T
::
KIND
==
ValueKind
::
Object
);
static_assert
(
std
::
is_base_of_v
<
ObjectValue
<
T
>
,
T
>
);
mgb_assert
(
m_storage
);
mgb_assert
(
m_storage
);
mgb_assert
(
!
m_storage
->
m_successor
);
mgb_assert
(
!
m_storage
->
m_successor
);
if
(
m_storage
->
m_watching
)
{
if
(
m_storage
->
m_watching
)
{
...
@@ -422,25 +484,19 @@ public:
...
@@ -422,25 +484,19 @@ public:
static
inline
const
TypedValueRef
nil
;
static
inline
const
TypedValueRef
nil
;
friend
class
ValueRef
;
friend
class
ValueRef
;
friend
class
Type
<
T
>
;
template
<
typename
,
ValueKind
>
friend
class
TypedValueWeakRef
<
T
>
;
friend
class
ValueImpl
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
TypedValueWeakRef
:
public
ValueWeakRef
{
class
TypedValueWeakRef
:
public
ValueWeakRef
{
private:
private:
TypedValueWeakRef
(
const
ValueRef
&
value
)
:
ValueWeakRef
(
value
)
{}
TypedValueWeakRef
(
const
ValueWeakRef
&
value
)
:
ValueWeakRef
(
value
)
{}
public:
public:
TypedValueWeakRef
(
ValueRef
value
)
:
ValueWeakRef
(
value
)
{}
TypedValueWeakRef
(
const
TypedValueRef
<
T
>&
value
)
:
ValueWeakRef
(
value
)
{}
TypedValueWeakRef
(
ValueWeakRef
value
)
:
ValueWeakRef
(
value
)
{}
TypedValueRef
<
T
>
lock
()
{
return
(
TypedValueRef
<
T
>
)
ValueWeakRef
::
lock
();
}
TypedValueRef
<
T
>
lock
()
{
auto
value
=
ValueWeakRef
::
lock
();
if
(
value
)
{
return
value
.
template
as_ref
<
T
>();
}
else
{
return
{};
}
}
};
};
// TODO: add proxy value type, which is meant to be reset in the end
// TODO: add proxy value type, which is meant to be reset in the end
...
@@ -509,10 +565,14 @@ inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_s
...
@@ -509,10 +565,14 @@ inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_s
m_data
[
0
]
=
std
::
move
(
item
);
m_data
[
0
]
=
std
::
move
(
item
);
}
}
/*class ValueRefList : public SmallVector<ValueRef, 1> {
template
<
typename
T
>
public:
template
<
typename
...
TArgs
>
using SmallVector::SmallVector;
TypedValueRef
<
T
>
Type
<
T
>::
make
(
TArgs
&&
...
args
)
const
{
};*/
static_assert
(
std
::
is_final_v
<
T
>
);
auto
storage
=
LocalPtr
<
Value
>::
make
<
T
>
(
std
::
forward
<
TArgs
&&>
(
args
)...);
storage
->
m_type
=
this
;
return
ValueRef
::
make
(
std
::
move
(
storage
));
}
}
// namespace imperative
}
// namespace imperative
}
// namespace mgb
}
// namespace mgb
...
...
imperative/src/test/backward_graph.cpp
浏览文件 @
177001d5
...
@@ -123,7 +123,7 @@ TEST(TestImperative, BackwardGraphBasic) {
...
@@ -123,7 +123,7 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
}
}
inputs
.
clear
();
inputs
.
clear
();
auto
input_grads
=
result
.
graph
.
apply
(
auto
input_grads
=
result
.
graph
.
apply
<
TensorPtr
>
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
});
[
&
](
auto
&&
x
)
{
return
x
;
});
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
...
@@ -177,7 +177,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
...
@@ -177,7 +177,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
}
}
inputs
.
clear
();
inputs
.
clear
();
auto
input_grads
=
result
.
graph
.
apply
(
auto
input_grads
=
result
.
graph
.
apply
<
TensorPtr
>
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
});
[
&
](
auto
&&
x
)
{
return
x
;
});
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
...
@@ -244,11 +244,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
...
@@ -244,11 +244,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
bg
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
bg
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
auto
grads
=
expand_grads
(
auto
grads
=
expand_grads
(
bg
.
output_mask
,
bg
.
output_mask
,
bg
.
graph
.
apply
(
bg
.
graph
.
apply
<
TensorPtr
>
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
}));
[
&
](
auto
&&
x
)
{
return
x
;
}));
auto
precomp
=
obg
.
precomp
.
apply
(
auto
precomp
=
obg
.
precomp
.
apply
<
TensorPtr
>
(
SmallVector
<
TensorPtr
>
{
a_tn
,
b_tn
,
c_tn
},
apply_shared_on_physical_tensor
,
SmallVector
<
TensorPtr
>
{
a_tn
,
b_tn
,
c_tn
},
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
});
[
&
](
auto
&&
x
)
{
return
x
;
});
ASSERT_EQ
(
precomp
.
size
(),
2
);
ASSERT_EQ
(
precomp
.
size
(),
2
);
...
@@ -261,7 +261,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
...
@@ -261,7 +261,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
obg
,
precomp
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
obg
,
precomp
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
auto
grads2
=
expand_grads
(
auto
grads2
=
expand_grads
(
obg
.
input_has_grad
,
obg
.
input_has_grad
,
obg
.
backward
.
apply
(
obg
.
backward
.
apply
<
TensorPtr
>
(
backward_inputs
,
apply_shared_on_physical_tensor
,
backward_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
}));
[
&
](
auto
&&
x
)
{
return
x
;
}));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录