Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4fa61620
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4704
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4fa61620
编写于
1月 27, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(dispatch): improve performance of dispatch system
GitOrigin-RevId: 860028e1af63936e7b4edefbed90d8244e7cb8d2
上级
ca001777
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
442 addition
and
185 deletion
+442
-185
imperative/python/src/module_trace.h
imperative/python/src/module_trace.h
+1
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+79
-42
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+2
-1
imperative/src/impl/dispatch.cpp
imperative/src/impl/dispatch.cpp
+1
-5
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+1
-0
imperative/src/impl/transformation.cpp
imperative/src/impl/transformation.cpp
+1
-0
imperative/src/impl/transformations/eval.cpp
imperative/src/impl/transformations/eval.cpp
+6
-4
imperative/src/impl/transformations/grad.cpp
imperative/src/impl/transformations/grad.cpp
+54
-46
imperative/src/impl/transformations/scalar.cpp
imperative/src/impl/transformations/scalar.cpp
+10
-5
imperative/src/impl/value.cpp
imperative/src/impl/value.cpp
+2
-2
imperative/src/include/megbrain/imperative/basic_values.h
imperative/src/include/megbrain/imperative/basic_values.h
+26
-17
imperative/src/include/megbrain/imperative/dispatch.h
imperative/src/include/megbrain/imperative/dispatch.h
+7
-2
imperative/src/include/megbrain/imperative/graph_cache.h
imperative/src/include/megbrain/imperative/graph_cache.h
+5
-0
imperative/src/include/megbrain/imperative/transformations/eval.h
...ve/src/include/megbrain/imperative/transformations/eval.h
+1
-1
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+42
-35
imperative/src/include/megbrain/imperative/transformations/lazy.h
...ve/src/include/megbrain/imperative/transformations/lazy.h
+2
-1
imperative/src/include/megbrain/imperative/transformations/scalar.h
.../src/include/megbrain/imperative/transformations/scalar.h
+1
-1
imperative/src/include/megbrain/imperative/transformations/symbol.h
.../src/include/megbrain/imperative/transformations/symbol.h
+1
-1
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+4
-2
imperative/src/include/megbrain/imperative/utils/stats.h
imperative/src/include/megbrain/imperative/utils/stats.h
+140
-0
imperative/src/include/megbrain/imperative/value.h
imperative/src/include/megbrain/imperative/value.h
+56
-20
未找到文件。
imperative/python/src/module_trace.h
浏览文件 @
4fa61620
...
...
@@ -13,6 +13,7 @@
#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"
#include "megbrain/imperative/utils/stats.h"
#include "./tensor.h"
...
...
imperative/python/src/tensor.cpp
浏览文件 @
4fa61620
...
...
@@ -21,6 +21,7 @@
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/io.h"
#include "megbrain/plugin/profiler.h"
...
...
@@ -52,8 +53,48 @@ namespace mgb::imperative::python {
namespace
{
WeakKeyMap
<
ValueWeakRef
,
py
::
object
>
module_trace_info_map
;
struct
SymbolVarContext
{
TransformationContext
context
;
cg
::
ComputingGraph
*
graph
;
SymbolVarContext
(
cg
::
ComputingGraph
*
graph
)
:
graph
(
graph
)
{
Transformation
::
swap_context
(
context
);
}
void
init
()
{
std
::
make_shared
<
SymbolTransformation
>
(
graph
)
->
register_at
(
Transformation
::
top
());
std
::
make_shared
<
ScalarTransformation
>
()
->
register_at
(
Transformation
::
top
());
}
~
SymbolVarContext
()
{
Transformation
::
swap_context
(
context
);
}
};
ValueRef
symvar2val
(
py
::
handle
py_symbol_var
)
{
auto
*
symbol_var
=
py_symbol_var
.
cast
<
PySymbolVar
*>
();
ValueRef
value
=
SymbolValue
::
make
(
symbol_var
->
m_node
);
if
(
symbol_var
->
is_scalar
)
{
value
=
ScalarValue
::
make
(
value
);
}
return
value
;
}
py
::
object
val2symvar
(
py
::
handle
typeobj
,
ValueRef
value
)
{
bool
is_scalar
=
false
;
if
(
auto
*
scalar_value
=
value
.
as
<
ScalarValue
>
())
{
value
=
scalar_value
->
value
();
is_scalar
=
true
;
}
auto
*
node
=
value
.
cast
<
SymbolValue
>
().
node
();
auto
py_symbol_var
=
typeobj
(
pybind11
::
cast
(
node
,
pybind11
::
return_value_policy
::
automatic
));
py_symbol_var
.
cast
<
PySymbolVar
*>
()
->
is_scalar
=
is_scalar
;
return
py_symbol_var
;
}
}
// namespace
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
=
nullptr
;
PyTypeObject
*
py_tensor_type
=
nullptr
;
PyObject
*
cpp_use_symbolic_shape
,
*
cpp_astensor1d
;
...
...
@@ -91,36 +132,17 @@ PyObject* py_apply(
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
])))
{
// swap to a special context to reuse scalar handle
TransformationContext
symbol_var_context
;
Transformation
::
swap_context
(
symbol_var_context
);
CleanupGuard
_
{[
&
]
{
Transformation
::
swap_context
(
symbol_var_context
);
}};
auto
*
graph
=
py
::
handle
(
args
[
0
]).
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
();
std
::
make_shared
<
SymbolTransformation
>
(
graph
)
->
register_at
(
Transformation
::
top
());
std
::
make_shared
<
ScalarTransformation
>
()
->
register_at
(
Transformation
::
top
());
SymbolVarContext
context
(
py
::
handle
(
args
[
0
]).
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
());
context
.
init
();
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
auto
*
py_input
=
py
::
handle
(
args
[
i
]).
cast
<
PySymbolVar
*>
();
ValueRef
input
=
SymbolValue
::
make
(
py_input
->
m_node
);
if
(
py_input
->
is_scalar
)
{
input
=
ScalarValue
::
make
(
input
);
}
tensors
[
i
]
=
input
;
tensors
[
i
]
=
symvar2val
(
args
[
i
]);
}
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
auto
ret
=
pybind11
::
tuple
(
outputs
.
size
());
auto
typeobj
=
py
::
handle
(
args
[
0
]).
get_type
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
bool
is_scalar
=
false
;
if
(
auto
*
scalar_value
=
outputs
[
i
].
as
<
ScalarValue
>
())
{
outputs
[
i
]
=
scalar_value
->
value
();
is_scalar
=
true
;
}
auto
*
node
=
outputs
[
i
].
cast
<
SymbolValue
>
().
node
();
ret
[
i
]
=
typeobj
(
pybind11
::
cast
(
node
,
pybind11
::
return_value_policy
::
automatic
));
py
::
handle
(
ret
[
i
]).
cast
<
PySymbolVar
*>
()
->
is_scalar
=
is_scalar
;
ret
[
i
]
=
val2symvar
(
typeobj
,
outputs
[
i
]);
}
return
ret
.
release
().
ptr
();
}
...
...
@@ -1537,17 +1559,29 @@ void init_tensor(py::module m) {
}
});
m
.
def
(
"reduce_to_scalar"
,
[](
py
::
object
op
,
py
::
object
tensor
)
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
m
.
def
(
"reduce_to_scalar"
,
[](
py
::
object
op
,
py
::
object
tensor
)
->
py
::
object
{
auto
reduce_to_scalar
=
[](
const
OpDef
&
op
,
const
ValueRef
&
input
)
{
auto
make_scalar_shape
=
[
&
](
CompNode
device
)
{
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
device
,
dtype
::
Int32
(),
{
0
}),
HostStorage
::
make
(
device
))[
0
];
};
auto
output
=
imperative
::
apply
(
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
tw
->
m_tensor
->
data
(),
make_scalar_shape
(
tw
->
m_tensor
->
comp_node
()))[
0
];
return
imperative
::
apply
(
op
,
input
,
make_scalar_shape
(
*
input
.
device
()))[
0
];
};
if
(
py
::
isinstance
<
PySymbolVar
>
(
tensor
))
{
auto
*
graph
=
tensor
.
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
();
SymbolVarContext
context
(
graph
);
context
.
init
();
auto
output
=
reduce_to_scalar
(
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
symvar2val
(
tensor
));
auto
typeobj
=
tensor
.
get_type
();
return
val2symvar
(
typeobj
,
output
);
}
else
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
output
=
reduce_to_scalar
(
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
tw
->
m_tensor
->
data
());
return
TensorWrapper
::
make
(
py_tensor_type
,
output
);
}
});
m
.
def
(
"name_tensor"
,
[](
std
::
string
name
,
py
::
object
tensor
)
{
...
...
@@ -1557,7 +1591,7 @@ void init_tensor(py::module m) {
});
m
.
def
(
"is_grad_attached"
,
[](
std
::
vector
<
py
::
object
>
tensors
)
->
bool
{
ValueRefList
values
(
tensors
.
size
());
SmallVector
<
ValueRef
>
values
(
tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
values
[
i
]
=
tensors
[
i
].
cast
<
TensorWrapper
>
().
m_tensor
->
data
();
}
...
...
@@ -1570,17 +1604,16 @@ void init_tensor(py::module m) {
});
m
.
def
(
"get_grad_key"
,
[](
std
::
vector
<
py
::
object
>
tensors
)
->
py
::
object
{
ValueRefList
values
(
tensors
.
size
());
SmallVector
<
ValueRef
>
values
(
tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
values
[
i
]
=
tensors
[
i
].
cast
<
TensorWrapper
>
().
m_tensor
->
data
();
}
auto
outputs
=
imperative
::
apply
(
GetGradKey
(),
values
);
if
(
auto
*
grad_key_val
=
outputs
[
0
].
as
<
GradKeyValue
>
())
{
return
py
::
reinterpret_borrow
<
py
::
object
>
(
GradKeyWrapper
::
wrap_t
::
pycast
(
GradKeyWrapper
::
get
(
*
grad_key_val
)));
}
else
{
auto
output
=
imperative
::
apply
(
GetGradKey
(),
values
)[
0
];
if
(
!
output
)
{
return
py
::
none
();
}
return
py
::
reinterpret_borrow
<
py
::
object
>
(
GradKeyWrapper
::
wrap_t
::
pycast
(
GradKeyWrapper
::
get
(
output
.
cast
<
GradKeyValue
>
())));
});
m
.
def
(
"set_grad"
,
[](
py
::
object
py_key
,
py
::
function
backward_fn
,
...
...
@@ -1612,7 +1645,7 @@ void init_tensor(py::module m) {
}
return
input_grads
;
};
ValueRefList
values
(
inputs
.
size
()
+
outputs
.
size
());
SmallVector
<
ValueRef
>
values
(
inputs
.
size
()
+
outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
values
[
i
]
=
inputs
[
i
].
cast
<
TensorWrapper
>
().
m_tensor
->
data
();
}
...
...
@@ -1669,6 +1702,10 @@ void init_tensor(py::module m) {
return
reprs
;
});
m
.
def
(
"print_stats"
,
[]
{
imperative
::
Stats
::
print
();
});
m
.
def
(
"reset_stats"
,
[]
{
imperative
::
Stats
::
reset
();
});
py
::
register_exception
<
TraceError
>
(
m
,
"TraceError"
);
}
...
...
imperative/python/src/transformation.h
浏览文件 @
4fa61620
...
...
@@ -67,7 +67,8 @@ struct TransformationManager {
}
};
class
PyValue
final
:
public
MixinValueImpl
<
PyValue
,
pybind11
::
object
>
{
class
PyValue
final
:
public
MixinValueImpl
<
PyValue
,
ValueKind
::
Object
,
pybind11
::
object
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
imperative/src/impl/dispatch.cpp
浏览文件 @
4fa61620
...
...
@@ -14,13 +14,9 @@
#include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/map.h"
#include "megbrain/imperative/utils/stats.h"
namespace
mgb
{
void
imperative_log_profile_begin
(
const
char
*
message
);
void
imperative_log_profile
(
const
char
*
message
);
void
imperative_log_profile_end
(
const
char
*
message
);
namespace
imperative
{
namespace
{
...
...
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
4fa61620
...
...
@@ -19,6 +19,7 @@
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/imperative/utils/to_string.h"
#include "../blob_manager_impl.h"
...
...
imperative/src/impl/transformation.cpp
浏览文件 @
4fa61620
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/stats.h"
namespace
mgb
{
namespace
imperative
{
...
...
imperative/src/impl/transformations/eval.cpp
浏览文件 @
4fa61620
...
...
@@ -11,6 +11,7 @@
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/utils/stats.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -40,9 +41,6 @@ ShapeValue::ref_t InterpreterInfo::shape() const {
ValueRefList
InterpreterTransformation
::
apply_op
(
const
ApplyOp
&
apply_op
,
Span
<
ValueRef
>
inputs
)
{
if
(
apply_op
.
op
().
same_type
<
FastpathCopy
>
())
{
return
{
inputs
[
0
]};
}
SmallVector
<
Handle
>
input_handles
;
SmallVector
<
Handle
>
output_handles
;
CleanupGuard
_
{[
&
]
{
...
...
@@ -111,7 +109,11 @@ ValueRefList InterpreterTransformation::apply_create_tensor(
ValueRefList
InterpreterTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_val
=
op
.
as
<
ApplyOp
>
())
{
if
(
op_val
->
op
().
same_type
<
FastpathCopy
>
())
{
return
inputs
[
0
];
}
else
{
return
apply_op
(
*
op_val
,
inputs
);
}
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
return
apply_get_attr
(
*
get_attr
,
inputs
);
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
...
...
imperative/src/impl/transformations/grad.cpp
浏览文件 @
4fa61620
...
...
@@ -11,8 +11,11 @@
#include "megbrain/imperative/transformations/grad.h"
#include <variant>
#include "megbrain/imperative/graph_cache.h"
#include "megbrain/imperative/resource_manager.h"
#include "megbrain/imperative/utils/stats.h"
#include <range/v3/all.hpp>
...
...
@@ -20,20 +23,21 @@ namespace mgb {
namespace
imperative
{
static
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
make_optimized_backward_graph
(
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
,
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
,
Span
<
bool
>
inputs_require_grad
)
{
// hash
using
OptimizedBackwardGraphCache
=
OpMethResultCache
<
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
,
SmallVector
<
bool
>>
;
thread_local
auto
&
cache
=
*
ResourceManager
::
create_local
<
OptimizedBackwardGraphCache
>
();
OptimizedBackwardGraphCache
::
key_t
cache_key
{
op
};
OptimizedBackwardGraphCache
::
key_t
cache_key
{
op
.
shared_from_this
()
};
SmallVector
<
LogicalTensorDesc
>&
input_descs
=
cache_key
.
inputs
;
std
::
get
<
0
>
(
cache_key
.
extras
)
=
inputs_require_grad
.
copy_into
<
SmallVector
<
bool
>>
();
cache_key
.
extra
<
0
>
(
)
=
inputs_require_grad
.
copy_into
<
SmallVector
<
bool
>>
();
input_descs
.
resize
(
inputs
.
size
());
// some overhead, consider simplify LogicalTensorDesc
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
input_descs
[
i
].
layout
.
dtype
=
inputs
[
i
].
dtype
().
cast
<
DTypeValue
>
();
input_descs
[
i
].
comp_node
=
inputs
[
i
].
device
().
cast
<
CompNodeValue
>
();
input_descs
[
i
].
layout
.
dtype
=
*
inputs
[
i
].
dtype
();
input_descs
[
i
].
comp_node
=
*
inputs
[
i
].
device
();
}
auto
iter
=
cache
.
find
(
cache_key
);
...
...
@@ -45,7 +49,7 @@ static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_gra
SmallVector
<
bool
>
output_has_grad
(
outputs
.
size
(),
true
);
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
ret
;
auto
bg
=
OpDef
::
make_backward_graph
(
*
op
,
input_descs
,
std
::
get
<
0
>
(
cache_key
.
extras
),
output_has_grad
);
op
,
input_descs
,
std
::
get
<
0
>
(
cache_key
.
extras
),
output_has_grad
);
if
(
!
bg
.
graph
.
empty
())
{
ret
=
std
::
make_shared
<
OptimizedBackwardGraphResult
>
(
bg
);
}
...
...
@@ -235,7 +239,7 @@ GradValue::ref_t GradKey::attach(
}
else
{
GradSlotPtr
grad_slot
;
auto
&
grad_fn
=
grad_slot
.
m_fn
;
grad_fn
=
std
::
make_shared
<
GradFn
>
();
grad_fn
=
LocalPtr
<
GradFn
>::
make
();
grad_fn
->
m_key
=
shared_from_this
();
grad_fn
->
m_slots
.
resize
(
1
);
grad_slot
.
m_index
=
0
;
...
...
@@ -260,17 +264,21 @@ ValueRefList GradTransformation::apply_transformation(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
fallback
=
[
&
]
{
ValueRefList
unwrapped_inputs
(
inputs
.
size
());
{
// overhead
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
auto
grad_value
=
as_grad_value
(
inputs
[
i
]))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
i
]))
{
unwrapped_inputs
[
i
]
=
grad_value
->
m_value
;
}
else
{
unwrapped_inputs
[
i
]
=
inputs
[
i
];
}
}
}
return
imperative
::
apply
(
op
,
unwrapped_inputs
);
};
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
if
(
auto
grad_value
=
as_grad_value
(
inputs
.
item
()))
{
if
(
op
.
is
<
GetAttr
>
())
{
// overhead
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
.
item
()))
{
return
imperative
::
apply
(
op
,
grad_value
->
m_value
);
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
...
...
@@ -281,28 +289,29 @@ ValueRefList GradTransformation::apply_transformation(
}
if
(
auto
*
op_val
=
op
.
as
<
ApplyOp
>
())
{
size_t
nr_require_grad
=
0
;
SmallVector
<
bool
>
require_grads
;
for
(
auto
&&
input
:
inputs
)
{
if
(
is_grad_value
(
input
))
{
SmallVector
<
bool
>
require_grads
(
inputs
.
size
())
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
is_grad_value
(
input
s
[
i
]
))
{
nr_require_grad
++
;
require_grads
.
push_back
(
true
)
;
require_grads
[
i
]
=
true
;
}
else
{
require_grads
.
push_back
(
false
)
;
require_grads
[
i
]
=
false
;
}
}
if
(
nr_require_grad
==
0
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
ValueRefList
captured_inputs
(
inputs
.
size
());
SmallVector
<
ValueRef
>
captured_inputs
(
inputs
.
size
());
SmallVector
<
bool
>
inputs_require_grad
(
inputs
.
size
());
// capture value so that trace could assume input as same
auto
capture_value
=
[](
ValueRef
value
)
{
auto
capture_value
=
[](
const
ValueRef
&
value
)
{
// TODO: fastpath copy shouldn't be an OpDef
return
imperative
::
apply
(
ApplyOp
(
*
FastpathCopy
::
make
()),
{
&
value
,
1
})[
0
];
static
auto
fastpath_copy
=
FastpathCopy
::
make
();
return
imperative
::
apply
(
ApplyOp
(
*
fastpath_copy
),
value
)[
0
];
};
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&
input
=
inputs
[
i
];
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
input
))
{
captured_inputs
[
i
]
=
capture_value
(
grad_value
->
m_value
);
inputs_require_grad
[
i
]
=
true
;
}
else
{
...
...
@@ -310,32 +319,28 @@ ValueRefList GradTransformation::apply_transformation(
inputs_require_grad
[
i
]
=
false
;
}
}
decltype
(
std
::
declval
<
GradFn
>
().
m_backward
)
backward_storage
;
// copy grad_fn->m_backward is expensive
auto
grad_fn
=
LocalPtr
<
GradFn
>::
make
();
auto
&
backward_storage
=
grad_fn
->
m_backward
;
auto
outputs
=
[
&
]
{
auto
backward_rule
=
CustomBackward
::
lookup_grad_rule
(
op_val
->
op
().
dyn_typeinfo
());
if
(
backward_rule
)
{
CustomBackward
backward
;
auto
optional_outputs
=
backward_rule
(
op_val
->
op
(),
{
captured_inputs
.
data
(),
captured_inputs
.
size
()},
{
inputs_require_grad
.
data
(),
inputs_require_grad
.
size
()},
backward
);
op_val
->
op
(),
captured_inputs
,
inputs_require_grad
,
backward
);
if
(
optional_outputs
)
{
backward_storage
=
backward
;
// backward by rule
return
*
optional_outputs
;
}
}
auto
outputs
=
imperative
::
apply
(
op
,
{
captured_inputs
.
begin
(),
captured_inputs
.
end
()});
auto
outputs
=
imperative
::
apply
(
op
,
captured_inputs
);
auto
backward_graph
=
make_optimized_backward_graph
(
op
.
cast
<
ApplyOp
>
().
op
().
shared_from_this
(),
{
captured_inputs
.
begin
(),
captured_inputs
.
end
()},
{
outputs
.
data
(),
outputs
.
size
()},
{
inputs_require_grad
.
data
(),
inputs_require_grad
.
size
()});
op_val
->
op
(),
captured_inputs
,
outputs
,
inputs_require_grad
);
if
(
backward_graph
)
{
backward_storage
=
BackwardGraphWithClosure
(
backward_graph
,
op
.
cast
<
ApplyOp
>
().
op
().
shared_from_this
(),
backward_graph
,
op
_val
->
op
().
shared_from_this
(),
{
captured_inputs
.
begin
(),
captured_inputs
.
end
()},
{
outputs
.
data
(),
outputs
.
size
()});
// backward by make_backward_graph
...
...
@@ -348,18 +353,17 @@ ValueRefList GradTransformation::apply_transformation(
if
(
std
::
holds_alternative
<
std
::
monostate
>
(
backward_storage
))
{
return
outputs
;
}
auto
grad_fn
=
std
::
make_shared
<
GradFn
>
();
grad_fn
->
m_key
=
m_key
;
grad_fn
->
m_slots
.
resize
(
outputs
.
size
());
grad_fn
->
m_backward
=
backward_storage
;
mgb_assert
(
!
outputs
.
empty
());
grad_fn
->
m_dests
.
reserve
(
inputs
.
size
());
// clang-format off
std
::
visit
(
[
&
](
auto
&
backward
)
{
auto
visitor
=
[
&
](
auto
&
backward
)
{
using
T
=
std
::
decay_t
<
decltype
(
backward
)
>
;
if
constexpr
(
std
::
is_same_v
<
T
,
std
::
monostate
>
)
{
mgb_throw
(
AssertionError
,
"invalid backward"
);
}
else
{
// little overhead
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
backward
.
input_has_grad
(
i
)
&&
require_grads
[
i
])
{
auto
&
input_grad_slot
=
...
...
@@ -373,19 +377,23 @@ ValueRefList GradTransformation::apply_transformation(
}
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
backward
.
output_requires_grad
(
i
))
{
// little overhead: Value::make
auto
grad_value
=
GradValue
::
make
(
outputs
[
i
],
m_key
,
GradSlotPtr
{
grad_fn
,
i
});
outputs
[
i
]
=
record_grad
(
grad_value
);
}
}
}
},
grad_fn
->
m_backward
);
};
// std::visit may be slightly slower than direct if
std
::
visit
(
visitor
,
backward_storage
);
// clang-format on
mgb_assert
(
!
grad_fn
->
m_slots
.
empty
());
m_key
->
m_tape
.
push_back
({
grad_fn
,
op_val
->
op
().
shared_from_this
()});
return
outputs
;
}
else
if
(
op
.
is
<
CreateTensor
>
())
{
return
imperative
::
apply
(
op
,
inputs
);
}
else
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
}
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
if
(
!
has_key
(
attach_grad
->
key
()))
{
return
fallback
();
}
...
...
@@ -408,7 +416,7 @@ ValueRefList GradTransformation::apply_transformation(
return
{};
}
else
if
(
auto
*
is_attached_to
=
op
.
as
<
IsAttachedTo
>
())
{
if
(
has_key
(
is_attached_to
->
key
()))
{
if
(
auto
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
// TODO: assert grad_fn
return
{
BoolValue
::
make
(
true
)};
}
...
...
@@ -416,7 +424,7 @@ ValueRefList GradTransformation::apply_transformation(
return
{
BoolValue
::
make
(
false
)};
}
else
if
(
auto
*
set_grad
=
op
.
as
<
SetGrad
>
())
{
// TODO: merge SetGrad and ApplyOp
auto
grad_fn
=
std
::
make_shared
<
GradFn
>
();
auto
grad_fn
=
LocalPtr
<
GradFn
>::
make
();
auto
&
backward
=
std
::
get
<
CustomBackward
>
(
grad_fn
->
m_backward
=
CustomBackward
());
size_t
nr_inputs
=
set_grad
->
nr_inputs
();
...
...
@@ -433,7 +441,7 @@ ValueRefList GradTransformation::apply_transformation(
grad_fn
->
m_slots
.
resize
(
nr_outputs
);
grad_fn
->
m_dests
.
reserve
(
nr_inputs
);
for
(
size_t
i
=
0
;
i
<
nr_inputs
;
++
i
)
{
if
(
auto
grad_value
=
as_grad_value
(
inputs_
[
i
]))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs_
[
i
]))
{
auto
&
input_grad_slot
=
grad_value
->
m_slot
;
grad_fn
->
m_dests
.
emplace_back
(
grad_value
->
m_slot
);
grad_fn
->
m_dests
.
back
().
m_producer_record
.
insert_after
(
...
...
@@ -461,21 +469,21 @@ ValueRefList GradTransformation::apply_transformation(
}
return
{
FunctionValue
::
make
(
make_backward_closure
(
inputs
))};
}
else
if
(
op
.
is
<
DetachGrad
>
())
{
if
(
auto
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
return
{
grad_value
->
m_value
};
}
else
{
return
{
inputs
[
0
]};
}
}
else
if
(
op
.
is
<
GetGradKey
>
())
{
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
input
))
{
return
{
GradKeyValue
::
make
(
grad_value
->
m_key
)};
}
}
return
imperative
::
apply
(
op
,
inputs
);
}
else
if
(
op
.
kind
()
==
Operator
::
IdentityLike
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
auto
output
=
imperative
::
apply
(
op
,
grad_value
->
m_value
)[
0
];
auto
grad_output
=
GradValue
::
make
(
output
,
grad_value
->
key
(),
grad_value
->
slot_for
(
m_key
));
...
...
@@ -493,7 +501,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
auto
grad_key
=
m_key
;
std
::
vector
<
GradSlotPtr
>
y_slots
;
for
(
auto
&&
y
:
ys
)
{
if
(
auto
grad_value
=
as_grad_value
(
y
))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
y
))
{
y_slots
.
push_back
(
grad_value
->
slot_for
(
grad_key
));
}
else
{
y_slots
.
emplace_back
();
...
...
imperative/src/impl/transformations/scalar.cpp
浏览文件 @
4fa61620
...
...
@@ -13,6 +13,7 @@
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/utils/stats.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -185,7 +186,7 @@ ValueRefList subtensor_rule(
bool
is_scalar
;
mgb_assert
(
!
inputs_mask
[
0
],
"subtensor shouldn't have scalar input"
);
if
(
auto
shape
=
input
.
shape
())
{
size_t
ndim
=
input
.
shape
()
->
ndim
;
size_t
ndim
=
shape
->
ndim
;
for
(
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
:
subtensor
.
items
)
{
if
(
idx
)
{
ndim
--
;
...
...
@@ -193,6 +194,7 @@ ValueRefList subtensor_rule(
}
is_scalar
=
ndim
==
0
;
}
else
{
// assume not scalar
is_scalar
=
false
;
}
auto
outputs
=
imperative
::
apply
(
subtensor
,
inputs
);
...
...
@@ -341,12 +343,16 @@ ValueRefList ScalarTransformation::apply_transformation(
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
// fastpath for GetAttr
return
apply_get_attr
(
*
get_attr
,
inputs
);
}
else
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
if
(
apply_op
->
op
().
same_type
<
FastpathCopy
>
())
{
return
inputs
[
0
];
}
}
size_t
nr_inputs
=
inputs
.
size
();
ValueRefList
unwrapped_inputs
(
nr_inputs
);
bool
inputs_mask
[
nr_inputs
]
;
SmallVector
<
bool
>
inputs_mask
(
nr_inputs
)
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
auto
scalar_value
=
inputs
[
i
].
as_ref
<
ScalarValue
>
())
{
if
(
auto
&&
scalar_value
=
inputs
[
i
].
as_ref
<
ScalarValue
>
())
{
unwrapped_inputs
[
i
]
=
scalar_value
->
value
();
inputs_mask
[
i
]
=
true
;
}
else
{
...
...
@@ -358,8 +364,7 @@ ValueRefList ScalarTransformation::apply_transformation(
if
(
auto
apply_op
=
op
.
as
<
ApplyOp
>
())
{
auto
iter
=
scalar_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
scalar_rules
.
end
())
{
return
iter
->
second
(
apply_op
->
op
(),
unwrapped_inputs
,
{
inputs_mask
,
nr_inputs
});
return
iter
->
second
(
apply_op
->
op
(),
unwrapped_inputs
,
inputs_mask
);
}
else
{
// TODO: repeat op
return
fallback
();
...
...
imperative/src/impl/value.cpp
浏览文件 @
4fa61620
...
...
@@ -215,8 +215,8 @@ ValueRefList::ValueRefList(size_t nr_elems) {
init
(
nr_elems
);
}
ValueRefList
::
ValueRefList
(
std
::
initializer_list
<
ValueRef
>
values
)
:
ValueRefList
(
values
.
begin
(),
values
.
end
())
{}
/*
ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
: ValueRefList(values.begin(), values.end()) {}
*/
ValueRefList
::
ValueRefList
(
const
ValueRefList
&
rhs
)
:
ValueRefList
(
rhs
.
cbegin
(),
rhs
.
cend
())
{}
...
...
imperative/src/include/megbrain/imperative/basic_values.h
浏览文件 @
4fa61620
...
...
@@ -25,14 +25,16 @@ class GradKey;
using
GenericFunction
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
)
>
;
class
ShapeValue
final
:
public
MixinValueImpl
<
ShapeValue
,
ValueShape
>
{
class
ShapeValue
final
:
public
MixinValueImpl
<
ShapeValue
,
ValueKind
::
Primitive
,
ValueShape
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
;
};
class
CompNodeValue
final
:
public
MixinValueImpl
<
CompNodeValue
,
CompNode
>
{
class
CompNodeValue
final
:
public
MixinValueImpl
<
CompNodeValue
,
ValueKind
::
Primitive
,
CompNode
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
@@ -40,7 +42,7 @@ public:
};
// TODO: override factory method
class
BoolValue
final
:
public
ValueImpl
<
BoolValue
>
{
class
BoolValue
final
:
public
ValueImpl
<
BoolValue
,
ValueKind
::
Primitive
>
{
private:
std
::
optional
<
bool
>
m_value
;
...
...
@@ -53,14 +55,17 @@ public:
void
clear
()
override
{
m_value
.
reset
();
}
};
class
HostStorage
final
:
public
MixinValueImpl
<
HostStorage
,
HostTensorStorage
>
{
class
HostStorage
final
:
public
MixinValueImpl
<
HostStorage
,
ValueKind
::
Primitive
,
HostTensorStorage
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
;
};
class
DeviceStorage
final
:
public
MixinValueImpl
<
DeviceStorage
,
DeviceTensorStorage
>
{
class
DeviceStorage
final
:
public
MixinValueImpl
<
DeviceStorage
,
ValueKind
::
Primitive
,
DeviceTensorStorage
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
@@ -71,7 +76,7 @@ public:
* \brief like HostTensorND mixin, but allow scalar value
*
*/
class
HostValue
final
:
public
ValueImpl
<
HostValue
>
{
class
HostValue
final
:
public
ValueImpl
<
HostValue
,
ValueKind
::
Primitive
>
{
private:
DType
m_dtype
;
ValueShape
m_shape
;
...
...
@@ -94,9 +99,9 @@ public:
}
DType
dtype
()
const
{
return
m_dtype
;
}
ValueShape
shape
()
const
{
return
m_shape
;
}
const
ValueShape
&
shape
()
const
{
return
m_shape
;
}
CompNode
device
()
const
{
return
m_storage
.
comp_node
();
}
HostTensorStorage
storage
()
const
{
return
m_storage
;
}
const
HostTensorStorage
&
storage
()
const
{
return
m_storage
;
}
DTypeScalar
item
()
const
{
mgb_assert
(
m_shape
.
is_scalar
());
return
DTypeScalar
::
make_from_raw
(
m_dtype
,
m_storage
.
ptr
());
...
...
@@ -109,7 +114,7 @@ public:
* \brief like DeviceTensorND mixin, but allow scalar value
*
*/
class
DeviceValue
final
:
public
ValueImpl
<
DeviceValue
>
{
class
DeviceValue
final
:
public
ValueImpl
<
DeviceValue
,
ValueKind
::
Primitive
>
{
private:
DType
m_dtype
;
ValueShape
m_shape
;
...
...
@@ -117,8 +122,8 @@ private:
public:
DeviceValue
(
DType
dtype
,
ValueShape
shape
,
DeviceTensorStorage
storage
)
:
m_dtype
(
dtype
),
m_shape
(
shape
),
m_storage
(
st
orage
)
{}
DeviceValue
(
DeviceTensorND
value
)
:
m_dtype
(
dtype
),
m_shape
(
shape
),
m_storage
(
st
d
::
move
(
storage
)
)
{}
DeviceValue
(
const
DeviceTensorND
&
value
)
:
DeviceValue
(
value
.
dtype
(),
ValueShape
::
from
(
value
.
shape
()),
value
.
storage
())
{
}
...
...
@@ -132,28 +137,31 @@ public:
}
DType
dtype
()
const
{
return
m_dtype
;
}
ValueShape
shape
()
const
{
return
m_shape
;
}
const
ValueShape
&
shape
()
const
{
return
m_shape
;
}
CompNode
device
()
const
{
return
m_storage
.
comp_node
();
}
DeviceTensorStorage
storage
()
const
{
return
m_storage
;
}
const
DeviceTensorStorage
&
storage
()
const
{
return
m_storage
;
}
DeviceTensorND
as_nd
(
bool
allow_scalar
=
false
)
const
;
};
class
FunctionValue
final
:
public
MixinValueImpl
<
FunctionValue
,
GenericFunction
>
{
class
FunctionValue
final
:
public
MixinValueImpl
<
FunctionValue
,
ValueKind
::
Primitive
,
GenericFunction
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
;
};
class
DTypeValue
final
:
public
MixinValueImpl
<
DTypeValue
,
DType
>
{
class
DTypeValue
final
:
public
MixinValueImpl
<
DTypeValue
,
ValueKind
::
Primitive
,
DType
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
;
};
class
StringValue
final
:
public
MixinValueImpl
<
StringValue
,
std
::
string
>
{
class
StringValue
final
:
public
MixinValueImpl
<
StringValue
,
ValueKind
::
Primitive
,
std
::
string
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
@@ -171,7 +179,8 @@ public:
std
::
string
message
()
const
{
return
m_message
;
}
};
class
ErrorValue
final
:
public
MixinValueImpl
<
ErrorValue
,
Error
>
{
class
ErrorValue
final
:
public
MixinValueImpl
<
ErrorValue
,
ValueKind
::
Primitive
,
Error
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
imperative/src/include/megbrain/imperative/dispatch.h
浏览文件 @
4fa61620
...
...
@@ -47,9 +47,14 @@ constexpr bool is_all_value_ref_v =
(...
&&
(
std
::
is_base_of_v
<
ValueRef
,
std
::
decay_t
<
TArgs
>>
||
std
::
is_same_v
<
ValueRef
,
std
::
decay_t
<
TArgs
>>
));
template
<
typename
T
>
static
ValueRefList
apply
(
T
&&
op
,
const
ValueRef
&
arg
)
{
return
imperative
::
apply
(
std
::
forward
<
T
&&>
(
op
),
Span
<
ValueRef
>
{
&
arg
,
1
});
}
template
<
typename
T
,
typename
...
TArgs
>
static
auto
apply
(
T
&&
op
,
TArgs
&&
...
args
)
->
std
::
enable_if_t
<
is_all_value_ref_v
<
TArgs
...
>
,
ValueRefList
>
{
static
auto
apply
(
T
&&
op
,
TArgs
&&
...
args
)
->
std
::
enable_if_t
<
is_all_value_ref_v
<
TArgs
...
>
&&
sizeof
...(
args
)
!=
1
,
ValueRefList
>
{
ValueRef
args_arr
[
sizeof
...(
TArgs
)]
=
{
std
::
forward
<
TArgs
&&>
(
args
)...};
return
imperative
::
apply
(
std
::
forward
<
T
&&>
(
op
),
...
...
imperative/src/include/megbrain/imperative/graph_cache.h
浏览文件 @
4fa61620
...
...
@@ -54,6 +54,11 @@ struct OpMethArgs {
return
extras
==
rhs
.
extras
;
}
template
<
size_t
i
>
auto
&
extra
()
{
return
std
::
get
<
i
>
(
extras
);
}
struct
hash_t
{
size_t
operator
()(
const
OpMethArgs
&
key
)
const
{
return
key
.
hash
();
}
};
...
...
imperative/src/include/megbrain/imperative/transformations/eval.h
浏览文件 @
4fa61620
...
...
@@ -60,7 +60,7 @@ public:
};
class
InterpreterValue
final
:
public
MixinValueImpl
<
InterpreterValue
,
InterpreterInfo
>
{
:
public
MixinValueImpl
<
InterpreterValue
,
ValueKind
::
Object
,
InterpreterInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
imperative/src/include/megbrain/imperative/transformations/grad.h
浏览文件 @
4fa61620
...
...
@@ -104,37 +104,15 @@ struct ToStringTrait<GradSlot> {
std
::
string
operator
()(
const
GradSlot
&
value
)
const
{
return
value
.
to_string
();
}
};
class
GradFn
{
private:
std
::
weak_ptr
<
GradKey
>
m_key
;
std
::
vector
<
GradSlot
>
m_slots
;
std
::
vector
<
GradSlotProducerPtr
>
m_dests
;
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
CustomBackward
>
m_backward
;
public:
void
clear
()
{
m_key
.
reset
();
m_slots
.
clear
();
m_dests
.
clear
();
m_backward
.
emplace
<
std
::
monostate
>
();
}
std
::
string
to_string
()
const
;
friend
class
GradSlotPtr
;
friend
class
GradKey
;
friend
class
GradTransformation
;
};
class
GradSlotPtr
{
private:
std
::
shared_p
tr
<
GradFn
>
m_fn
;
LocalP
tr
<
GradFn
>
m_fn
;
size_t
m_index
=
0
;
public:
GradSlotPtr
(
std
::
shared_p
tr
<
GradFn
>
fn
,
size_t
index
)
:
m_fn
(
fn
),
m_index
(
index
)
{}
GradSlotPtr
(
LocalP
tr
<
GradFn
>
fn
,
size_t
index
)
:
m_fn
(
fn
),
m_index
(
index
)
{}
GradSlotPtr
()
=
default
;
GradSlot
*
operator
->
()
const
{
return
&
m_fn
->
m_slots
[
m_index
];
}
GradSlot
*
operator
->
()
const
;
operator
bool
()
const
{
return
bool
(
m_fn
);
}
...
...
@@ -171,7 +149,33 @@ struct ToStringTrait<GradSlotProducerPtr> {
}
};
class
GradValue
final
:
public
ValueImpl
<
GradValue
>
{
class
GradFn
{
private:
std
::
weak_ptr
<
GradKey
>
m_key
;
SmallVector
<
GradSlot
>
m_slots
;
SmallVector
<
GradSlotProducerPtr
>
m_dests
;
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
CustomBackward
>
m_backward
;
public:
void
clear
()
{
m_key
.
reset
();
m_slots
.
clear
();
m_dests
.
clear
();
m_backward
.
emplace
<
std
::
monostate
>
();
}
std
::
string
to_string
()
const
;
friend
class
GradSlotPtr
;
friend
class
GradKey
;
friend
class
GradTransformation
;
};
inline
GradSlot
*
GradSlotPtr
::
operator
->
()
const
{
return
&
m_fn
->
m_slots
[
m_index
];
}
class
GradValue
final
:
public
ValueImpl
<
GradValue
,
ValueKind
::
Object
>
{
private:
ValueRef
m_value
;
std
::
shared_ptr
<
GradKey
>
m_key
;
...
...
@@ -179,7 +183,7 @@ private:
public:
GradValue
(
ValueRef
value
,
std
::
shared_ptr
<
GradKey
>
key
,
GradSlotPtr
slot
=
{})
:
m_value
(
value
),
m_key
(
key
),
m_slot
(
slot
)
{}
:
m_value
(
std
::
move
(
value
)),
m_key
(
std
::
move
(
key
)
),
m_slot
(
slot
)
{}
std
::
string
to_string
()
const
override
;
...
...
@@ -209,12 +213,13 @@ public:
class
GradKey
:
public
std
::
enable_shared_from_this
<
GradKey
>
{
private:
std
::
string
m_name
;
std
::
vector
<
std
::
pair
<
std
::
weak_ptr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_tape
;
std
::
vector
<
std
::
pair
<
std
::
shared_ptr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_frozen_tape
;
std
::
vector
<
std
::
pair
<
LocalWeakPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_tape
;
std
::
vector
<
std
::
pair
<
LocalPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_frozen_tape
;
bool
m_frozen
=
false
;
public:
GradKey
()
{
m_tape
.
reserve
(
4
*
1024
);
}
void
backward
();
GradValue
::
ref_t
attach
(
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
);
const
std
::
string
&
name
()
const
{
return
m_name
;
}
...
...
@@ -225,7 +230,8 @@ public:
};
class
GradKeyValue
final
:
public
MixinValueImpl
<
GradKeyValue
,
std
::
shared_ptr
<
GradKey
>>
{
:
public
MixinValueImpl
<
GradKeyValue
,
ValueKind
::
Primitive
,
std
::
shared_ptr
<
GradKey
>>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
@@ -248,7 +254,7 @@ public:
return
tensor
;
}
bool
is_grad_value
(
ValueRef
value
)
{
bool
is_grad_value
(
const
ValueRef
&
value
)
{
if
(
auto
*
grad_value
=
value
.
as
<
GradValue
>
())
{
if
(
grad_value
->
has_key
(
m_key
))
{
return
true
;
...
...
@@ -266,13 +272,14 @@ public:
* \param value
* \return GradValue::ref_t
*/
GradValue
::
ref_t
as_grad_value
(
ValueRef
value
)
{
if
(
auto
grad_value
=
value
.
as_ref
<
GradValue
>
())
{
const
GradValue
::
ref_t
&
as_grad_value
(
const
ValueRef
&
value
)
{
auto
&&
grad_value
=
value
.
as_ref
<
GradValue
>
();
if
(
grad_value
)
{
if
(
grad_value
->
has_key
(
m_key
))
{
return
grad_value
;
}
}
return
{}
;
return
GradValue
::
ref_t
::
nil
;
}
bool
has_key
(
std
::
shared_ptr
<
GradKey
>
key
)
{
...
...
imperative/src/include/megbrain/imperative/transformations/lazy.h
浏览文件 @
4fa61620
...
...
@@ -39,7 +39,8 @@ public:
std
::
string
name
()
const
{
return
m_name
;
}
};
class
LazyEvalValue
final
:
public
MixinValueImpl
<
LazyEvalValue
,
LazyEvalInfo
>
{
class
LazyEvalValue
final
:
public
MixinValueImpl
<
LazyEvalValue
,
ValueKind
::
Object
,
LazyEvalInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
imperative/src/include/megbrain/imperative/transformations/scalar.h
浏览文件 @
4fa61620
...
...
@@ -17,7 +17,7 @@
namespace
mgb
::
imperative
{
class
ScalarValue
final
:
public
ValueImpl
<
ScalarValue
>
{
class
ScalarValue
final
:
public
ValueImpl
<
ScalarValue
,
ValueKind
::
Object
>
{
private:
ValueRef
m_value
;
...
...
imperative/src/include/megbrain/imperative/transformations/symbol.h
浏览文件 @
4fa61620
...
...
@@ -22,7 +22,7 @@
namespace
mgb
::
imperative
{
class
SymbolValue
final
:
public
ValueImpl
<
SymbolValue
>
{
class
SymbolValue
final
:
public
ValueImpl
<
SymbolValue
,
ValueKind
::
Object
>
{
private:
VarNode
*
m_node
=
nullptr
;
...
...
imperative/src/include/megbrain/imperative/transformations/trace.h
浏览文件 @
4fa61620
...
...
@@ -111,7 +111,8 @@ public:
size_t
id
()
const
{
return
m_id
;
}
};
class
TracingValue
final
:
public
MixinValueImpl
<
TracingValue
,
TracingInfo
>
{
class
TracingValue
final
:
public
MixinValueImpl
<
TracingValue
,
ValueKind
::
Object
,
TracingInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
@@ -256,7 +257,8 @@ public:
}
};
class
TracedValue
final
:
public
MixinValueImpl
<
TracedValue
,
TracedInfo
>
{
class
TracedValue
final
:
public
MixinValueImpl
<
TracedValue
,
ValueKind
::
Object
,
TracedInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
...
...
imperative/src/include/megbrain/imperative/utils/stats.h
0 → 100644
浏览文件 @
4fa61620
#pragma once
#include <chrono>
#include <iostream>
#include <string>
#include <vector>
namespace
mgb
{
namespace
imperative
{
namespace
stats
{
#define MGE_ENABLE_STATS 0
class
Timer
{
public:
using
clock_t
=
std
::
chrono
::
system_clock
;
private:
clock_t
::
duration
m_duration
=
clock_t
::
duration
{
0
};
size_t
m_timing
=
0
;
const
char
*
m_name
=
nullptr
;
uint64_t
m_count
=
0
;
size_t
m_enabled
=
1
;
bool
m_default_enabled
=
true
;
struct
TimeScopeRecursive
{
Timer
&
timer
;
clock_t
::
time_point
start
;
bool
released
=
false
;
TimeScopeRecursive
(
Timer
&
timer
)
:
timer
(
timer
)
{
if
(
timer
.
m_enabled
&&
!
timer
.
m_timing
++
)
{
start
=
clock_t
::
now
();
}
}
~
TimeScopeRecursive
()
{
release
();
}
void
release
()
{
if
(
released
)
{
return
;
}
if
(
timer
.
m_enabled
)
{
if
(
!--
timer
.
m_timing
)
{
timer
.
m_duration
+=
(
clock_t
::
now
()
-
start
);
}
timer
.
m_count
++
;
}
released
=
true
;
}
};
struct
EnableScope
{
Timer
&
timer
;
bool
released
=
false
;
EnableScope
(
Timer
&
timer
)
:
timer
(
timer
)
{
timer
.
m_enabled
++
;
}
~
EnableScope
()
{
release
();
}
void
release
()
{
if
(
released
)
{
return
;
}
timer
.
m_enabled
--
;
released
=
true
;
}
};
using
TimeScope
=
TimeScopeRecursive
;
public:
Timer
(
const
char
*
name
,
bool
default_enabled
);
const
char
*
name
()
{
return
m_name
;
}
auto
time_scope
()
{
return
TimeScope
(
*
this
);
}
auto
time_scope_recursive
()
{
return
TimeScopeRecursive
(
*
this
);
};
auto
enable_scope
()
{
return
EnableScope
(
*
this
);
}
void
reset
()
{
m_duration
=
clock_t
::
duration
{
0
};
m_count
=
0
;
m_enabled
=
m_default_enabled
?
1
:
0
;
}
clock_t
::
duration
get
()
const
{
return
m_duration
;
}
uint64_t
count
()
const
{
return
m_count
;
}
};
}
// namespace stats
struct
Stats
{
static
inline
std
::
vector
<
stats
::
Timer
*>
sm_timers
;
// register your timers here
// for example:
//
// static inline stats::Timer mytimer;
//
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code
static
void
print
()
{
std
::
vector
<
const
char
*>
unused_timers
;
for
(
auto
*
timer
:
sm_timers
)
{
if
(
timer
->
count
()
==
0
)
{
unused_timers
.
push_back
(
timer
->
name
());
}
else
{
printf
(
"%s costs %ld ns, happens %ld times
\n
"
,
timer
->
name
(),
timer
->
get
().
count
(),
timer
->
count
());
}
}
if
(
!
unused_timers
.
empty
())
{
printf
(
"%zu timers unused
\n
"
,
unused_timers
.
size
());
}
}
static
void
reset
()
{
for
(
auto
*
timer
:
sm_timers
)
{
timer
->
reset
();
}
}
};
inline
stats
::
Timer
::
Timer
(
const
char
*
name
,
bool
default_enabled
)
:
m_name
(
name
),
m_default_enabled
(
default_enabled
)
{
Stats
::
sm_timers
.
push_back
(
this
);
}
#if MGE_ENABLE_STATS
#define MGE_TIMER_SCOPE(name) auto name = Stats::name.time_scope()
#define MGE_TIMER_SCOPE_RELEASE(name) name.release()
#define MGE_TIMER_SCOPE_ENABLE(name) auto name = Stats::name.enable_scope()
#else
#define MGE_TIMER_SCOPE(name) (void)0
#define MGE_TIMER_SCOPE_RELEASE(name) (void)0
#define MGE_TIMER_SCOPE_ENABLE(name) (void)0
#endif
}
// namespace imperative
}
// namespace mgb
imperative/src/include/megbrain/imperative/value.h
浏览文件 @
4fa61620
...
...
@@ -23,6 +23,7 @@
#include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/utils/local_ptr.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/utils/stats.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -58,6 +59,11 @@ public:
inline
size_t
code
()
const
{
return
m_code
;
}
};
enum
class
ValueKind
{
Primitive
,
Object
,
};
/**
* \brief an smart reference of value
*
...
...
@@ -129,10 +135,10 @@ public:
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template
<
typename
TValue
>
inline
TypedValueRef
<
TValue
>
as_ref
(
Type
<
TValue
>
type
=
{})
const
;
inline
const
TypedValueRef
<
TValue
>&
as_ref
(
Type
<
TValue
>
type
=
{})
const
;
template
<
typename
TValue
>
inline
TypedValueRef
<
TValue
>
cast_ref
(
Type
<
TValue
>
type
=
{})
const
;
inline
const
TypedValueRef
<
TValue
>&
cast_ref
(
Type
<
TValue
>
type
=
{})
const
;
template
<
typename
TValue
>
void
on_cast_failure
()
const
;
...
...
@@ -161,14 +167,18 @@ public:
static
bool
any_watching
();
static
const
ValueRef
nil
;
friend
class
ValueWeakRef
;
template
<
typename
T
>
template
<
typename
>
friend
class
TypedValueRef
;
template
<
typename
T
>
template
<
typename
,
ValueKind
>
friend
class
ValueImpl
;
friend
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
};
inline
const
ValueRef
ValueRef
::
nil
;
template
<
>
struct
ToStringTrait
<
ValueRef
>
{
public:
...
...
@@ -241,7 +251,7 @@ public:
friend
class
ValueRef
;
friend
class
ValueWeakRef
;
template
<
typename
T
>
template
<
typename
,
ValueKind
>
friend
class
ValueImpl
;
template
<
typename
T
>
friend
class
TypedValueRef
;
...
...
@@ -257,7 +267,7 @@ private:
*
* \tparam T type of value
*/
template
<
typename
T
>
template
<
typename
T
,
ValueKind
Kind
>
class
ValueImpl
:
public
Value
{
protected:
ValueImpl
()
:
Value
(
TYPE_CODE
)
{}
...
...
@@ -267,6 +277,7 @@ public:
using
weak_ref_t
=
TypedValueWeakRef
<
T
>
;
static
inline
const
size_t
TYPE_CODE
=
[]
{
return
register_type
(
typeid
(
T
));
}();
static
constexpr
ValueKind
KIND
=
Kind
;
/**
* \brief helper function for construct a value
...
...
@@ -288,8 +299,8 @@ public:
* \tparam T type of value
* \tparam TMixin type of mixin class
*/
template
<
typename
T
,
typename
TMixin
>
class
MixinValueImpl
:
public
ValueImpl
<
T
>
,
public
TMixin
{
template
<
typename
T
,
ValueKind
Kind
,
typename
TMixin
>
class
MixinValueImpl
:
public
ValueImpl
<
T
,
Kind
>
,
public
TMixin
{
public:
using
TMixin
::
TMixin
;
...
...
@@ -309,12 +320,14 @@ inline ValueRef::ValueRef(storage_t storage) {
template
<
typename
TValue
>
inline
const
TValue
*
ValueRef
::
as
(
Type
<
TValue
>
type
)
const
{
static_assert
(
std
::
is_base_of_v
<
ValueImpl
<
TValue
>
,
TValue
>
);
// auto _ = Stats::time_value_as.time_scope();
static_assert
(
std
::
is_base_of_v
<
Value
,
TValue
>
);
return
static_cast
<
const
TValue
*>
(
as
(
type
.
code
()));
}
template
<
typename
TValue
>
inline
const
TValue
&
ValueRef
::
cast
(
Type
<
TValue
>
type
)
const
{
// auto _ = Stats::time_value_cast.time_scope();
auto
*
ptr
=
as
<
TValue
>
(
type
);
if
(
mgb_unlikely
(
!
ptr
))
{
on_cast_failure
<
TValue
>
();
...
...
@@ -324,26 +337,27 @@ inline const TValue& ValueRef::cast(Type<TValue> type) const {
template
<
typename
TValue
>
inline
bool
ValueRef
::
is
(
Type
<
TValue
>
type
)
const
{
// auto _ = Stats::time_value_is.time_scope();
return
is
(
type
.
code
());
}
template
<
typename
TValue
>
inline
TypedValueRef
<
TValue
>
ValueRef
::
as_ref
(
Type
<
TValue
>
type
)
const
{
inline
const
TypedValueRef
<
TValue
>&
ValueRef
::
as_ref
(
Type
<
TValue
>
type
)
const
{
if
(
!
is
<
TValue
>
(
type
))
{
return
{}
;
return
TypedValueRef
<
TValue
>::
nil
;
}
return
TypedValueRef
<
TValue
>
(
*
this
);
return
*
reinterpret_cast
<
const
TypedValueRef
<
TValue
>*>
(
this
);
}
template
<
typename
TValue
>
inline
TypedValueRef
<
TValue
>
ValueRef
::
cast_ref
(
Type
<
TValue
>
type
)
const
{
inline
const
TypedValueRef
<
TValue
>&
ValueRef
::
cast_ref
(
Type
<
TValue
>
type
)
const
{
if
(
!
m_storage
)
{
return
{}
;
return
TypedValueRef
<
TValue
>::
nil
;
}
if
(
mgb_unlikely
(
!
is
<
TValue
>
(
type
)))
{
on_cast_failure
<
TValue
>
();
}
return
TypedValueRef
<
TValue
>
(
*
this
);
return
*
reinterpret_cast
<
const
TypedValueRef
<
TValue
>*>
(
this
);
}
template
<
typename
TValue
>
...
...
@@ -363,12 +377,31 @@ void ValueRef::on_cast_failure() const {
template
<
typename
T
>
class
TypedValueRef
:
public
ValueRef
{
private:
TypedValueRef
(
ValueRef
value
)
:
ValueRef
(
value
)
{}
TypedValueRef
(
ValueRef
value
)
:
ValueRef
(
std
::
move
(
value
)
)
{}
public:
TypedValueRef
()
=
default
;
const
T
&
operator
*
()
const
{
return
this
->
template
cast
<
T
>();
}
const
T
*
operator
->
()
const
{
return
this
->
template
as
<
T
>();
}
const
T
&
operator
*
()
const
{
if
constexpr
(
T
::
KIND
==
ValueKind
::
Object
)
{
return
this
->
template
cast
<
T
>();
}
else
if
constexpr
(
T
::
KIND
==
ValueKind
::
Primitive
)
{
if
(
!
m_storage
)
{
on_cast_failure
<
T
>
();
}
return
static_cast
<
const
T
&>
(
*
m_storage
);
}
else
{
static_assert
(
!
std
::
is_same_v
<
T
,
T
>
);
}
}
const
T
*
operator
->
()
const
{
if
constexpr
(
T
::
KIND
==
ValueKind
::
Object
)
{
return
this
->
template
as
<
T
>();
}
else
if
constexpr
(
T
::
KIND
==
ValueKind
::
Primitive
)
{
return
static_cast
<
const
T
*>
(
m_storage
.
get
());
}
else
{
static_assert
(
!
std
::
is_same_v
<
T
,
T
>
);
}
}
/**
* \brief reset underlying value to another value
...
...
@@ -376,6 +409,7 @@ public:
* \param successor new value
*/
inline
void
reset
(
ValueRef
successor
)
{
static_assert
(
T
::
KIND
==
ValueKind
::
Object
);
mgb_assert
(
m_storage
);
mgb_assert
(
!
m_storage
->
m_successor
);
if
(
m_storage
->
m_watching
)
{
...
...
@@ -385,9 +419,11 @@ public:
m_storage
->
m_successor
=
ValueRef
(
successor
.
storage
());
}
static
inline
const
TypedValueRef
nil
;
friend
class
ValueRef
;
template
<
typename
U
>
template
<
typename
,
ValueKind
>
friend
class
ValueImpl
;
};
...
...
@@ -423,7 +459,7 @@ public:
ValueRefList
()
=
default
;
ValueRefList
(
size_t
nr_elems
);
ValueRefList
(
ValueRef
item
);
ValueRefList
(
std
::
initializer_list
<
ValueRef
>
values
);
//
ValueRefList(std::initializer_list<ValueRef> values);
template
<
typename
TIterator
>
ValueRefList
(
TIterator
begin
,
TIterator
end
);
ValueRefList
(
const
ValueRefList
&
rhs
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录