Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ca001777
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ca001777
编写于
1月 20, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(dispatch): speed up dispatch system
GitOrigin-RevId: eabbe3e0219ff989801751c726eb0828b1b7a740
上级
187c1dc0
变更
40
隐藏空白更改
内联
并排
Showing
40 changed file
with
1220 addition
and
685 deletion
+1220
-685
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+2
-5
imperative/python/megengine/optimizer/sgd.py
imperative/python/megengine/optimizer/sgd.py
+1
-1
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+3
-2
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+4
-4
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+18
-18
imperative/python/src/module_trace.h
imperative/python/src/module_trace.h
+5
-6
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+47
-34
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+2
-2
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+18
-0
imperative/src/impl/basic_operators.cpp
imperative/src/impl/basic_operators.cpp
+1
-1
imperative/src/impl/dispatch.cpp
imperative/src/impl/dispatch.cpp
+44
-47
imperative/src/impl/interpreter/stack_manager.h
imperative/src/impl/interpreter/stack_manager.h
+1
-1
imperative/src/impl/operator.cpp
imperative/src/impl/operator.cpp
+1
-1
imperative/src/impl/physical_tensor.cpp
imperative/src/impl/physical_tensor.cpp
+14
-11
imperative/src/impl/profiler/chrome_timeline.cpp
imperative/src/impl/profiler/chrome_timeline.cpp
+2
-1
imperative/src/impl/transformations/eval.cpp
imperative/src/impl/transformations/eval.cpp
+98
-60
imperative/src/impl/transformations/grad.cpp
imperative/src/impl/transformations/grad.cpp
+42
-44
imperative/src/impl/transformations/lazy.cpp
imperative/src/impl/transformations/lazy.cpp
+4
-4
imperative/src/impl/transformations/scalar.cpp
imperative/src/impl/transformations/scalar.cpp
+177
-215
imperative/src/impl/transformations/tangent.cpp
imperative/src/impl/transformations/tangent.cpp
+25
-0
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+101
-60
imperative/src/impl/value.cpp
imperative/src/impl/value.cpp
+128
-29
imperative/src/include/megbrain/imperative/basic_operators.h
imperative/src/include/megbrain/imperative/basic_operators.h
+6
-10
imperative/src/include/megbrain/imperative/basic_values.h
imperative/src/include/megbrain/imperative/basic_values.h
+5
-1
imperative/src/include/megbrain/imperative/dispatch.h
imperative/src/include/megbrain/imperative/dispatch.h
+6
-6
imperative/src/include/megbrain/imperative/operator.h
imperative/src/include/megbrain/imperative/operator.h
+10
-9
imperative/src/include/megbrain/imperative/profiler.h
imperative/src/include/megbrain/imperative/profiler.h
+0
-1
imperative/src/include/megbrain/imperative/transformation.h
imperative/src/include/megbrain/imperative/transformation.h
+7
-3
imperative/src/include/megbrain/imperative/transformations/eval.h
...ve/src/include/megbrain/imperative/transformations/eval.h
+36
-12
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+16
-23
imperative/src/include/megbrain/imperative/transformations/lazy.h
...ve/src/include/megbrain/imperative/transformations/lazy.h
+1
-1
imperative/src/include/megbrain/imperative/transformations/scalar.h
.../src/include/megbrain/imperative/transformations/scalar.h
+4
-1
imperative/src/include/megbrain/imperative/transformations/symbol.h
.../src/include/megbrain/imperative/transformations/symbol.h
+4
-4
imperative/src/include/megbrain/imperative/transformations/tangent.h
...src/include/megbrain/imperative/transformations/tangent.h
+36
-0
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+43
-21
imperative/src/include/megbrain/imperative/utils/allocator.h
imperative/src/include/megbrain/imperative/utils/allocator.h
+113
-3
imperative/src/include/megbrain/imperative/utils/local_ptr.h
imperative/src/include/megbrain/imperative/utils/local_ptr.h
+44
-7
imperative/src/include/megbrain/imperative/utils/mempool.h
imperative/src/include/megbrain/imperative/utils/mempool.h
+4
-4
imperative/src/include/megbrain/imperative/utils/value_shape.h
...ative/src/include/megbrain/imperative/utils/value_shape.h
+2
-0
imperative/src/include/megbrain/imperative/value.h
imperative/src/include/megbrain/imperative/value.h
+145
-33
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
ca001777
...
...
@@ -16,6 +16,7 @@ import numpy as np
from
..
import
_config
from
.._imperative_rt.common
import
CompNode
from
.._imperative_rt.core2
import
SymbolVar
,
Tensor
,
apply
,
dtype_promotion
from
.._imperative_rt.core2
import
reduce_to_scalar
as
_reduce_to_scalar
from
..ops
import
builtin
from
.
import
amp
from
.indexing
import
getitem
,
setitem
...
...
@@ -508,12 +509,8 @@ def _reduce(mode):
elif
self
.
dtype
==
np
.
bool_
:
data
=
data
.
astype
(
"int32"
)
if
axis
is
None
:
data
=
data
.
reshape
(
-
1
)
assert
not
keepdims
,
"can not set axis=None and keepdims=True"
op
=
builtin
.
Reduce
(
mode
=
mode
,
axis
=
0
)
(
result
,)
=
apply
(
op
,
data
)
result
=
_remove_axis
(
result
,
0
)
result
=
_reduce_to_scalar
(
builtin
.
Reduce
(
mode
=
mode
),
data
)
elif
isinstance
(
axis
,
collections
.
abc
.
Iterable
):
axis
=
_normalize_axis
(
self
.
ndim
,
axis
,
reverse
=
True
)
for
ai
in
axis
:
...
...
imperative/python/megengine/optimizer/sgd.py
浏览文件 @
ca001777
...
...
@@ -69,7 +69,7 @@ class SGD(Optimizer):
inplace_mode
=
int
(
os
.
getenv
(
"MEGENGINE_INPLACE_UPDATE"
,
"0"
))
if
inplace_mode
:
_neg_lr
=
tensor
(
-
lr
,
dtype
=
"float32"
)
c1
=
tensor
(
[
1.0
]
)
c1
=
tensor
(
1.0
)
for
param
in
param_group
[
"params"
]:
if
param
.
grad
is
None
:
...
...
imperative/python/megengine/tensor.py
浏览文件 @
ca001777
...
...
@@ -84,14 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin):
device
:
str
=
None
,
is_const
:
bool
=
False
,
no_cache
:
bool
=
False
,
name
:
str
=
""
,
name
:
str
=
None
,
):
if
name
is
None
:
name
=
""
else
:
self
.
_set_name
(
name
)
self
.
_custom_name
=
name
self
.
_name
=
name
self
.
_short_name
=
name
self
.
_set_name
(
self
.
_name
)
self
.
_prefix
=
None
@
property
...
...
imperative/python/src/grad.cpp
浏览文件 @
ca001777
...
...
@@ -46,17 +46,17 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
if
(
args
[
1
]
!=
Py_None
)
{
callback
=
py
::
reinterpret_borrow
<
py
::
object
>
(
args
[
1
]);
}
GenericFunction
generic_callback
=
[
=
](
Span
<
ValueRef
>
inputs
)
->
std
::
vector
<
ValueRef
>
{
GenericFunction
generic_callback
=
[
=
](
Span
<
ValueRef
>
inputs
)
->
ValueRefList
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
callback
)
{
callback
(
TensorWrapper
::
make
(
py_tensor_type
,
inputs
[
0
]));
}
return
{};
};
tw
->
m_tensor
->
reset
(
imperative
::
apply
(
auto
attached_value
=
imperative
::
apply
(
AttachGrad
(
m_key
),
tw
->
m_tensor
->
data
(),
FunctionValue
::
make
(
generic_callback
))[
0
]);
FunctionValue
::
make
(
generic_callback
))[
0
];
tw
->
m_tensor
->
reset
(
attached_value
);
}
void
GradKeyWrapper
::
backward
(
GradKeyWrapper
*
self
,
py
::
list
tensors
,
py
::
list
grads
)
{
...
...
imperative/python/src/grad_override.cpp
浏览文件 @
ca001777
...
...
@@ -98,7 +98,7 @@ ValueRef make_empty_tensor(
return
res
;
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
elemwise_grad_rule
(
std
::
optional
<
ValueRefList
>
elemwise_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&
elemwise
=
op
.
cast_final_safe
<
Elemwise
>
();
...
...
@@ -117,7 +117,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule(
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
2
);
ValueRefList
ret
(
2
);
if
(
!
grad
)
{
return
ret
;
}
...
...
@@ -132,7 +132,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule(
return
imperative
::
apply
(
ApplyOp
(
op
),
inputs
);
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
reshape_grad_rule
(
std
::
optional
<
ValueRefList
>
reshape_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
...
...
@@ -147,7 +147,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule(
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
2
);
ValueRefList
ret
(
2
);
if
(
!
grad
)
{
return
ret
;
}
...
...
@@ -162,7 +162,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule(
return
imperative
::
apply
(
ApplyOp
(
op
),
inputs
);
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
subtensor_grad_rule
(
std
::
optional
<
ValueRefList
>
subtensor_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&&
subtensor
=
op
.
cast_final_safe
<
Subtensor
>
();
...
...
@@ -180,9 +180,9 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule(
grad_op_
=
std
::
move
(
grad_op
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
1
);
ValueRefList
ret
(
1
);
if
(
grad
&&
inputs
[
0
])
{
SmallVector
<
ValueRef
>
args_
(
inputs
.
size
()
+
1
);
ValueRefList
args_
(
inputs
.
size
()
+
1
);
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
0
],
grad
.
dtype
());
args_
[
0
]
=
zeros
;
args_
[
1
]
=
grad
;
...
...
@@ -197,7 +197,7 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule(
return
imperative
::
apply
(
ApplyOp
(
op
),
inputs
);
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
indexingMultiAxisVec_grad_rule
(
std
::
optional
<
ValueRefList
>
indexingMultiAxisVec_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&&
indexingMultiAxisVec
=
op
.
cast_final_safe
<
IndexingMultiAxisVec
>
();
...
...
@@ -215,9 +215,9 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule(
grad_op_
=
std
::
move
(
grad_op
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
1
);
ValueRefList
ret
(
1
);
if
(
grad
&&
inputs
[
0
])
{
SmallVector
<
ValueRef
>
args_
(
inputs
.
size
()
+
1
);
ValueRefList
args_
(
inputs
.
size
()
+
1
);
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
0
],
grad
.
dtype
());
args_
[
0
]
=
zeros
;
args_
[
1
]
=
grad
;
...
...
@@ -232,7 +232,7 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule(
return
imperative
::
apply
(
ApplyOp
(
op
),
inputs
);
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
reduce_grad_rule
(
std
::
optional
<
ValueRefList
>
reduce_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&
reduce
=
op
.
cast_final_safe
<
Reduce
>
();
...
...
@@ -251,7 +251,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule(
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
1
);
ValueRefList
ret
(
1
);
if
(
grad
&&
shapes
[
0
])
{
ret
[
0
]
=
broadcast_to
(
grad
,
shapes
[
0
]);
}
...
...
@@ -261,7 +261,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule(
return
imperative
::
apply
(
ApplyOp
(
op
),
inputs
);
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
addAxis_grad_rule
(
std
::
optional
<
ValueRefList
>
addAxis_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&&
addAxis
=
op
.
cast_final_safe
<
AddAxis
>
();
...
...
@@ -274,7 +274,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule(
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
),
flag_
=
flag
](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
1
);
ValueRefList
ret
(
1
);
if
(
grad
&&
flag_
)
{
ret
[
0
]
=
imperative
::
apply
(
*
grad_op_
,
grad
)[
0
];
}
...
...
@@ -284,7 +284,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule(
return
imperative
::
apply
(
op
,
inputs
);
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
removeAxis_grad_rule
(
std
::
optional
<
ValueRefList
>
removeAxis_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&&
removeAxis
=
op
.
cast_final_safe
<
RemoveAxis
>
();
...
...
@@ -297,7 +297,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule(
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
),
flag_
=
flag
](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
1
);
ValueRefList
ret
(
1
);
if
(
grad
&&
flag_
)
{
ret
[
0
]
=
imperative
::
apply
(
*
grad_op_
,
grad
)[
0
];
}
...
...
@@ -307,7 +307,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule(
return
imperative
::
apply
(
op
,
inputs
);
}
std
::
optional
<
std
::
vector
<
ValueRef
>
>
fastpathcopy_grad_rule
(
std
::
optional
<
ValueRefList
>
fastpathcopy_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
...
...
@@ -316,7 +316,7 @@ std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule(
maker
.
backward
([](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
std
::
vector
<
ValueRef
>
ret
(
1
);
ValueRefList
ret
(
1
);
if
(
grad
)
{
ret
[
0
]
=
grad
;
}
...
...
imperative/python/src/module_trace.h
浏览文件 @
ca001777
...
...
@@ -25,24 +25,23 @@ private:
py
::
function
m_hook_fn
;
int
m_enabled
=
0
;
std
::
vector
<
ValueRef
>
apply_module_trace_hook
(
const
OpDef
&
op
,
Span
<
ValueRef
>
input_values
)
{
ValueRefList
apply_module_trace_hook
(
const
OpDef
&
op
,
Span
<
ValueRef
>
input_values
)
{
py
::
list
input_tws
;
for
(
auto
&&
input_value
:
input_values
)
{
input_tws
.
append
(
TensorWrapper
::
make
(
py_tensor_type
,
input_value
));
}
py
::
list
output_tws
=
m_hook_fn
(
py
::
cast
(
op
.
shared_from_this
()),
*
input_tws
);
std
::
vector
<
ValueRef
>
outputs
;
ValueRefList
outputs
(
output_tws
.
size
());
auto
it
=
outputs
.
begin
();
for
(
auto
&&
output_tw
:
output_tws
)
{
outputs
.
push_back
(
TensorWrapper
::
try_cast
(
output_tw
.
ptr
())
->
m_tensor
->
data
());
*
(
it
++
)
=
TensorWrapper
::
try_cast
(
output_tw
.
ptr
())
->
m_tensor
->
data
();
}
return
outputs
;
}
public:
ModuleTraceTransformation
(
py
::
function
hook_fn
)
:
m_hook_fn
(
hook_fn
)
{}
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
{
if
(
op
.
is
<
ApplyOp
>
()
&&
m_enabled
>
0
)
{
auto
outputs
=
apply_module_trace_hook
(
op
.
cast
<
ApplyOp
>
().
op
(),
inputs
);
...
...
imperative/python/src/tensor.cpp
浏览文件 @
ca001777
...
...
@@ -87,7 +87,7 @@ PyObject* py_apply(
--
nargs
;
auto
op
=
py
::
handle
(
py_op
).
cast
<
std
::
shared_ptr
<
OpDef
>>
();
SmallVector
<
ValueRef
,
64
>
tensors
(
nargs
);
SmallVector
<
ValueRef
,
8
>
tensors
(
nargs
);
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
])))
{
// swap to a special context to reuse scalar handle
...
...
@@ -100,16 +100,15 @@ PyObject* py_apply(
Transformation
::
top
());
std
::
make_shared
<
ScalarTransformation
>
()
->
register_at
(
Transformation
::
top
());
SmallVector
<
ValueRef
>
inputs
(
nargs
);
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
auto
*
py_input
=
py
::
handle
(
args
[
i
]).
cast
<
PySymbolVar
*>
();
ValueRef
input
=
SymbolValue
::
make
(
py_input
->
m_node
);
if
(
py_input
->
is_scalar
)
{
input
=
ScalarValue
::
make
(
input
);
}
input
s
[
i
]
=
input
;
tensor
s
[
i
]
=
input
;
}
auto
outputs
=
imperative
::
apply
(
*
op
,
input
s
);
auto
outputs
=
imperative
::
apply
(
*
op
,
tensor
s
);
auto
ret
=
pybind11
::
tuple
(
outputs
.
size
());
auto
typeobj
=
py
::
handle
(
args
[
0
]).
get_type
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
...
...
@@ -140,7 +139,7 @@ PyObject* py_apply(
}
}
auto
outputs
=
imperative
::
apply
(
ApplyOp
(
*
op
),
{
tensors
.
data
(),
nargs
}
);
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
size_t
nout
=
outputs
.
size
();
auto
ret
=
py
::
tuple
(
nout
);
for
(
size_t
i
=
0
;
i
<
nout
;
++
i
)
{
...
...
@@ -214,16 +213,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if
(
!
name
.
empty
())
{
m_tensor
->
reset
(
imperative
::
apply
(
RenameValue
(
name
),
m_tensor
->
data
())[
0
]);
mgb_assert
(
((
std
::
string
&
)
*
m_tensor
->
data
().
name
())
==
name
,
"result name incorrect"
);
}
if
(
data
.
ndim
()
==
0
)
{
mgb_assert
(
m_tensor
->
is_scalar
(),
"result should be scalar"
);
}
}
}
mgb_assert
(
m_tensor
->
data
());
}
PyObject
*
TensorWrapper
::
module_trace_info
()
{
...
...
@@ -1384,15 +1377,20 @@ void init_tensor(py::module m) {
std
::
function
<
bool
(
py
::
object
,
py
::
object
)
>
array_comparator
;
bool
compare_value
(
ValueRef
lhs
,
ValueRef
rhs
)
{
if
(
!
lhs
.
shape
()
->
eq
(
*
rhs
.
shape
()))
{
auto
lvalue
=
lhs
.
numpy
();
auto
rvalue
=
rhs
.
numpy
();
if
(
lvalue
->
shape
()
!=
rvalue
->
shape
())
{
return
false
;
}
HostTensorND
lvalue
=
lhs
.
numpy
()
->
as_nd
(
true
);
HostTensorND
rvalue
=
rhs
.
numpy
()
->
as_nd
(
true
);
if
(
lvalue
->
shape
().
is_scalar
())
{
return
lvalue
->
item
()
==
rvalue
->
item
();
}
HostTensorND
lnd
=
lvalue
->
as_nd
(
true
);
HostTensorND
rnd
=
rvalue
->
as_nd
(
true
);
auto
larr
=
py
::
reinterpret_steal
<
py
::
array
>
(
npy
::
ndarray_from_tensor
(
l
value
,
npy
::
ShareType
::
TRY_SHARE
));
npy
::
ndarray_from_tensor
(
l
nd
,
npy
::
ShareType
::
TRY_SHARE
));
auto
rarr
=
py
::
reinterpret_steal
<
py
::
array
>
(
npy
::
ndarray_from_tensor
(
r
value
,
npy
::
ShareType
::
TRY_SHARE
));
npy
::
ndarray_from_tensor
(
r
nd
,
npy
::
ShareType
::
TRY_SHARE
));
return
array_comparator
(
larr
,
rarr
);
}
...
...
@@ -1539,6 +1537,19 @@ void init_tensor(py::module m) {
}
});
m
.
def
(
"reduce_to_scalar"
,
[](
py
::
object
op
,
py
::
object
tensor
)
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
make_scalar_shape
=
[
&
](
CompNode
device
)
{
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
device
,
dtype
::
Int32
(),
{
0
}),
HostStorage
::
make
(
device
))[
0
];
};
auto
output
=
imperative
::
apply
(
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
tw
->
m_tensor
->
data
(),
make_scalar_shape
(
tw
->
m_tensor
->
comp_node
()))[
0
];
return
TensorWrapper
::
make
(
py_tensor_type
,
output
);
});
m
.
def
(
"name_tensor"
,
[](
std
::
string
name
,
py
::
object
tensor
)
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
output
=
imperative
::
apply
(
TraceMarkVar
(
name
),
tw
->
m_tensor
->
data
())[
0
];
...
...
@@ -1546,9 +1557,9 @@ void init_tensor(py::module m) {
});
m
.
def
(
"is_grad_attached"
,
[](
std
::
vector
<
py
::
object
>
tensors
)
->
bool
{
SmallVector
<
ValueRef
>
values
;
for
(
auto
&&
tensor
:
tensors
)
{
values
.
push_back
(
tensor
.
cast
<
TensorWrapper
>
().
m_tensor
->
data
()
);
ValueRefList
values
(
tensors
.
size
())
;
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
values
[
i
]
=
tensors
[
i
].
cast
<
TensorWrapper
>
().
m_tensor
->
data
(
);
}
auto
outputs
=
imperative
::
apply
(
GetGradKey
(),
values
);
if
(
outputs
[
0
].
is
<
GradKeyValue
>
())
{
...
...
@@ -1559,9 +1570,9 @@ void init_tensor(py::module m) {
});
m
.
def
(
"get_grad_key"
,
[](
std
::
vector
<
py
::
object
>
tensors
)
->
py
::
object
{
SmallVector
<
ValueRef
>
values
;
for
(
auto
&&
tensor
:
tensors
)
{
values
.
push_back
(
tensor
.
cast
<
TensorWrapper
>
().
m_tensor
->
data
()
);
ValueRefList
values
(
tensors
.
size
())
;
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
values
[
i
]
=
tensors
[
i
].
cast
<
TensorWrapper
>
().
m_tensor
->
data
(
);
}
auto
outputs
=
imperative
::
apply
(
GetGradKey
(),
values
);
if
(
auto
*
grad_key_val
=
outputs
[
0
].
as
<
GradKeyValue
>
())
{
...
...
@@ -1578,7 +1589,7 @@ void init_tensor(py::module m) {
mgb_assert
(
GradKeyWrapper
::
wrap_t
::
type
().
isinstance
(
py_key
.
ptr
()));
auto
*
key
=
reinterpret_cast
<
GradKeyWrapper
::
wrap_t
*>
(
py_key
.
ptr
())
->
inst
();
GenericFunction
generic_backward_fn
=
[
backward_fn
](
Span
<
ValueRef
>
output_grads
)
->
std
::
vector
<
ValueRef
>
{
[
backward_fn
](
Span
<
ValueRef
>
output_grads
)
->
ValueRefList
{
py
::
list
output_grad_tws
;
for
(
auto
&&
output_grad
:
output_grads
)
{
if
(
output_grad
)
{
...
...
@@ -1589,23 +1600,25 @@ void init_tensor(py::module m) {
}
}
py
::
tuple
input_grad_tws
=
backward_fn
(
*
output_grad_tws
);
std
::
vector
<
ValueRef
>
input_grads
;
for
(
auto
&&
input_grad_tw
:
input_grad_tws
)
{
ValueRefList
input_grads
(
input_grad_tws
.
size
());
for
(
size_t
i
=
0
;
i
<
input_grad_tws
.
size
();
++
i
)
{
auto
input_grad_tw
=
input_grad_tws
[
i
];
if
(
!
input_grad_tw
.
is_none
())
{
input_grads
.
push_back
(
py
::
cast
<
TensorWrapper
>
(
input_grad_tw
).
m_tensor
->
data
()
)
;
input_grads
[
i
]
=
py
::
cast
<
TensorWrapper
>
(
input_grad_tw
).
m_tensor
->
data
();
}
else
{
input_grads
.
push_back
({})
;
input_grads
[
i
]
=
{}
;
}
}
return
input_grads
;
};
SmallVector
<
ValueRef
>
values
;
for
(
auto
&&
input
:
inputs
)
{
values
.
push_back
(
input
.
cast
<
TensorWrapper
>
().
m_tensor
->
data
()
);
ValueRefList
values
(
inputs
.
size
()
+
outputs
.
size
())
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
values
[
i
]
=
inputs
[
i
].
cast
<
TensorWrapper
>
().
m_tensor
->
data
(
);
}
for
(
auto
&&
output
:
outputs
)
{
values
.
push_back
(
output
.
cast
<
TensorWrapper
>
().
m_tensor
->
data
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
values
[
i
+
inputs
.
size
()]
=
outputs
[
i
].
cast
<
TensorWrapper
>
().
m_tensor
->
data
();
}
auto
wrapped_output_values
=
imperative
::
apply
(
SetGrad
(
key
->
m_key
,
generic_backward_fn
,
inputs
.
size
()),
values
);
...
...
imperative/python/src/tensor.h
浏览文件 @
ca001777
...
...
@@ -39,7 +39,7 @@ namespace mgb::imperative::python {
extern
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
;
extern
PyTypeObject
*
py_tensor_type
;
struct
Tensor
:
std
::
enable_shared_from_this
<
Tensor
>
,
NonCopyableObj
{
struct
Tensor
:
NonCopyableObj
{
private:
std
::
string
m_name
;
ValueRef
m_data
;
...
...
@@ -52,7 +52,7 @@ public:
~
Tensor
()
=
default
;
inline
std
::
shared_ptr
<
Tensor
>
copy
()
{
auto
ret
=
std
::
make_shared
<
Tensor
>
(
m_data
.
unwrap
()
);
auto
ret
=
std
::
make_shared
<
Tensor
>
(
m_data
);
ret
->
m_name
=
m_name
;
return
ret
;
}
...
...
imperative/python/src/transformation.h
浏览文件 @
ca001777
...
...
@@ -11,7 +11,15 @@
#pragma once
#include <optional>
#include <string>
#include "pybind11/pybind11.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/value.h"
#include "megbrain/utils/small_vector.h"
namespace
mgb
::
imperative
::
python
{
struct
TransformationManager
{
...
...
@@ -58,4 +66,14 @@ struct TransformationManager {
return
sl_instance
;
}
};
class
PyValue
final
:
public
MixinValueImpl
<
PyValue
,
pybind11
::
object
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
{
return
pybind11
::
str
((
const
pybind11
::
object
&
)
*
this
).
cast
<
std
::
string
>
();
}
};
}
// namespace mgb::imperative::python
imperative/src/impl/basic_operators.cpp
浏览文件 @
ca001777
...
...
@@ -45,7 +45,7 @@ CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout)
layout
.
is_contiguous
()
||
layout
.
is_empty
(),
"layout should be contiguous"
);
}
auto
CreateTensor
::
parse
(
Span
<
ValueRef
>
inputs
)
->
Args
{
auto
CreateTensor
::
parse
(
Span
<
ValueRef
>
inputs
)
const
->
Args
{
Args
result
;
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
host_storage
=
input
.
as_ref
<
HostStorage
>
())
{
...
...
imperative/src/impl/dispatch.cpp
浏览文件 @
ca001777
...
...
@@ -16,70 +16,67 @@
#include "megbrain/imperative/utils/map.h"
namespace
mgb
{
void
imperative_log_profile_begin
(
const
char
*
message
);
void
imperative_log_profile
(
const
char
*
message
);
void
imperative_log_profile_end
(
const
char
*
message
);
namespace
imperative
{
std
::
vector
<
ValueRef
>
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
static
bool
log_dispatch
=
MGB_GETENV
(
"MGE_LOG_OP_DISPATCH"
);
bool
enable_watch
=
ValueRef
::
any_watching
();
auto
&
context
=
Transformation
::
get_context
();
size_t
&
depth
=
context
.
next_transformation
;
static
const
char
tabs_storage
[]
=
"
\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t
"
;
const
char
*
tabs
=
tabs_storage
+
sizeof
(
tabs_storage
)
/
sizeof
(
char
)
-
depth
-
1
;
bool
log_current_dispatch
=
log_dispatch
;
if
(
enable_watch
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&
input
=
inputs
[
i
];
if
(
input
.
watching
())
{
log_current_dispatch
=
true
;
mgb_log_debug
(
"%sinput[%zu] is %s"
,
tabs
,
i
,
input
.
to_string
().
c_str
());
debug
::
notify_event
(
"apply"
);
}
}
}
// entrance
std
::
vector
<
ValueRef
>
outputs
;
if
(
depth
>=
context
.
transformations
.
size
())
{
// fallback
if
(
log_current_dispatch
)
{
mgb_log_debug
(
"%sfallback apply %s in %s"
,
tabs
,
op
.
to_string
().
c_str
(),
imperative
::
to_string
(
inputs
).
c_str
());
namespace
{
MGB_NOINLINE
void
copy_outputs
(
ForwardAllocator
<
ValueRef
>&
allocator
,
ValueRefList
&
outputs
)
{
size_t
nr_outputs
=
outputs
.
size
();
if
(
mgb_likely
(
nr_outputs
==
1
))
{
ValueRef
output_copy
;
output_copy
=
outputs
[
0
];
allocator
.
clear
();
outputs
=
ValueRefList
({
output_copy
});
}
else
if
(
!
outputs
.
empty
())
{
SmallVector
<
ValueRef
>
outputs_copy
(
nr_outputs
);
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
outputs_copy
[
i
]
=
outputs
[
i
];
}
outputs
=
op
.
fallback
(
inputs
);
outputs
.
clear
();
allocator
.
clear
();
outputs
=
{
outputs_copy
.
begin
(),
outputs_copy
.
end
()};
}
else
{
// dispatch to stack top
auto
&
transformation
=
*
context
.
transformations
[
depth
];
++
depth
;
context
.
frames
.
push_back
({
op
,
inputs
});
CleanupGuard
_
{[
&
]
{
context
.
frames
.
pop_back
();
--
depth
;
}};
if
(
log_current_dispatch
)
{
mgb_log_debug
(
"%s%s apply %s in %s"
,
tabs
,
transformation
.
name
().
c_str
(),
op
.
to_string
().
c_str
(),
imperative
::
to_string
(
inputs
).
c_str
());
}
outputs
=
transformation
.
apply_transformation
(
op
,
inputs
);
allocator
.
clear
();
}
if
(
log_current_dispatch
)
{
mgb_log_debug
(
"%sreturn %s"
,
tabs
,
imperative
::
to_string
(
outputs
).
c_str
());
}
}
// namespace
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&
context
=
Transformation
::
get_context
();
size_t
&
depth
=
context
.
next_transformation
;
bool
top
=
depth
==
0
;
auto
outputs
=
([
&
]
{
if
(
mgb_unlikely
(
depth
>=
context
.
transformations
.
size
()))
{
return
op
.
fallback
(
inputs
);
}
else
{
auto
&
transformation
=
*
context
.
transformations
[
depth
++
];
CleanupGuard
_
{[
&
]
{
--
depth
;
}};
return
transformation
.
apply_transformation
(
op
,
inputs
);
}
})();
if
(
mgb_unlikely
(
top
))
{
copy_outputs
(
context
.
allocator
,
outputs
);
}
return
outputs
;
}
std
::
vector
<
ValueRef
>
apply
(
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
apply
(
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
)
{
return
imperative
::
apply
(
ApplyOp
{
def
},
inputs
);
}
std
::
vector
<
ValueRef
>
apply
(
Subgraph
graph
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
apply
(
const
Subgraph
&
graph
,
Span
<
ValueRef
>
inputs
)
{
SmallVector
<
ValueRef
>
inputs_storage
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inputs_storage
.
push_back
(
inputs
[
i
]);
}
auto
apply_functor
=
[](
std
::
shared_ptr
<
OpDef
>
op
,
SmallVector
<
ValueRef
>
inputs
,
size_t
)
{
auto
outputs
=
imperative
::
apply
(
ApplyOp
(
*
op
)
,
inputs
);
auto
outputs
=
imperative
::
apply
(
*
op
,
inputs
);
return
SmallVector
<
ValueRef
>
(
outputs
.
begin
(),
outputs
.
end
());
};
auto
make_const
=
[](
TensorPtr
constant
)
->
ValueRef
{
...
...
@@ -101,7 +98,7 @@ std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) {
DeviceStorage
::
make
(
device_value
.
storage
()))[
0
];
};
auto
outputs
=
graph
.
apply
(
inputs_storage
,
apply_functor
,
make_const
);
return
{
outputs
.
begin
(),
outputs
.
end
()};
return
ValueRefList
{
outputs
.
begin
(),
outputs
.
end
()};
}
}
// namespace imperative
...
...
imperative/src/impl/interpreter/stack_manager.h
浏览文件 @
ca001777
...
...
@@ -126,7 +126,7 @@ public:
m_frames
[
m_frames
.
size
()
-
1
-
i
]
=
{
node
,
node
->
version
()};
node
=
node
->
parent
();
}
mgb_assert
(
node
->
is_root
()
,
""
);
mgb_assert
(
node
->
is_root
());
}
Trace
()
=
default
;
std
::
string
to_string
()
const
{
...
...
imperative/src/impl/operator.cpp
浏览文件 @
ca001777
...
...
@@ -3,7 +3,7 @@
namespace
mgb
{
namespace
imperative
{
std
::
vector
<
ValueRef
>
Operator
::
fallback
(
Span
<
ValueRef
>
inputs
)
const
{
ValueRefList
Operator
::
fallback
(
Span
<
ValueRef
>
inputs
)
const
{
mgb_throw
(
MegBrainError
,
"no fallback implementation for %s"
,
to_string
().
c_str
());
}
...
...
imperative/src/impl/physical_tensor.cpp
浏览文件 @
ca001777
...
...
@@ -99,19 +99,22 @@ Tensor::Tensor(
Tensor
::
Tensor
(
const
HostTensorND
&
hv
)
:
Tensor
(
hv
.
layout
(),
hv
.
comp_node
())
{
constexpr
int
size_threshold
=
TensorShape
::
MAX_NDIM
;
if
(
hv
.
layout
().
total_nr_elems
()
<=
size_threshold
)
{
size_t
nr_elems
=
hv
.
layout
().
total_nr_elems
();
if
(
nr_elems
<=
size_threshold
)
{
m_value
=
hv
;
}
MGB_RECORD_EVENT
(
profiler
::
HostToDeviceEvent
,
hv
.
layout
(),
hv
.
comp_node
(),
hv
.
raw_ptr
(),
dev_tensor
().
raw_ptr
());
dev_tensor
().
copy_from_fixlayout
(
hv
);
// even though hv is saved in m_value, Tensor itself could be
// released before copy completes
MGB_RECORD_EVENT
(
profiler
::
HostToDeviceFinishEvent
,
hv
.
layout
(),
hv
.
comp_node
(),
hv
.
raw_ptr
(),
dev_tensor
().
raw_ptr
());
AsyncReleaser
::
inst
()
->
add
(
hv
);
if
(
nr_elems
)
{
MGB_RECORD_EVENT
(
profiler
::
HostToDeviceEvent
,
hv
.
layout
(),
hv
.
comp_node
(),
hv
.
raw_ptr
(),
dev_tensor
().
raw_ptr
());
dev_tensor
().
copy_from_fixlayout
(
hv
);
// even though hv is saved in m_value, Tensor itself could be
// released before copy completes
MGB_RECORD_EVENT
(
profiler
::
HostToDeviceFinishEvent
,
hv
.
layout
(),
hv
.
comp_node
(),
hv
.
raw_ptr
(),
dev_tensor
().
raw_ptr
());
AsyncReleaser
::
inst
()
->
add
(
hv
);
}
}
Tensor
::
Tensor
(
const
DeviceTensorND
&
dv
,
const
HostTensorND
&
hv
)
{
...
...
imperative/src/impl/profiler/chrome_timeline.cpp
浏览文件 @
ca001777
...
...
@@ -310,7 +310,8 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> {
}
else
if
constexpr
(
std
::
is_same_v
<
TEvent
,
TensorGetPropEvent
>
)
{
new_host_event
(
"TensorGetProp"
,
'X'
)
.
dur
(
0
)
.
args
(
current_tensor
->
detail
(
current
->
time
));
.
args
(
current_tensor
->
detail
(
current
->
time
))
.
arg
(
"kind"
,
imperative
::
to_string
(
event
.
prop
));
}
else
if
constexpr
(
std
::
is_same_v
<
TEvent
,
TensorWaitPropEvent
>
)
{
new_host_event
(
"TensorWaitProp"
,
'B'
);
}
else
if
constexpr
(
std
::
is_same_v
<
TEvent
,
TensorWaitPropFinishEvent
>
)
{
...
...
imperative/src/impl/transformations/eval.cpp
浏览文件 @
ca001777
...
...
@@ -15,71 +15,109 @@
namespace
mgb
{
namespace
imperative
{
std
::
vector
<
ValueRef
>
InterpreterTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_val
=
op
.
as
<
ApplyOp
>
())
{
if
(
op_val
->
op
().
same_type
<
FastpathCopy
>
())
{
return
{
inputs
[
0
]};
}
SmallVector
<
Handle
>
input_handles
;
SmallVector
<
Handle
>
output_handles
;
CleanupGuard
_
{[
&
]
{
for
(
auto
handle
:
output_handles
)
{
if
(
handle
)
{
m_channel
->
del
(
handle
);
}
DTypeValue
::
ref_t
InterpreterInfo
::
dtype
()
const
{
if
(
!
m_dtype
)
{
m_dtype
=
DTypeValue
::
make
(
handle
()
->
channel
()
->
get_dtype
(
handle
()
->
handle
()));
}
return
m_dtype
;
}
CompNodeValue
::
ref_t
InterpreterInfo
::
comp_node
()
const
{
if
(
!
m_comp_node
)
{
m_comp_node
=
CompNodeValue
::
make
(
handle
()
->
channel
()
->
get_device
(
handle
()
->
handle
()));
}
return
m_comp_node
;
}
ShapeValue
::
ref_t
InterpreterInfo
::
shape
()
const
{
if
(
!
m_shape
)
{
m_shape
=
ShapeValue
::
make
(
ValueShape
::
from
(
handle
()
->
channel
()
->
get_shape
(
handle
()
->
handle
())));
}
return
m_shape
;
}
ValueRefList
InterpreterTransformation
::
apply_op
(
const
ApplyOp
&
apply_op
,
Span
<
ValueRef
>
inputs
)
{
if
(
apply_op
.
op
().
same_type
<
FastpathCopy
>
())
{
return
{
inputs
[
0
]};
}
SmallVector
<
Handle
>
input_handles
;
SmallVector
<
Handle
>
output_handles
;
CleanupGuard
_
{[
&
]
{
for
(
auto
handle
:
output_handles
)
{
if
(
handle
)
{
m_channel
->
del
(
handle
);
}
}};
for
(
auto
input
:
inputs
)
{
input_handles
.
push_back
(
*
input
.
cast
<
InterpreterValue
>
().
handle
());
}
output_handles
=
m_channel
->
apply_op
(
op_val
->
op
().
shared_from_this
(),
input_handles
);
std
::
vector
<
ValueRef
>
outputs
;
for
(
auto
&
handle
:
output_handles
)
{
outputs
.
push_back
(
InterpreterValue
::
make
(
share_handle
(
handle
)));
handle
=
nullptr
;
}
return
outputs
;
}};
for
(
auto
input
:
inputs
)
{
input_handles
.
push_back
(
input
.
cast
<
InterpreterValue
>
().
handle
()
->
handle
());
}
output_handles
=
m_channel
->
apply_op
(
apply_op
.
op
().
shared_from_this
(),
input_handles
);
ValueRefList
outputs
(
output_handles
.
size
());
for
(
size_t
i
=
0
;
i
<
output_handles
.
size
();
++
i
)
{
outputs
[
i
]
=
InterpreterValue
::
make
(
share_handle
(
output_handles
[
i
]));
output_handles
[
i
]
=
nullptr
;
}
return
outputs
;
}
ValueRefList
InterpreterTransformation
::
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
auto
&
input
=
inputs
.
item
().
cast
<
InterpreterValue
>
();
ValueRef
output
;
switch
(
get_attr
.
attr
())
{
case
GetAttr
::
DType
:
output
=
input
.
dtype
();
break
;
case
GetAttr
::
Shape
:
output
=
input
.
shape
();
break
;
case
GetAttr
::
Device
:
output
=
input
.
comp_node
();
break
;
case
GetAttr
::
Value
:
output
=
HostValue
::
make
(
m_channel
->
get_value
(
input
.
handle
()
->
handle
()));
break
;
case
GetAttr
::
Data
:
output
=
DeviceValue
::
make
(
m_channel
->
get_dev_tensor
(
input
.
handle
()
->
handle
()));
break
;
default:
mgb_throw
(
MegBrainError
,
"Interpreter: malformed GetAttr: %s"
,
get_attr
.
to_string
().
c_str
());
}
return
{
output
};
}
ValueRefList
InterpreterTransformation
::
apply_create_tensor
(
const
CreateTensor
&
create_tensor
,
Span
<
ValueRef
>
inputs
)
{
auto
args
=
create_tensor
.
parse
(
inputs
);
if
(
!
args
.
device
)
{
// implies H2D
mgb_assert
(
args
.
host
,
"neither host and device value is valid"
);
return
{
InterpreterValue
::
make
(
share_handle
(
m_channel
->
put
(
*
args
.
host
,
args
.
kind
==
CreateTensor
::
Unique
)))};
}
else
{
return
{
InterpreterValue
::
make
(
share_handle
(
m_channel
->
put
(
*
args
.
device
,
args
.
host
?
*
args
.
host
:
HostTensorND
())))};
}
}
ValueRefList
InterpreterTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_val
=
op
.
as
<
ApplyOp
>
())
{
return
apply_op
(
*
op_val
,
inputs
);
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
Handle
handle
=
*
inputs
[
0
].
cast
<
InterpreterValue
>
().
handle
();
ValueRef
output
;
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
DType
:
output
=
DTypeValue
::
make
(
m_channel
->
get_dtype
(
handle
));
break
;
case
GetAttr
::
Shape
:
output
=
ShapeValue
::
make
(
ValueShape
::
from
(
m_channel
->
get_shape
(
handle
)));
break
;
case
GetAttr
::
Device
:
output
=
CompNodeValue
::
make
(
m_channel
->
get_device
(
handle
));
break
;
case
GetAttr
::
Value
:
output
=
HostValue
::
make
(
m_channel
->
get_value
(
handle
));
break
;
case
GetAttr
::
Data
:
output
=
DeviceValue
::
make
(
m_channel
->
get_dev_tensor
(
handle
));
break
;
default:
mgb_throw
(
MegBrainError
,
"Interpreter: malformed GetAttr: %s"
,
op
.
to_string
().
c_str
());
}
return
{
output
};
return
apply_get_attr
(
*
get_attr
,
inputs
);
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
auto
args
=
create_tensor
->
parse
(
inputs
);
if
(
!
args
.
device
)
{
// implies H2D
mgb_assert
(
args
.
host
,
"neither host and device value is valid"
);
return
{
InterpreterValue
::
make
(
share_handle
(
m_channel
->
put
(
*
args
.
host
,
args
.
kind
==
CreateTensor
::
Unique
)))};
}
else
{
return
{
InterpreterValue
::
make
(
share_handle
(
m_channel
->
put
(
*
args
.
device
,
args
.
host
?
*
args
.
host
:
HostTensorND
())))};
}
return
apply_create_tensor
(
*
create_tensor
,
inputs
);
}
else
if
(
auto
*
dtr_command
=
op
.
as
<
DTRCommand
>
())
{
auto
handle
=
*
inputs
[
0
].
cast
<
InterpreterValue
>
().
handle
();
auto
handle
=
inputs
[
0
].
cast
<
InterpreterValue
>
().
handle
()
->
handle
();
switch
(
dtr_command
->
kind
())
{
case
DTRCommand
::
Drop
:
m_channel
->
drop
(
handle
);
...
...
imperative/src/impl/transformations/grad.cpp
浏览文件 @
ca001777
...
...
@@ -64,12 +64,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
size_t
count
=
std
::
count_if
(
save_for_backward
.
begin
(),
save_for_backward
.
end
(),
ranges
::
identity
{});
if
(
!
backward_graph
->
precomp
.
empty
())
{
SmallVector
<
ValueRef
>
inputs_and_outputs
;
ValueRefList
inputs_and_outputs
(
inputs
.
size
()
+
outputs
.
size
());
auto
it
=
inputs_and_outputs
.
begin
();
for
(
auto
&&
input
:
inputs
)
{
inputs_and_outputs
.
push_back
(
input
)
;
*
it
++
=
input
;
}
for
(
auto
&&
output
:
outputs
)
{
inputs_and_outputs
.
push_back
(
output
)
;
*
it
++
=
output
;
}
auto
precomp
=
imperative
::
apply
(
backward_graph
->
precomp
,
inputs_and_outputs
);
closure
.
reserve
(
precomp
.
size
()
+
count
);
...
...
@@ -89,7 +90,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
}
}
void
BackwardGraphWithClosure
::
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
ValueRef
args
[
closure
.
size
()
+
grads
.
size
()];
size_t
nargs
=
0
;
for
(
auto
&&
value
:
closure
)
{
...
...
@@ -120,7 +121,7 @@ void BackwardGraphWithClosure::operator()(
}
void
CustomBackward
::
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
size_t
nargs
=
grads
.
size
();
ValueRef
args
[
nargs
];
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
...
...
@@ -201,9 +202,10 @@ void GradKey::backward() {
mgb_throw
(
AssertionError
,
"invalid backward"
);
}
else
{
mgb_assert
(
grad_fn
->
m_slots
.
size
()
>
0
);
std
::
vector
<
ValueRef
>
grads
;
ValueRefList
grads
(
grad_fn
->
m_slots
.
size
());
auto
iter
=
grads
.
begin
();
for
(
auto
&&
slot
:
grad_fn
->
m_slots
)
{
grads
.
push_back
(
slot
.
m_grad
)
;
*
iter
++
=
slot
.
m_grad
;
}
backward
(
grads
,
grad_receiver
);
}
...
...
@@ -254,21 +256,28 @@ void GradKey::freeze() {
m_frozen
=
true
;
}
std
::
vector
<
ValueRef
>
GradTransformation
::
apply_transformation
(
ValueRefList
GradTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
unwrap_inputs
=
[
this
](
Span
<
ValueRef
>
inputs
)
->
SmallVector
<
ValueRef
>
{
SmallVector
<
ValueRef
>
unwrapped_inputs
;
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
unwrapped_inputs
.
push_back
(
grad_value
->
m_value
)
;
auto
fallback
=
[
&
]
{
ValueRefList
unwrapped_inputs
(
inputs
.
size
())
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
s
[
i
]
))
{
unwrapped_inputs
[
i
]
=
grad_value
->
m_value
;
}
else
{
unwrapped_inputs
.
push_back
(
input
)
;
unwrapped_inputs
[
i
]
=
inputs
[
i
]
;
}
}
return
unwrapped_inputs
;
return
imperative
::
apply
(
op
,
unwrapped_inputs
)
;
};
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
if
(
auto
grad_value
=
as_grad_value
(
inputs
.
item
()))
{
return
imperative
::
apply
(
op
,
grad_value
->
m_value
);
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
if
(
m_suppressed
)
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)
);
return
fallback
(
);
}
if
(
auto
*
op_val
=
op
.
as
<
ApplyOp
>
())
{
size_t
nr_require_grad
=
0
;
...
...
@@ -284,20 +293,21 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
if
(
nr_require_grad
==
0
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
SmallVector
<
ValueRef
>
captured_inputs
;
SmallVector
<
bool
>
inputs_require_grad
;
ValueRefList
captured_inputs
(
inputs
.
size
())
;
SmallVector
<
bool
>
inputs_require_grad
(
inputs
.
size
())
;
// capture value so that trace could assume input as same
auto
capture_value
=
[](
ValueRef
value
)
{
// TODO: fastpath copy shouldn't be an OpDef
return
imperative
::
apply
(
ApplyOp
(
*
FastpathCopy
::
make
()),
{
&
value
,
1
})[
0
];
};
for
(
auto
&
input
:
inputs
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&
input
=
inputs
[
i
];
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
captured_inputs
.
push_back
(
capture_value
(
grad_value
->
m_value
)
);
inputs_require_grad
.
push_back
(
true
)
;
captured_inputs
[
i
]
=
capture_value
(
grad_value
->
m_value
);
inputs_require_grad
[
i
]
=
true
;
}
else
{
captured_inputs
.
push_back
(
capture_value
(
input
)
);
inputs_require_grad
.
push_back
(
false
)
;
captured_inputs
[
i
]
=
capture_value
(
input
);
inputs_require_grad
[
i
]
=
false
;
}
}
decltype
(
std
::
declval
<
GradFn
>
().
m_backward
)
backward_storage
;
...
...
@@ -373,9 +383,11 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
mgb_assert
(
!
grad_fn
->
m_slots
.
empty
());
m_key
->
m_tape
.
push_back
({
grad_fn
,
op_val
->
op
().
shared_from_this
()});
return
outputs
;
}
else
if
(
op
.
is
<
CreateTensor
>
())
{
return
imperative
::
apply
(
op
,
inputs
);
}
else
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
if
(
!
has_key
(
attach_grad
->
key
()))
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)
);
return
fallback
(
);
}
auto
tensor
=
inputs
[
0
];
GenericFunction
callback
=
(
GenericFunction
&
)
inputs
[
1
].
cast
<
FunctionValue
>
();
...
...
@@ -386,7 +398,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
return
{
record_grad
(
output
)};
}
else
if
(
auto
*
grad_backward
=
op
.
as
<
GradBackward
>
())
{
if
(
!
has_key
(
grad_backward
->
key
()))
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)
);
return
fallback
(
);
}
size_t
nr_grads
=
inputs
.
size
()
/
2
;
mgb_assert
(
nr_grads
*
2
==
inputs
.
size
());
...
...
@@ -416,7 +428,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
backward
.
m_output_attrs
=
SmallVector
(
nr_outputs
,
CustomBackward
::
OutputAttr
{
true
,
true
});
backward
.
m_backward
=
set_grad
->
grad_fn
();
std
::
vector
<
ValueRef
>
outputs
;
ValueRefList
outputs
(
nr_outputs
)
;
grad_fn
->
m_key
=
m_key
;
grad_fn
->
m_slots
.
resize
(
nr_outputs
);
grad_fn
->
m_dests
.
reserve
(
nr_inputs
);
...
...
@@ -439,13 +451,13 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
}
else
{
grad_value
=
GradValue
::
make
(
output
,
m_key
,
GradSlotPtr
(
grad_fn
,
i
));
}
outputs
.
push_back
(
record_grad
(
grad_value
)
);
outputs
[
i
]
=
record_grad
(
grad_value
);
}
m_key
->
m_tape
.
push_back
({
grad_fn
,
nullptr
});
return
outputs
;
}
else
if
(
auto
*
gbc
=
op
.
as
<
GetBackwardColsure
>
())
{
if
(
gbc
->
key
()
!=
m_key
)
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)
);
return
fallback
(
);
}
return
{
FunctionValue
::
make
(
make_backward_closure
(
inputs
))};
}
else
if
(
op
.
is
<
DetachGrad
>
())
{
...
...
@@ -471,21 +483,8 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
op
.
is
<
CreateTensor
>
())
{
return
imperative
::
apply
(
op
,
inputs
);
}
else
{
SmallVector
<
ValueRef
>
unwrapped_inputs
;
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
unwrapped_inputs
.
push_back
(
grad_value
->
m_value
);
}
else
{
unwrapped_inputs
.
push_back
(
input
);
}
}
auto
outputs
=
imperative
::
apply
(
op
,
{
unwrapped_inputs
.
data
(),
unwrapped_inputs
.
size
()});
mgb_assert
(
op
.
kind
()
==
Operator
::
GetAttrLike
||
outputs
.
empty
());
return
outputs
;
return
fallback
();
}
}
...
...
@@ -500,8 +499,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
y_slots
.
emplace_back
();
}
}
GenericFunction
closure
=
[
grad_key
,
y_slots
](
Span
<
ValueRef
>
dys
)
->
std
::
vector
<
ValueRef
>
{
GenericFunction
closure
=
[
grad_key
,
y_slots
](
Span
<
ValueRef
>
dys
)
->
ValueRefList
{
size_t
nr_grads
=
y_slots
.
size
();
mgb_assert
(
dys
.
size
()
==
nr_grads
);
for
(
size_t
i
=
0
;
i
<
nr_grads
;
++
i
)
{
...
...
imperative/src/impl/transformations/lazy.cpp
浏览文件 @
ca001777
...
...
@@ -21,7 +21,7 @@
namespace
mgb
{
namespace
imperative
{
std
::
vector
<
ValueRef
>
LazyEvalTransformation
::
apply_transformation
(
ValueRefList
LazyEvalTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_val
=
op
.
as
<
ApplyOp
>
())
{
static
std
::
unordered_set
<
Typeinfo
*>
mm_io_ops
=
{
...
...
@@ -59,9 +59,9 @@ std::vector<ValueRef> LazyEvalTransformation::apply_transformation(
mgb_assert
(
!
output_nodes
.
empty
());
m_io_link
=
SymbolVar
(
output_nodes
[
0
]);
}
std
::
vector
<
ValueRef
>
outputs
;
for
(
auto
&&
output_node
:
output_nodes
)
{
outputs
.
push_back
(
record_var
(
output_node
)
);
ValueRefList
outputs
(
output_nodes
.
size
())
;
for
(
size_t
i
=
0
;
i
<
output_nodes
.
size
();
++
i
)
{
outputs
[
i
]
=
record_var
(
output_nodes
[
i
]
);
}
return
outputs
;
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
...
...
imperative/src/impl/transformations/scalar.cpp
浏览文件 @
ca001777
...
...
@@ -19,26 +19,8 @@ namespace imperative {
namespace
{
using
ScalarRule
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
const
OpDef
&
,
Span
<
ValueRef
>
)
>
;
static
std
::
unordered_map
<
Typeinfo
*
,
std
::
function
<
std
::
vector
<
ValueRef
>
(
const
OpDef
&
,
Span
<
ValueRef
>
)
>>
scalar_rules
;
ValueRef
unwrap_input
(
ValueRef
input
)
{
if
(
auto
scalar_input
=
input
.
as_ref
<
ScalarValue
>
())
{
return
scalar_input
->
value
();
}
else
{
return
input
;
}
}
std
::
vector
<
ValueRef
>
unwrap_inputs
(
Span
<
ValueRef
>
inputs
)
{
std
::
vector
<
ValueRef
>
unwrapped_inputs
;
for
(
auto
&&
input
:
inputs
)
{
unwrapped_inputs
.
push_back
(
unwrap_input
(
input
));
}
return
unwrapped_inputs
;
}
using
ScalarRule
=
ValueRefList
(
*
)(
const
OpDef
&
,
Span
<
ValueRef
>
,
Span
<
bool
>
);
static
std
::
unordered_map
<
Typeinfo
*
,
ScalarRule
>
scalar_rules
;
ValueRef
make_scalar_shape
(
CompNode
device
)
{
HostTensorND
scalar_shape
(
device
,
{
1
},
dtype
::
Int32
());
...
...
@@ -49,9 +31,6 @@ ValueRef make_scalar_shape(CompNode device) {
}
bool
is_scalar_shape
(
ValueRef
shape
)
{
if
(
shape
.
is
<
ScalarValue
>
())
{
return
false
;
}
// may have performance issue
auto
shape_of_shape
=
shape
.
shape
();
if
(
!
shape_of_shape
)
{
...
...
@@ -61,74 +40,65 @@ bool is_scalar_shape(ValueRef shape) {
return
*
shape_of_shape
==
ValueShape
{
0
};
}
template
<
typename
T
>
void
register_scalar_rule
(
std
::
vector
<
ValueRef
>
(
*
rule
)(
const
T
&
,
Span
<
ValueRef
>
))
{
scalar_rules
[
T
::
typeinfo
()]
=
[
rule
](
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
)
{
return
(
*
rule
)(
def
.
cast_final_safe
<
T
>
(),
inputs
);
template
<
typename
T
,
ValueRefList
(
*
rule
)(
const
T
&
,
Span
<
ValueRef
>,
Span
<
bool
>
)
>
void
register_scalar_rule
()
{
scalar_rules
[
T
::
typeinfo
()]
=
[](
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
return
(
*
rule
)(
def
.
cast_final_safe
<
T
>
(),
inputs
,
inputs_mask
);
};
}
std
::
vector
<
ValueRef
>
elemwise_rule
(
const
Elemwise
&
elem
,
Span
<
ValueRef
>
inputs
)
{
template
<
typename
TOpDef
,
size_t
nr_inputs
>
ValueRefList
elemwise_rule
(
const
TOpDef
&
op_def
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
if
constexpr
(
nr_inputs
!=
0
)
{
mgb_assert
(
inputs
.
size
()
==
inputs
.
size
(),
"inputs size mismatch"
);
}
bool
all_scalar
=
true
;
for
(
auto
&&
input
:
inputs
)
{
if
(
!
input
.
is
<
ScalarValue
>
()
)
{
for
(
auto
&&
input
_mask
:
inputs_mask
)
{
if
(
!
input
_mask
)
{
all_scalar
=
false
;
break
;
}
}
auto
output
=
imperative
::
apply
(
elem
,
unwrap_inputs
(
inputs
))[
0
]
;
auto
output
s
=
imperative
::
apply
(
op_def
,
inputs
)
;
if
(
all_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
}
return
outputs
;
}
std
::
vector
<
ValueRef
>
remove_axis_rule
(
const
RemoveAxis
&
remove_axis
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
mgb_assert
(
!
inputs
[
0
].
is
<
ScalarValue
>
());
auto
output
=
imperative
::
apply
(
remove_axis
,
inputs
)[
0
];
bool
is_scalar
=
inputs
[
0
].
shape
()
->
ndim
==
remove_axis
.
axis
.
size
();
ValueRefList
remove_axis_rule
(
const
RemoveAxis
&
remove_axis
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
mgb_assert
(
!
inputs_mask
.
item
());
bool
is_scalar
=
inputs
.
item
().
shape
()
->
ndim
==
remove_axis
.
axis
.
size
();
if
(
is_scalar
&&
remove_axis
.
axis
.
size
()
==
1
)
{
return
{
ScalarValue
::
make
(
inputs
.
item
())};
}
auto
outputs
=
imperative
::
apply
(
remove_axis
,
inputs
);
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
}
return
outputs
;
}
std
::
vector
<
ValueRef
>
reduce_rule
(
const
Reduce
&
reduce
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
reduce_rule
(
const
Reduce
&
reduce
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
if
(
inputs
.
size
()
==
1
)
{
return
imperative
::
apply
(
reduce
,
unwrap_inputs
(
inputs
)
);
return
imperative
::
apply
(
reduce
,
inputs
);
}
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
if
(
is_scalar
)
{
auto
unwrapped_input
=
unwrap_input
(
inputs
[
0
]);
CompNode
device
=
*
unwrapped_input
.
device
();
return
{
ScalarValue
::
make
(
imperative
::
apply
(
reduce
,
unwrapped_input
,
make_scalar_shape
(
device
))[
0
])};
}
auto
output
=
imperative
::
apply
(
reduce
,
unwrap_inputs
(
inputs
))[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
}
}
std
::
vector
<
ValueRef
>
typecvt_rule
(
const
TypeCvt
&
typecvt
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
scalar_input
=
inputs
[
0
].
as_ref
<
ScalarValue
>
())
{
CompNode
device
=
*
inputs
[
0
].
device
();
return
{
ScalarValue
::
make
(
imperative
::
apply
(
typecvt
,
scalar_input
->
value
())[
0
])};
}
else
{
return
imperative
::
apply
(
typecvt
,
inputs
);
imperative
::
apply
(
reduce
,
inputs
[
0
],
make_scalar_shape
(
device
))[
0
])};
}
return
imperative
::
apply
(
reduce
,
inputs
);
}
std
::
vector
<
ValueRef
>
collective_comm_rule
(
const
CollectiveComm
&
collective_comm
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
collective_comm_rule
(
const
CollectiveComm
&
collective_comm
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
static
std
::
unordered_set
<
CollectiveComm
::
Mode
>
modes
=
{
CollectiveComm
::
Mode
::
ALL_REDUCE_MAX
,
CollectiveComm
::
Mode
::
ALL_REDUCE_MIN
,
...
...
@@ -138,17 +108,17 @@ std::vector<ValueRef> collective_comm_rule(
if
(
modes
.
count
(
collective_comm
.
mode
)
==
0
)
{
return
imperative
::
apply
(
collective_comm
,
inputs
);
}
if
(
auto
scalar_input
=
inputs
[
0
].
as_ref
<
ScalarValue
>
())
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
collective_comm
,
scalar_input
->
value
())[
0
])};
if
(
inputs_mask
.
item
())
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
collective_comm
,
inputs
[
0
])[
0
])};
}
else
{
return
imperative
::
apply
(
collective_comm
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
param_pack_split_rule
(
const
ParamPackSplit
&
param_pack_split
,
Span
<
ValueRef
>
inputs
)
{
auto
outputs
=
imperative
::
apply
(
param_pack_split
,
unwrap_inputs
(
inputs
));
ValueRefList
param_pack_split_rule
(
const
ParamPackSplit
&
param_pack_split
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
auto
outputs
=
imperative
::
apply
(
param_pack_split
,
inputs
);
size_t
nr_outputs
=
outputs
.
size
();
mgb_assert
(
nr_outputs
==
param_pack_split
.
shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
...
...
@@ -159,29 +129,28 @@ std::vector<ValueRef> param_pack_split_rule(
return
outputs
;
}
std
::
vector
<
ValueRef
>
dot_rule
(
const
Dot
&
dot
,
Span
<
ValueRef
>
inputs
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
dot
,
unwrap_inputs
(
inputs
)
)[
0
])};
ValueRefList
dot_rule
(
const
Dot
&
dot
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
dot
,
inputs
)[
0
])};
}
std
::
vector
<
ValueRef
>
add_axis_rule
(
const
AddAxis
&
add_axis
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
add_axis_rule
(
const
AddAxis
&
add_axis
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
scalar_input
=
inputs
[
0
].
as_ref
<
ScalarValue
>
())
{
if
(
inputs_mask
.
item
())
{
mgb_assert
(
add_axis
.
axis
[
0
]
==
0
);
if
(
add_axis
.
axis
.
size
()
==
1
)
{
return
{
scalar_input
->
value
()
};
return
{
inputs
[
0
]
};
}
else
{
std
::
vector
<
int32_t
>
axis
(
add_axis
.
axis
.
begin
()
+
1
,
add_axis
.
axis
.
end
());
return
imperative
::
apply
(
ApplyOp
(
*
AddAxis
::
make
(
axis
,
add_axis
.
scope
())),
scalar_input
->
value
());
return
imperative
::
apply
(
*
AddAxis
::
make
(
axis
,
add_axis
.
scope
()),
inputs
[
0
]);
}
}
else
{
return
imperative
::
apply
(
add_axis
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
remote_recv_rule
(
const
RemoteRecv
&
remote_recv
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
remote_recv_rule
(
const
RemoteRecv
&
remote_recv
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
if
(
remote_recv
.
shape
.
empty
())
{
std
::
vector
<
int32_t
>
shape
=
{
1
};
auto
remote_recv_no_scalar
=
RemoteRecv
::
make
(
...
...
@@ -189,32 +158,32 @@ std::vector<ValueRef> remote_recv_rule(
remote_recv
.
rank_from
,
remote_recv
.
cn
,
shape
,
remote_recv
.
dtype
,
remote_recv
.
backend
);
remote_recv_no_scalar
->
set_scope
(
remote_recv
.
scope
());
return
imperative
::
apply
(
ApplyOp
(
*
remote_recv_no_scalar
),
unwrap_inputs
(
inputs
));
return
imperative
::
apply
(
ApplyOp
(
*
remote_recv_no_scalar
),
inputs
);
}
else
{
return
imperative
::
apply
(
remote_recv
,
unwrap_inputs
(
inputs
)
);
return
imperative
::
apply
(
remote_recv
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
check_no_finite_rule
(
const
CheckNonFinite
&
check_no_finite
,
Span
<
ValueRef
>
inputs
)
{
auto
outputs
=
imperative
::
apply
(
check_no_finite
,
unwrap_inputs
(
inputs
));
ValueRefList
check_no_finite_rule
(
const
CheckNonFinite
&
check_no_finite
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
auto
outputs
=
imperative
::
apply
(
check_no_finite
,
inputs
);
mgb_assert
(
outputs
.
size
()
==
inputs
.
size
()
+
1
,
"output size mismatch"
);
outputs
.
back
()
=
ScalarValue
::
make
(
outputs
.
back
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
inputs
[
i
].
is
<
ScalarValue
>
()
)
{
if
(
inputs
_mask
[
i
]
)
{
outputs
[
i
]
=
ScalarValue
::
make
(
outputs
[
i
]);
}
}
return
outputs
;
}
std
::
vector
<
ValueRef
>
subtensor_rule
(
const
Subtensor
&
subtensor
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
subtensor_rule
(
const
Subtensor
&
subtensor
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
input
=
inputs
[
0
];
bool
is_scalar
;
mgb_assert
(
!
input
.
is
<
ScalarValue
>
()
,
"subtensor shouldn't have scalar input"
);
mgb_assert
(
!
input
s_mask
[
0
]
,
"subtensor shouldn't have scalar input"
);
if
(
auto
shape
=
input
.
shape
())
{
size_t
ndim
=
input
.
shape
()
->
ndim
;
for
(
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
:
subtensor
.
items
)
{
...
...
@@ -226,25 +195,25 @@ std::vector<ValueRef> subtensor_rule(
}
else
{
is_scalar
=
false
;
}
auto
output
=
imperative
::
apply
(
subtensor
,
unwrap_inputs
(
inputs
))[
0
]
;
auto
output
s
=
imperative
::
apply
(
subtensor
,
inputs
)
;
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
}
return
outputs
;
}
std
::
vector
<
ValueRef
>
get_var_shape_rule
(
const
GetVarShape
&
get_var_shape
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
get_var_shape_rule
(
const
GetVarShape
&
get_var_shape
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
bool
all_scalar
=
true
;
mgb_assert
(
inputs
.
size
()
>=
1
);
for
(
auto
&&
input
:
inputs
)
{
if
(
!
input
.
is
<
ScalarValue
>
()
)
{
for
(
auto
&&
input
_mask
:
inputs_mask
)
{
if
(
!
input
_mask
)
{
all_scalar
=
false
;
}
}
if
(
all_scalar
)
{
auto
device
=
inputs
[
0
].
cast
<
ScalarValue
>
().
value
().
device
();
auto
device
=
inputs
[
0
].
device
();
auto
storage
=
HostStorage
::
make
(
*
device
);
// storage->ensure_size(1);
return
imperative
::
apply
(
...
...
@@ -252,88 +221,49 @@ std::vector<ValueRef> get_var_shape_rule(
CreateTensor
::
Const
,
*
device
,
dtype
::
Int32
(),
ValueShape
{
0
}),
storage
);
}
else
{
return
imperative
::
apply
(
get_var_shape
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
fastpath_copy_rule
(
const
FastpathCopy
&
fastpath_copy
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
bool
is_scalar
=
inputs
[
0
].
is
<
ScalarValue
>
();
auto
output
=
imperative
::
apply
(
fastpath_copy
,
unwrap_inputs
(
inputs
))[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
return
imperative
::
apply
(
get_var_shape
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
reshape_rule
(
const
Reshape
&
reshape
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
reshape_rule
(
const
Reshape
&
reshape
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
auto
unwrapped_input
=
inputs
[
0
].
is
<
ScalarValue
>
()
?
inputs
[
0
].
cast
<
ScalarValue
>
().
value
()
:
inputs
[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
reshape
,
unwrapped_input
,
make_scalar_shape
(
*
unwrapped_input
.
device
()))[
0
])};
reshape
,
inputs
[
0
],
make_scalar_shape
(
*
inputs
[
0
].
device
()))[
0
])};
}
else
{
return
imperative
::
apply
(
reshape
,
unwrap_inputs
(
inputs
)
);
return
imperative
::
apply
(
reshape
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
broadcast_rule
(
const
Broadcast
&
broadcast
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
broadcast_rule
(
const
Broadcast
&
broadcast
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
auto
unwrapped_input
=
inputs
[
0
].
is
<
ScalarValue
>
()
?
inputs
[
0
].
cast
<
ScalarValue
>
().
value
()
:
inputs
[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
broadcast
,
unwrapped_input
,
make_scalar_shape
(
*
unwrapped_input
.
device
()))[
0
])};
}
else
{
return
imperative
::
apply
(
broadcast
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
copy_rule
(
const
Copy
&
copy
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
bool
is_scalar
=
inputs
[
0
].
is
<
ScalarValue
>
();
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
copy
,
unwrap_inputs
(
inputs
))[
0
])};
}
else
{
return
imperative
::
apply
(
copy
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
inplace_add_rule
(
const
InplaceAdd
&
inplace_add
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
4
);
bool
is_scalar
=
inputs
[
0
].
is
<
ScalarValue
>
();
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
inplace_add
,
unwrap_inputs
(
inputs
))[
0
])};
broadcast
,
inputs
[
0
],
make_scalar_shape
(
*
inputs
[
0
].
device
()))[
0
])};
}
else
{
return
imperative
::
apply
(
inplace_add
,
unwrap_inputs
(
inputs
)
);
return
imperative
::
apply
(
broadcast
,
inputs
);
}
}
template
<
typename
T
>
std
::
vector
<
ValueRef
>
subgraph_op_rule
(
const
T
&
op
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
subgraph_op_rule
(
const
T
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_mask
,
const
Type
<
ScalarValue
>&
scalar_type
)
{
// TODO: add flag instead of assume
bool
all_scalar
=
true
;
for
(
auto
&&
input
:
inputs
)
{
if
(
!
input
.
is
<
ScalarValue
>
()
)
{
for
(
auto
&&
input
_mask
:
inputs_mask
)
{
if
(
!
input
_mask
)
{
all_scalar
=
false
;
}
}
auto
outputs
=
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)
);
auto
outputs
=
imperative
::
apply
(
op
,
inputs
);
if
(
all_scalar
)
{
for
(
auto
&
output
:
outputs
)
{
output
=
ScalarValue
::
make
(
output
);
output
=
scalar_type
.
make
(
output
);
}
}
return
outputs
;
...
...
@@ -341,67 +271,54 @@ std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) {
struct
ScalarRuleRegistry
{
ScalarRuleRegistry
()
{
register_scalar_rule
(
elemwise_rule
);
register_scalar_rule
(
remove_axis_rule
);
register_scalar_rule
(
reduce_rule
);
register_scalar_rule
(
typecvt_rule
);
register_scalar_rule
(
collective_comm_rule
);
register_scalar_rule
(
param_pack_split_rule
);
register_scalar_rule
(
dot_rule
);
register_scalar_rule
(
add_axis_rule
);
register_scalar_rule
(
remote_recv_rule
);
register_scalar_rule
(
check_no_finite_rule
);
register_scalar_rule
(
subtensor_rule
);
register_scalar_rule
(
get_var_shape_rule
);
register_scalar_rule
(
fastpath_copy_rule
);
register_scalar_rule
(
reshape_rule
);
register_scalar_rule
(
broadcast_rule
);
register_scalar_rule
(
copy_rule
);
register_scalar_rule
(
inplace_add_rule
);
register_scalar_rule
(
subgraph_op_rule
<
SubgraphOp
>
);
register_scalar_rule
(
subgraph_op_rule
<
CompiledOp
>
);
register_scalar_rule
<
Elemwise
,
elemwise_rule
<
Elemwise
,
0
>>
(
);
register_scalar_rule
<
RemoveAxis
,
remove_axis_rule
>
(
);
register_scalar_rule
<
Reduce
,
reduce_rule
>
(
);
register_scalar_rule
<
TypeCvt
,
elemwise_rule
<
TypeCvt
,
1
>>
(
);
register_scalar_rule
<
CollectiveComm
,
collective_comm_rule
>
(
);
register_scalar_rule
<
ParamPackSplit
,
param_pack_split_rule
>
(
);
register_scalar_rule
<
Dot
,
dot_rule
>
(
);
register_scalar_rule
<
AddAxis
,
add_axis_rule
>
(
);
register_scalar_rule
<
RemoteRecv
,
remote_recv_rule
>
(
);
register_scalar_rule
<
CheckNonFinite
,
check_no_finite_rule
>
(
);
register_scalar_rule
<
Subtensor
,
subtensor_rule
>
(
);
register_scalar_rule
<
GetVarShape
,
get_var_shape_rule
>
(
);
register_scalar_rule
<
FastpathCopy
,
elemwise_rule
<
FastpathCopy
,
1
>>
(
);
register_scalar_rule
<
Reshape
,
reshape_rule
>
(
);
register_scalar_rule
<
Broadcast
,
broadcast_rule
>
(
);
register_scalar_rule
<
Copy
,
elemwise_rule
<
Copy
,
1
>>
(
);
register_scalar_rule
<
InplaceAdd
,
elemwise_rule
<
InplaceAdd
,
4
>>
(
);
register_scalar_rule
<
SubgraphOp
,
subgraph_op_rule
<
SubgraphOp
>>
(
);
register_scalar_rule
<
CompiledOp
,
subgraph_op_rule
<
CompiledOp
>>
(
);
}
}
_
;
}
// namespace
std
::
vector
<
ValueRef
>
ScalarTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
apply_op
=
op
.
as
<
ApplyOp
>
())
{
auto
iter
=
scalar_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
scalar_rules
.
end
())
{
return
iter
->
second
(
apply_op
->
op
(),
inputs
);
}
else
{
// TODO: repeat op
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
if
(
create_tensor
->
shape
().
is_scalar
())
{
ValueShape
scalar_shape
=
{
1
};
CreateTensor
scalar_op
(
create_tensor
->
kind
(),
create_tensor
->
device
(),
create_tensor
->
dtype
(),
scalar_shape
);
return
{
ScalarValue
::
make
(
imperative
::
apply
(
scalar_op
,
inputs
)[
0
])};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
bool
is_scalar
=
inputs
.
as_array
<
1
>
()[
0
].
is
<
ScalarValue
>
();
auto
output
=
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
))[
0
];
if
(
!
is_scalar
)
{
return
{
output
};
ValueRefList
ScalarTransformation
::
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
input
=
inputs
.
item
();
bool
is_scalar
=
input
.
is
<
ScalarValue
>
();
if
(
!
is_scalar
)
{
return
imperative
::
apply
(
get_attr
,
input
);
}
auto
unwrapped_input
=
input
.
cast
<
ScalarValue
>
().
value
();
if
(
get_attr
.
attr
()
==
GetAttr
::
Shape
)
{
if
(
!
m_empty_shape
)
{
m_empty_shape
=
ShapeValue
::
make
();
}
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
{
// Scalar Shape
return
{
ShapeValue
::
make
()}
;
}
return
{
m_empty_shape
};
}
else
{
auto
outputs
=
imperative
::
apply
(
get_attr
,
unwrapped_input
);
auto
&
output
=
outputs
[
0
]
;
switch
(
get_attr
.
attr
())
{
case
GetAttr
::
Value
:
{
auto
&
hv
=
output
.
cast
<
HostValue
>
();
mgb_assert
(
hv
.
shape
()
==
ValueShape
({
1
}),
"underlying value should has shape {1}, got %s"
,
hv
.
shape
().
to_string
().
c_str
());
return
{
HostValue
::
make
(
hv
.
dtype
(),
ValueShape
(),
hv
.
storage
())};
output
=
HostValue
::
make
(
hv
.
dtype
(),
ValueShape
(),
hv
.
storage
());
break
;
}
case
GetAttr
::
Data
:
{
auto
&
dv
=
output
.
cast
<
DeviceValue
>
();
...
...
@@ -409,22 +326,67 @@ std::vector<ValueRef> ScalarTransformation::apply_transformation(
dv
.
shape
()
==
ValueShape
({
1
}),
"underlying value should has shape {1}, got %s"
,
dv
.
shape
().
to_string
().
c_str
());
return
{
DeviceValue
::
make
(
dv
.
dtype
(),
ValueShape
(),
dv
.
storage
())};
output
=
DeviceValue
::
make
(
dv
.
dtype
(),
ValueShape
(),
dv
.
storage
());
break
;
}
default:
return
{
output
};
break
;
}
return
outputs
;
}
}
ValueRefList
ScalarTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
// fastpath for GetAttr
return
apply_get_attr
(
*
get_attr
,
inputs
);
}
size_t
nr_inputs
=
inputs
.
size
();
ValueRefList
unwrapped_inputs
(
nr_inputs
);
bool
inputs_mask
[
nr_inputs
];
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
auto
scalar_value
=
inputs
[
i
].
as_ref
<
ScalarValue
>
())
{
unwrapped_inputs
[
i
]
=
scalar_value
->
value
();
inputs_mask
[
i
]
=
true
;
}
else
{
unwrapped_inputs
[
i
]
=
inputs
[
i
];
inputs_mask
[
i
]
=
false
;
}
}
auto
fallback
=
[
&
]
{
return
imperative
::
apply
(
op
,
unwrapped_inputs
);
};
if
(
auto
apply_op
=
op
.
as
<
ApplyOp
>
())
{
auto
iter
=
scalar_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
scalar_rules
.
end
())
{
return
iter
->
second
(
apply_op
->
op
(),
unwrapped_inputs
,
{
inputs_mask
,
nr_inputs
});
}
else
{
// TODO: repeat op
return
fallback
();
}
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
if
(
create_tensor
->
shape
().
is_scalar
())
{
ValueShape
scalar_shape
=
{
1
};
CreateTensor
scalar_op
(
create_tensor
->
kind
(),
create_tensor
->
device
(),
create_tensor
->
dtype
(),
scalar_shape
);
return
{
ScalarValue
::
make
(
imperative
::
apply
(
scalar_op
,
inputs
)[
0
])};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
op
.
as
<
IsScalar
>
())
{
return
{
BoolValue
::
make
(
inputs
.
as_array
<
1
>
()[
0
].
is
<
ScalarValue
>
())};
mgb_assert
(
nr_inputs
==
1
);
return
{
BoolValue
::
make
(
inputs_mask
[
0
])};
}
else
if
(
op
.
is
<
Operator
::
IdentityLike
>
())
{
bool
is_scalar
=
inputs
.
as_array
<
1
>
()[
0
].
is
<
ScalarValue
>
();
mgb_assert
(
nr_inputs
==
1
);
bool
is_scalar
=
inputs_mask
[
0
];
auto
outputs
=
fallback
();
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
))[
0
])};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
outputs
[
0
]
=
ScalarValue
::
make
(
outputs
[
0
]);
}
return
outputs
;
}
else
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)
);
return
fallback
(
);
}
};
...
...
imperative/src/impl/transformations/tangent.cpp
0 → 100644
浏览文件 @
ca001777
/**
* \file imperative/src/impl/transformations/tangent.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/transformations/tangent.h"
namespace
mgb
{
namespace
imperative
{
ValueRefList
TangentTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
}
mgb_assert
(
false
);
}
}
// namespace imperative
}
// namespace mgb
imperative/src/impl/transformations/trace.cpp
浏览文件 @
ca001777
...
...
@@ -153,7 +153,7 @@ VarNodeArray TraceResult::dump(
return
output_nodes
;
}
std
::
vector
<
ValueRef
>
TracingTransformation
::
apply_transformation
(
ValueRefList
TracingTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_value
=
op
.
as
<
ApplyOp
>
())
{
SmallVector
<
ValueRef
>
unwrapped_inputs
;
...
...
@@ -180,11 +180,12 @@ std::vector<ValueRef> TracingTransformation::apply_transformation(
}
const_cast
<
OpDef
&>
(
op_value
->
op
()).
set_scope
(
scopes_join
);
auto
unwrapped_outputs
=
imperative
::
apply
(
op
,
unwrapped_inputs
);
std
::
vector
<
ValueRef
>
wrapped_outputs
;
ValueRefList
wrapped_outputs
(
unwrapped_outputs
.
size
())
;
SmallVector
<
size_t
>
output_ids
;
for
(
auto
&&
output
:
unwrapped_outputs
)
{
for
(
size_t
i
=
0
;
i
<
unwrapped_outputs
.
size
();
++
i
)
{
auto
&&
output
=
unwrapped_outputs
[
i
];
auto
wrapped_output
=
record_var
(
output
,
false
,
VarKind
::
Internal
);
wrapped_outputs
.
push_back
(
wrapped_output
)
;
wrapped_outputs
[
i
]
=
wrapped_output
;
output_ids
.
push_back
(
wrapped_output
->
id
());
}
m_seq
.
push_back
({
op_value
->
op
().
shared_from_this
(),
input_ids
,
output_ids
});
...
...
@@ -375,6 +376,11 @@ void CompiledTransformation::compile() {
return
accessor
;
};
std
::
vector
<
VarAccessor
>
var_accessors
(
m_vars
.
size
());
auto
exc_setter
=
std
::
bind
(
&
CompiledTransformation
::
set_exception
,
this
,
std
::
placeholders
::
_1
);
for
(
auto
&&
accessor
:
var_accessors
)
{
accessor
.
exc_setter
=
exc_setter
;
}
for
(
auto
&&
item
:
m_seq
)
{
bool
require_link
=
bool
(
item
.
op
)
&&
mm_io_ops
.
count
(
item
.
op
->
dyn_typeinfo
());
VarNodeArray
input_vars
;
...
...
@@ -509,8 +515,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
}
}
TracedValue
::
ref_t
CompiledTransformation
::
trace_output
(
size_t
id
)
{
auto
traced_value
=
TracedValue
::
make
(
id
);
auto
CompiledTransformation
::
trace_output
(
size_t
id
)
->
TracedValue
::
ref_t
{
auto
traced_value
=
TracedValue
::
make
(
id
,
&
m_vars
[
id
],
&
m_var_accessors
[
id
]
);
m_weak_values
.
push_back
(
traced_value
);
return
traced_value
;
}
...
...
@@ -520,64 +526,99 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
return
m_seq
[
m_pc
++
];
}
std
::
vector
<
ValueRef
>
CompiledTransformation
::
apply_transformation
(
ShapeValue
::
ref_t
CompiledTransformation
::
TracedInfo
::
shape
()
const
{
if
(
!
m_shape
)
{
trace_assert
(
m_accessor
->
shape_getter
,
"shape unreadable"
);
m_shape
=
ShapeValue
::
make
(
ValueShape
::
from
(
m_accessor
->
shape_getter
()));
}
return
m_shape
;
}
DTypeValue
::
ref_t
CompiledTransformation
::
TracedInfo
::
dtype
()
const
{
if
(
!
m_dtype
)
{
m_dtype
=
DTypeValue
::
make
(
m_var
->
dtype
);
}
return
m_dtype
;
}
CompNodeValue
::
ref_t
CompiledTransformation
::
TracedInfo
::
comp_node
()
const
{
if
(
!
m_comp_node
)
{
m_comp_node
=
CompNodeValue
::
make
(
m_var
->
device
);
}
return
m_comp_node
;
}
auto
CompiledTransformation
::
TracedInfo
::
accessor
()
const
->
const
VarAccessor
&
{
return
*
m_accessor
;
}
ValueRefList
CompiledTransformation
::
apply_op
(
const
ApplyOp
&
apply_op
,
Span
<
ValueRef
>
inputs
)
{
auto
&
item
=
next_instruction
();
trace_assert
(
inputs
.
size
()
==
item
.
inputs
.
size
(),
"input size mismatch"
);
trace_assert
(
apply_op
.
op
().
is_same
(
*
item
.
op
),
"operator mismatch"
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
trace_input
(
item
.
inputs
[
i
],
inputs
[
i
]);
}
ValueRefList
outputs
(
item
.
outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
item
.
outputs
.
size
();
++
i
)
{
outputs
[
i
]
=
trace_output
(
item
.
outputs
[
i
]);
}
return
outputs
;
}
ValueRefList
CompiledTransformation
::
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
traced_value
=
inputs
[
0
].
as
<
TracedValue
>
())
{
ValueRef
output
;
auto
&
var_accessor
=
traced_value
->
accessor
();
switch
(
get_attr
.
attr
())
{
case
GetAttr
::
Shape
:
output
=
traced_value
->
shape
();
break
;
case
GetAttr
::
Data
:
trace_assert
(
var_accessor
.
data_getter
,
"data unreadable"
);
output
=
DeviceValue
::
make
(
var_accessor
.
data_getter
());
break
;
case
GetAttr
::
Value
:
trace_assert
(
var_accessor
.
value_getter
,
"value unreadable"
);
output
=
HostValue
::
make
(
var_accessor
.
value_getter
());
break
;
case
GetAttr
::
DType
:
output
=
traced_value
->
dtype
();
break
;
case
GetAttr
::
Device
:
output
=
traced_value
->
comp_node
();
default:
break
;
}
return
{
output
};
}
else
{
return
imperative
::
apply
(
get_attr
,
inputs
);
}
}
ValueRefList
CompiledTransformation
::
apply_create_tensor
(
const
CreateTensor
&
create_tensor
,
Span
<
ValueRef
>
inputs
)
{
if
(
create_tensor
.
kind
()
==
CreateTensor
::
NoTrace
)
{
return
imperative
::
apply
(
create_tensor
,
inputs
);
}
auto
&
item
=
next_instruction
();
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
auto
input_id
=
item
.
inputs
[
0
];
auto
output_id
=
item
.
outputs
[
0
];
auto
tensor
=
imperative
::
apply
(
create_tensor
,
inputs
)[
0
];
trace_input
(
input_id
,
tensor
);
return
{
trace_output
(
output_id
)};
}
ValueRefList
CompiledTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_value
=
op
.
as
<
ApplyOp
>
())
{
auto
&
item
=
next_instruction
();
SmallVector
<
ValueRef
>
unwrapped_inputs
;
SmallVector
<
ValueRef
>
wrapped_inputs
;
trace_assert
(
inputs
.
size
()
==
item
.
inputs
.
size
(),
"input size mismatch"
);
trace_assert
(
op_value
->
op
().
is_same
(
*
item
.
op
),
"operator mismatch"
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
trace_input
(
item
.
inputs
[
i
],
inputs
[
i
]);
}
std
::
vector
<
ValueRef
>
outputs
;
for
(
auto
&&
output_id
:
item
.
outputs
)
{
outputs
.
push_back
(
trace_output
(
output_id
));
}
return
outputs
;
return
apply_op
(
*
op_value
,
inputs
);
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
if
(
create_tensor
->
kind
()
==
CreateTensor
::
NoTrace
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
auto
&
item
=
next_instruction
();
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
auto
input_id
=
item
.
inputs
[
0
];
auto
output_id
=
item
.
outputs
[
0
];
auto
tensor
=
imperative
::
apply
(
op
,
inputs
)[
0
];
trace_input
(
input_id
,
tensor
);
return
{
trace_output
(
output_id
)};
return
apply_create_tensor
(
*
create_tensor
,
inputs
);
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
if
(
auto
*
traced_value
=
inputs
[
0
].
as
<
TracedValue
>
())
{
ValueRef
output
;
auto
&
var
=
m_vars
[
traced_value
->
id
()];
auto
&
var_accessor
=
m_var_accessors
[
traced_value
->
id
()];
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
trace_assert
(
var_accessor
.
shape_getter
,
"shape unreadable"
);
output
=
ShapeValue
::
make
(
ValueShape
::
from
(
var_accessor
.
shape_getter
()));
break
;
case
GetAttr
::
Data
:
trace_assert
(
var_accessor
.
data_getter
,
"data unreadable"
);
output
=
DeviceValue
::
make
(
var_accessor
.
data_getter
());
break
;
case
GetAttr
::
Value
:
trace_assert
(
var_accessor
.
value_getter
,
"value unreadable"
);
output
=
HostValue
::
make
(
var_accessor
.
value_getter
());
break
;
case
GetAttr
::
DType
:
output
=
DTypeValue
::
make
(
var
.
dtype
);
break
;
case
GetAttr
::
Device
:
output
=
CompNodeValue
::
make
(
var
.
device
);
default:
break
;
}
return
{
output
};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
return
apply_get_attr
(
*
get_attr
,
inputs
);
}
else
if
(
auto
*
trace_mark_var
=
op
.
as
<
TraceMarkVar
>
())
{
auto
&
item
=
next_instruction
();
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
...
...
imperative/src/impl/value.cpp
浏览文件 @
ca001777
...
...
@@ -8,50 +8,58 @@ namespace mgb {
namespace
imperative
{
namespace
{
static
thread_local
size_t
nr_watched_values
=
0
;
static
thread_local
uint64_t
nr_values
=
0
;
static
thread_local
bool
recording_values
=
false
;
static
thread_local
std
::
vector
<
ValueWeakRef
>
recorded_values
;
static
/*thread_local*/
size_t
nr_watched_values
=
0
;
static
/*thread_local*/
uint64_t
nr_values
=
0
;
static
/*thread_local*/
bool
recording_values
=
false
;
static
/*thread_local*/
std
::
vector
<
ValueWeakRef
>
recorded_values
;
static
WeakValueMap
<
uint64_t
,
ValueWeakRef
>
registered_values
;
}
// namespace
ValueRef
::
storage_t
&
ValueRef
::
storage
()
const
{
if
(
!
m_storage
)
{
if
(
mgb_likely
(
!
m_storage
->
m_successor
.
m_storage
)
)
{
return
m_storage
;
}
if
(
auto
&
storage
=
m_storage
->
m_successor
.
m_storage
)
{
while
(
storage
->
m_successor
.
m_storage
)
{
storage
=
storage
->
m_successor
.
m_storage
;
}
return
storage
;
}
else
{
return
m_storage
;
while
(
m_storage
->
m_successor
.
m_storage
)
{
m_storage
=
m_storage
->
m_successor
.
m_storage
;
}
return
m_storage
;
}
const
Value
*
ValueRef
::
as
(
size_t
typecode
)
const
{
auto
&&
storage
=
this
->
storage
();
if
(
storage
->
m_typecode
!=
typecode
)
{
return
nullptr
;
}
return
static_cast
<
Value
*>
(
storage
.
get
());
}
bool
ValueRef
::
is
(
size_t
typecode
)
const
{
return
this
->
storage
()
->
m_typecode
==
typecode
;
}
TypedValueRef
<
DeviceValue
>
ValueRef
::
dev_tensor
()
const
{
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
Data
),
*
this
)[
0
].
as
_ref
<
DeviceValue
>
();
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
Data
),
*
this
)[
0
].
cast
_ref
<
DeviceValue
>
();
}
TypedValueRef
<
HostValue
>
ValueRef
::
numpy
()
const
{
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
Value
),
*
this
)[
0
].
as
_ref
<
HostValue
>
();
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
Value
),
*
this
)[
0
].
cast
_ref
<
HostValue
>
();
}
TypedValueRef
<
CompNodeValue
>
ValueRef
::
device
()
const
{
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
Device
),
*
this
)[
0
]
.
as
_ref
<
CompNodeValue
>
();
.
cast
_ref
<
CompNodeValue
>
();
}
TypedValueRef
<
ShapeValue
>
ValueRef
::
shape
()
const
{
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
Shape
),
*
this
)[
0
].
as
_ref
<
ShapeValue
>
();
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
Shape
),
*
this
)[
0
].
cast
_ref
<
ShapeValue
>
();
}
TypedValueRef
<
DTypeValue
>
ValueRef
::
dtype
()
const
{
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
DType
),
*
this
)[
0
].
as
_ref
<
DTypeValue
>
();
return
imperative
::
apply
(
GetAttr
(
GetAttr
::
DType
),
*
this
)[
0
].
cast
_ref
<
DTypeValue
>
();
}
TypedValueRef
<
StringValue
>
ValueRef
::
name
()
const
{
return
imperative
::
apply
(
GetName
(),
*
this
)[
0
].
as
_ref
<
StringValue
>
();
return
imperative
::
apply
(
GetName
(),
*
this
)[
0
].
cast
_ref
<
StringValue
>
();
}
bool
ValueRef
::
is_scalar
()
const
{
...
...
@@ -75,13 +83,15 @@ void ValueRef::unwatch() const {
}
ValueRef
ValueRef
::
unwrap
()
const
{
ValueRef
value
=
*
this
;
auto
&
context
=
Transformation
::
get_context
();
for
(
size_t
i
=
0
;
i
<
context
.
next_transformation
;
++
i
)
{
value
=
context
.
transformations
[
i
]
->
unwrap
(
value
);
if
(
mgb_unlikely
(
context
.
next_transformation
))
{
ValueRef
value
=
*
this
;
for
(
size_t
i
=
0
;
i
<
context
.
next_transformation
;
++
i
)
{
value
=
context
.
transformations
[
i
]
->
unwrap
(
value
);
}
return
value
;
}
mgb_assert
(
value
);
return
value
;
return
*
this
;
}
std
::
string
ValueRef
::
to_string
()
const
{
...
...
@@ -101,13 +111,11 @@ std::string ValueRef::raw_type() const {
return
types
[
m_storage
->
m_typecode
].
name
();
}
uint64_t
ValueRef
::
id
()
const
{
return
m_storage
?
m_storage
->
m_id
:
std
::
numeric_limits
<
uint64_t
>::
max
();
}
bool
ValueRef
::
watching
()
const
{
auto
storage
=
this
->
storage
();
return
storage
&&
storage
->
m_watching
;
if
(
!
m_storage
)
{
return
false
;
}
return
this
->
storage
()
->
m_watching
;
}
ValueRef
ValueRef
::
make
(
ValueRef
::
storage_t
storage
)
{
...
...
@@ -186,5 +194,96 @@ void Value::try_rethrow() {
}
}
inline
void
ValueRefList
::
init
(
size_t
nr_elems
)
{
m_size
=
nr_elems
;
if
(
m_size
>
0
)
{
if
(
m_size
==
1
)
{
m_data
=
inline_storage
();
}
else
{
auto
&
context
=
Transformation
::
get_context
();
m_data
=
context
.
allocator
.
allocate
(
m_size
);
}
for
(
size_t
i
=
0
;
i
<
m_size
;
++
i
)
{
new
(
m_data
+
i
)
ValueRef
();
}
}
else
{
m_data
=
nullptr
;
}
}
ValueRefList
::
ValueRefList
(
size_t
nr_elems
)
{
init
(
nr_elems
);
}
ValueRefList
::
ValueRefList
(
std
::
initializer_list
<
ValueRef
>
values
)
:
ValueRefList
(
values
.
begin
(),
values
.
end
())
{}
ValueRefList
::
ValueRefList
(
const
ValueRefList
&
rhs
)
:
ValueRefList
(
rhs
.
cbegin
(),
rhs
.
cend
())
{}
ValueRefList
::
ValueRefList
(
ValueRefList
&&
rhs
)
:
ValueRefList
()
{
m_size
=
rhs
.
m_size
;
if
(
rhs
.
m_data
==
rhs
.
inline_storage
())
{
m_data
=
inline_storage
();
new
(
m_data
)
ValueRef
();
m_data
[
0
]
=
std
::
move
(
rhs
.
m_data
[
0
]);
}
else
{
m_data
=
rhs
.
m_data
;
rhs
.
m_data
=
nullptr
;
rhs
.
m_size
=
0
;
}
}
ValueRefList
&
ValueRefList
::
operator
=
(
const
ValueRefList
&
rhs
)
{
if
(
this
==
&
rhs
)
{
return
*
this
;
}
clear
();
init
(
rhs
.
m_size
);
for
(
size_t
i
=
0
;
i
<
m_size
;
++
i
)
{
m_data
[
i
]
=
rhs
.
m_data
[
i
];
}
return
*
this
;
}
ValueRefList
&
ValueRefList
::
operator
=
(
ValueRefList
&&
rhs
)
{
if
(
this
==
&
rhs
)
{
return
*
this
;
}
clear
();
if
(
rhs
.
m_data
==
rhs
.
inline_storage
())
{
m_data
=
inline_storage
();
new
(
m_data
)
ValueRef
();
m_data
[
0
]
=
rhs
.
m_data
[
0
];
m_size
=
1
;
rhs
.
clear
();
}
else
{
m_data
=
rhs
.
m_data
;
m_size
=
rhs
.
m_size
;
rhs
.
m_data
=
nullptr
;
rhs
.
m_size
=
0
;
}
return
*
this
;
}
ValueRefList
::~
ValueRefList
()
{
clear
();
}
void
ValueRefList
::
clear
()
{
for
(
size_t
i
=
0
;
i
<
m_size
;
++
i
)
{
m_data
[
i
].
~
ValueRef
();
}
if
(
m_data
)
{
if
(
m_size
!=
1
)
{
Transformation
::
get_context
().
allocator
.
deallocate
(
m_data
,
m_size
);
}
else
{
mgb_assert
(
m_data
==
inline_storage
());
}
}
m_data
=
nullptr
;
m_size
=
0
;
}
}
// namespace imperative
}
// namespace mgb
imperative/src/include/megbrain/imperative/basic_operators.h
浏览文件 @
ca001777
...
...
@@ -24,8 +24,6 @@ namespace imperative {
class
GradKey
;
using
GenericFunction
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
Span
<
ValueRef
>
)
>
;
/**
* \brief apply an OpDef to values
*
...
...
@@ -37,7 +35,7 @@ private:
public:
ApplyOp
(
const
OpDef
&
op
)
:
m_op
(
op
)
{}
const
OpDef
&
op
()
{
return
m_op
;
}
const
OpDef
&
op
()
const
{
return
m_op
;
}
std
::
string
to_string
()
const
override
;
};
...
...
@@ -106,7 +104,7 @@ public:
* \param inputs contains host_storage and device_storage
* \return Args unpacked args
*/
Args
parse
(
Span
<
ValueRef
>
inputs
);
Args
parse
(
Span
<
ValueRef
>
inputs
)
const
;
Kind
kind
()
const
{
return
m_kind
;
}
CompNode
device
()
const
{
return
m_device
;
}
...
...
@@ -129,11 +127,11 @@ private:
public:
DTRCommand
(
Kind
kind
)
:
m_kind
(
kind
)
{}
Kind
kind
()
{
return
m_kind
;
}
Kind
kind
()
const
{
return
m_kind
;
}
std
::
string
to_string
()
const
override
;
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{};
}
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{};
}
};
// deprecated
...
...
@@ -141,9 +139,7 @@ class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> {
public:
std
::
string
to_string
()
const
override
;
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
};
/**
...
...
@@ -161,7 +157,7 @@ public:
std
::
string
to_string
()
const
override
;
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
inputs
.
as_array
<
1
>
()[
0
]};
}
};
...
...
imperative/src/include/megbrain/imperative/basic_values.h
浏览文件 @
ca001777
...
...
@@ -23,7 +23,7 @@ namespace imperative {
class
GradKey
;
using
GenericFunction
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
Span
<
ValueRef
>
)
>
;
using
GenericFunction
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
)
>
;
class
ShapeValue
final
:
public
MixinValueImpl
<
ShapeValue
,
ValueShape
>
{
public:
...
...
@@ -97,6 +97,10 @@ public:
ValueShape
shape
()
const
{
return
m_shape
;
}
CompNode
device
()
const
{
return
m_storage
.
comp_node
();
}
HostTensorStorage
storage
()
const
{
return
m_storage
;
}
DTypeScalar
item
()
const
{
mgb_assert
(
m_shape
.
is_scalar
());
return
DTypeScalar
::
make_from_raw
(
m_dtype
,
m_storage
.
ptr
());
}
HostTensorND
as_nd
(
bool
allow_scalar
=
false
)
const
;
};
...
...
imperative/src/include/megbrain/imperative/dispatch.h
浏览文件 @
ca001777
...
...
@@ -36,11 +36,11 @@ namespace imperative {
*
* \param op
* \param inputs
* \return
std::vector<ValueRef>
* \return
ValueRefList
*/
std
::
vector
<
ValueRef
>
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
std
::
vector
<
ValueRef
>
apply
(
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
);
std
::
vector
<
ValueRef
>
apply
(
Subgraph
graph
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply
(
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply
(
const
Subgraph
&
graph
,
Span
<
ValueRef
>
inputs
);
template
<
typename
...
TArgs
>
constexpr
bool
is_all_value_ref_v
=
...
...
@@ -49,7 +49,7 @@ constexpr bool is_all_value_ref_v =
template
<
typename
T
,
typename
...
TArgs
>
static
auto
apply
(
T
&&
op
,
TArgs
&&
...
args
)
->
std
::
enable_if_t
<
is_all_value_ref_v
<
TArgs
...
>
,
std
::
vector
<
ValueRef
>
>
{
->
std
::
enable_if_t
<
is_all_value_ref_v
<
TArgs
...
>
,
ValueRefList
>
{
ValueRef
args_arr
[
sizeof
...(
TArgs
)]
=
{
std
::
forward
<
TArgs
&&>
(
args
)...};
return
imperative
::
apply
(
std
::
forward
<
T
&&>
(
op
),
...
...
@@ -63,7 +63,7 @@ static auto apply(T&& op, TContainer&& container) -> std::enable_if_t<
ValueRef
>
&&
std
::
is_same_v
<
decltype
(
container
.
size
()),
size_t
>
&&
!
std
::
is_same_v
<
std
::
decay_t
<
TContainer
>
,
Span
<
ValueRef
>>
,
std
::
vector
<
ValueRef
>
>
{
ValueRefList
>
{
return
imperative
::
apply
(
std
::
forward
<
T
&&>
(
op
),
Span
<
ValueRef
>
(
container
.
data
(),
container
.
size
()));
}
...
...
imperative/src/include/megbrain/imperative/operator.h
浏览文件 @
ca001777
...
...
@@ -25,6 +25,8 @@
namespace
mgb
{
namespace
imperative
{
using
GenericFunction
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
)
>
;
/**
* \brief base class for all operators
*
...
...
@@ -49,25 +51,24 @@ public:
Kind
kind
()
const
{
return
m_kind
;
}
template
<
typename
U
>
U
*
as
()
const
{
const
U
*
as
()
const
{
if
(
m_typecode
!=
U
::
TYPE_CODE
)
{
return
nullptr
;
}
return
static_cast
<
U
*>
(
const_cast
<
Operator
*>
(
this
)
);
return
static_cast
<
const
U
*>
(
this
);
}
template
<
typename
U
>
bool
is
()
const
{
return
as
<
U
>
()
!=
nullptr
;
return
m_typecode
==
U
::
TYPE_CODE
;
}
template
<
Kind
kKind
>
bool
is
()
const
{
return
kind
()
==
kKind
;
}
template
<
typename
U
>
U
&
cast
()
const
{
U
*
ptr
=
as
<
U
>
();
mgb_assert
(
ptr
);
return
*
ptr
;
const
U
&
cast
()
const
{
mgb_assert
(
m_typecode
==
U
::
TYPE_CODE
);
return
static_cast
<
const
U
&>
(
*
this
);
}
virtual
std
::
string
to_string
()
const
=
0
;
...
...
@@ -77,9 +78,9 @@ public:
* implementation.
*
* \param inputs
* \return
std::vector<ValueRef>
* \return
ValueRefList
*/
virtual
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
;
virtual
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
;
std
::
type_index
type
()
const
{
return
registered_types
()[
m_typecode
];
}
...
...
imperative/src/include/megbrain/imperative/profiler.h
浏览文件 @
ca001777
...
...
@@ -123,7 +123,6 @@ public:
template
<
typename
T
,
typename
...
TArgs
>
static
uint64_t
record
(
TArgs
&&
...
args
)
{
auto
&
profiler
=
get_instance
();
// auto& mem_pool = get_mem_pool<T>();
if
constexpr
(
sm_debug
)
{
Status
expected
=
Running
;
mgb_assert
(
profiler
.
m_status
.
compare_exchange_strong
(
expected
,
Recording
));
...
...
imperative/src/include/megbrain/imperative/transformation.h
浏览文件 @
ca001777
...
...
@@ -18,6 +18,7 @@
#include "megbrain/common.h"
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/utils/allocator.h"
#include "megbrain/imperative/utils/local_ptr.h"
#include "megbrain/imperative/utils/span.h"
...
...
@@ -25,6 +26,7 @@ namespace mgb {
namespace
imperative
{
class
ValueRef
;
class
ValueRefList
;
class
Operator
;
class
Transformation
;
...
...
@@ -43,6 +45,7 @@ struct TransformationContext {
// TODO: deprecate TransformationGuard, let next_transformation == frames.size()
size_t
next_transformation
=
0
;
std
::
vector
<
TransformationFrame
>
frames
;
ForwardAllocator
<
ValueRef
>
allocator
;
};
/**
...
...
@@ -86,9 +89,9 @@ public:
*
* \param op
* \param inputs
* \return
std::vector<ValueRef>
* \return
ValueRefList
*/
virtual
std
::
vector
<
ValueRef
>
apply_transformation
(
virtual
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
=
0
;
virtual
ValueRef
unwrap
(
ValueRef
value
)
=
0
;
...
...
@@ -187,11 +190,12 @@ public:
std
::
swap
(
context
.
transformations
,
current_context
.
transformations
);
std
::
swap
(
context
.
scopes
,
current_context
.
scopes
);
std
::
swap
(
context
.
next_transformation
,
current_context
.
next_transformation
);
std
::
swap
(
context
.
allocator
,
current_context
.
allocator
);
}
static
TransformationContext
&
get_context
();
friend
std
::
vector
<
ValueRef
>
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
friend
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
friend
class
ValueRef
;
};
...
...
imperative/src/include/megbrain/imperative/transformations/eval.h
浏览文件 @
ca001777
...
...
@@ -23,16 +23,38 @@ public:
using
Handle
=
interpreter
::
Interpreter
::
Handle
;
using
Channel
=
interpreter
::
Interpreter
::
Channel
;
class
RAIIHandle
:
public
NonCopyableObj
{
private:
Handle
m_handle
=
nullptr
;
Channel
*
m_channel
=
nullptr
;
public:
RAIIHandle
(
Handle
handle
,
Channel
*
channel
)
:
m_handle
(
handle
),
m_channel
(
channel
)
{}
~
RAIIHandle
()
{
m_channel
->
del
(
m_handle
);
}
Handle
handle
()
const
{
return
m_handle
;
}
Channel
*
channel
()
const
{
return
m_channel
;
}
};
private:
std
::
shared_ptr
<
Handle
>
m_handle
=
nullptr
;
LocalPtr
<
RAIIHandle
>
m_handle
;
std
::
string
m_name
;
mutable
DTypeValue
::
ref_t
m_dtype
;
mutable
CompNodeValue
::
ref_t
m_comp_node
;
mutable
ShapeValue
::
ref_t
m_shape
;
public:
InterpreterInfo
()
=
default
;
InterpreterInfo
(
std
::
shared_ptr
<
Handle
>
handle
,
std
::
string
name
=
{})
InterpreterInfo
(
LocalPtr
<
RAII
Handle
>
handle
,
std
::
string
name
=
{})
:
m_handle
(
handle
),
m_name
(
name
)
{}
std
::
shared_ptr
<
Handle
>
handle
()
const
{
return
m_handle
;
}
const
LocalPtr
<
RAIIHandle
>&
handle
()
const
{
return
m_handle
;
}
DTypeValue
::
ref_t
dtype
()
const
;
CompNodeValue
::
ref_t
comp_node
()
const
;
ShapeValue
::
ref_t
shape
()
const
;
std
::
string
name
()
const
{
return
m_name
;
}
};
...
...
@@ -60,6 +82,7 @@ class InterpreterTransformation final : public Transformation {
public:
using
Interpreter
=
interpreter
::
Interpreter
;
using
Handle
=
Interpreter
::
Handle
;
using
SharedHandle
=
LocalPtr
<
InterpreterInfo
::
RAIIHandle
>
;
using
Channel
=
Interpreter
::
Channel
;
private:
...
...
@@ -71,7 +94,14 @@ public:
Channel
*
channel
()
{
return
m_channel
.
get
();
}
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_op
(
const
ApplyOp
&
apply_op
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_create_tensor
(
const
CreateTensor
&
create_tensor
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
...
...
@@ -81,14 +111,8 @@ public:
std
::
string
name
()
const
override
{
return
"InterpreterTransformation"
;
}
std
::
shared_ptr
<
Handle
>
share_handle
(
Handle
handle
)
{
return
std
::
shared_ptr
<
Handle
>
(
new
Handle
(
handle
),
[
channel
=
m_channel
.
get
()](
Handle
*
ptr
)
{
if
(
ptr
)
{
channel
->
del
(
*
ptr
);
delete
ptr
;
}
});
SharedHandle
share_handle
(
Handle
handle
)
{
return
SharedHandle
::
make
(
handle
,
m_channel
.
get
());
}
};
...
...
imperative/src/include/megbrain/imperative/transformations/grad.h
浏览文件 @
ca001777
...
...
@@ -34,9 +34,7 @@ struct BackwardGraphWithClosure {
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
backward_graph
,
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
);
void
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
void
operator
()(
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
bool
input_has_grad
(
size_t
i
)
{
return
backward_graph
->
input_has_grad
[
i
];
}
...
...
@@ -50,12 +48,11 @@ struct BackwardGraphWithClosure {
struct
CustomBackward
;
using
GradRuleFn
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
Span
<
ValueRef
>
inputs
,
CustomBackward
&
)
>
;
using
GradRuleFn
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
inputs
,
CustomBackward
&
)
>
;
struct
CustomBackward
{
using
BackwardFn
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
Span
<
ValueRef
>
)
>
;
using
BackwardRule
=
std
::
function
<
std
::
optional
<
std
::
vector
<
ValueRef
>
>
(
using
BackwardFn
=
std
::
function
<
ValueRefList
(
Span
<
ValueRef
>
)
>
;
using
BackwardRule
=
std
::
function
<
std
::
optional
<
ValueRefList
>
(
const
OpDef
&
,
Span
<
ValueRef
>
,
Span
<
bool
>
,
CustomBackward
&
)
>
;
BackwardFn
m_backward
;
SmallVector
<
bool
,
8
>
m_input_has_grad
;
...
...
@@ -65,9 +62,7 @@ struct CustomBackward {
SmallVector
<
OutputAttr
>
m_output_attrs
;
public:
void
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
void
operator
()(
ValueRefList
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
bool
input_has_grad
(
size_t
i
)
{
return
m_input_has_grad
[
i
];
}
bool
output_requires_grad
(
size_t
i
)
{
return
m_output_attrs
[
i
].
requires_grad
;
}
...
...
@@ -188,7 +183,7 @@ public:
std
::
string
to_string
()
const
override
;
bool
has_key
(
std
::
shared_ptr
<
GradKey
>
key
)
const
{
return
m_key
==
key
;
}
bool
has_key
(
const
std
::
shared_ptr
<
GradKey
>&
key
)
const
{
return
m_key
==
key
;
}
const
GradSlotPtr
&
slot_for
(
std
::
shared_ptr
<
GradKey
>
key
)
const
{
mgb_assert
(
m_key
==
key
);
...
...
@@ -287,7 +282,7 @@ public:
return
false
;
}
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
...
...
@@ -314,7 +309,7 @@ private:
public:
std
::
string
to_string
()
const
override
{
return
"DetachValue"
;
}
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
inputs
.
as_array
<
1
>
()[
0
]};
}
};
...
...
@@ -325,7 +320,7 @@ private:
public:
AttachGrad
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"AttachGradValue{key=%s}"
,
m_key
->
name
().
c_str
());
...
...
@@ -339,7 +334,7 @@ private:
public:
GradBackward
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GradBackwardValue{key=%s}"
,
m_key
->
name
().
c_str
());
...
...
@@ -352,13 +347,13 @@ private:
public:
IsAttachedTo
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"IsAttachedToValue{key=%s}"
,
m_key
->
name
().
c_str
());
}
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
BoolValue
::
make
(
false
)};
}
};
...
...
@@ -373,9 +368,9 @@ public:
SetGrad
(
std
::
shared_ptr
<
GradKey
>
key
,
GenericFunction
grad_fn
,
size_t
nr_inputs
)
:
m_key
(
key
),
m_grad_fn
(
grad_fn
),
m_nr_inputs
(
nr_inputs
)
{}
GenericFunction
grad_fn
()
{
return
m_grad_fn
;
}
GenericFunction
grad_fn
()
const
{
return
m_grad_fn
;
}
size_t
nr_inputs
()
{
return
m_nr_inputs
;
}
size_t
nr_inputs
()
const
{
return
m_nr_inputs
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"SetGradValue{key=%s}"
,
m_key
->
name
().
c_str
());
...
...
@@ -388,9 +383,7 @@ public:
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GetGradKeyValue{}"
);
}
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
};
class
GetBackwardColsure
...
...
@@ -401,7 +394,7 @@ private:
public:
GetBackwardColsure
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GetBackwardClosure{key=%s}"
,
m_key
->
name
().
c_str
());
...
...
imperative/src/include/megbrain/imperative/transformations/lazy.h
浏览文件 @
ca001777
...
...
@@ -81,7 +81,7 @@ public:
ComputingGraph
::
Options
&
options
()
{
return
m_graph
->
options
();
}
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
...
...
imperative/src/include/megbrain/imperative/transformations/scalar.h
浏览文件 @
ca001777
...
...
@@ -11,6 +11,7 @@
#pragma once
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
...
...
@@ -45,8 +46,10 @@ public:
*/
class
ScalarTransformation
final
:
public
Transformation
{
private:
ShapeValue
::
ref_t
m_empty_shape
;
// []
public:
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
...
...
imperative/src/include/megbrain/imperative/transformations/symbol.h
浏览文件 @
ca001777
...
...
@@ -50,7 +50,7 @@ private:
public:
SymbolTransformation
(
ComputingGraph
*
graph
)
:
m_graph
(
graph
)
{}
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
SmallVector
<
VarNode
*>
input_nodes
;
...
...
@@ -58,9 +58,9 @@ public:
input_nodes
.
push_back
(
input
.
cast
<
SymbolValue
>
().
node
());
}
auto
output_nodes
=
OpDef
::
apply_on_var_node
(
apply_op
->
op
(),
input_nodes
);
std
::
vector
<
ValueRef
>
outputs
;
for
(
auto
&&
output_node
:
output_nodes
)
{
outputs
.
push_back
(
SymbolValue
::
make
(
output_node
)
);
ValueRefList
outputs
(
output_nodes
.
size
())
;
for
(
size_t
i
=
0
;
i
<
output_nodes
.
size
();
++
i
)
{
outputs
[
i
]
=
SymbolValue
::
make
(
output_nodes
[
i
]
);
}
return
outputs
;
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
...
...
imperative/src/include/megbrain/imperative/transformations/tangent.h
0 → 100644
浏览文件 @
ca001777
/**
* \file imperative/src/include/megbrain/imperative/grad.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/value.h"
namespace
mgb
::
imperative
{
struct
TangentInfo
{
ValueRef
value
;
ValueRef
tangent
;
};
class
TangentTransformation
final
:
public
Transformation
{
public:
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
false
);
}
std
::
string
name
()
const
override
{
return
"Tangent"
;
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/transformations/trace.h
浏览文件 @
ca001777
...
...
@@ -126,25 +126,6 @@ public:
void
on_unwatch
()
override
{
value
().
unwatch
();
}
};
class
TracedInfo
{
private:
size_t
m_id
=
0
;
public:
TracedInfo
()
=
default
;
TracedInfo
(
size_t
id
)
:
m_id
(
id
)
{}
size_t
id
()
const
{
return
m_id
;
}
};
class
TracedValue
final
:
public
MixinValueImpl
<
TracedValue
,
TracedInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TracedValue{
\"
id
\"
=%zu}"
,
id
());
}
};
/**
* \brief trace operation sequence to TraceResult
*
...
...
@@ -202,7 +183,7 @@ public:
return
value
;
}
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
...
...
@@ -248,6 +229,40 @@ public:
std
::
function
<
DeviceTensorND
()
>
data_getter
;
std
::
function
<
HostTensorND
()
>
value_getter
;
std
::
function
<
void
(
DeviceTensorND
)
>
data_setter
;
std
::
function
<
void
(
std
::
exception_ptr
)
>
exc_setter
;
};
class
TracedInfo
{
private:
size_t
m_id
=
0
;
VarInfo
*
m_var
=
nullptr
;
VarAccessor
*
m_accessor
=
nullptr
;
mutable
ShapeValue
::
ref_t
m_shape
;
mutable
DTypeValue
::
ref_t
m_dtype
;
mutable
CompNodeValue
::
ref_t
m_comp_node
;
public:
TracedInfo
()
=
default
;
TracedInfo
(
size_t
id
,
VarInfo
*
var
,
VarAccessor
*
accessor
)
:
m_id
(
id
),
m_var
(
var
),
m_accessor
(
accessor
)
{}
size_t
id
()
const
{
return
m_id
;
}
ShapeValue
::
ref_t
shape
()
const
;
DTypeValue
::
ref_t
dtype
()
const
;
CompNodeValue
::
ref_t
comp_node
()
const
;
const
VarAccessor
&
accessor
()
const
;
void
set_exception
(
std
::
exception_ptr
exc
)
const
{
m_accessor
->
exc_setter
(
exc
);
}
};
class
TracedValue
final
:
public
MixinValueImpl
<
TracedValue
,
TracedInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TracedValue{
\"
id
\"
=%zu}"
,
id
());
}
};
private:
...
...
@@ -319,7 +334,14 @@ public:
TraceResult
::
SeqItem
&
next_instruction
();
std
::
vector
<
ValueRef
>
apply_transformation
(
ValueRefList
apply_op
(
const
ApplyOp
&
apply_op
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_get_attr
(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_create_tensor
(
const
CreateTensor
&
create_tensor
,
Span
<
ValueRef
>
inputs
);
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
void
on_unregister
()
noexcept
override
;
...
...
imperative/src/include/megbrain/imperative/utils/allocator.h
浏览文件 @
ca001777
...
...
@@ -36,12 +36,12 @@ private:
public:
Allocator
(
pool_type
*
pool
)
:
m_pool
(
pool
)
{}
T
*
allocate
(
size_type
n
)
{
pointer
allocate
(
size_type
n
)
{
mgb_assert
(
n
==
1
);
return
m_pool
->
alloc
(
sizeof
(
T
));
}
void
deallocate
(
pointer
*
p
,
size_type
n
)
{
void
deallocate
(
pointer
p
,
size_type
n
)
{
mgb_assert
(
n
==
1
);
m_pool
->
free
(
p
);
}
...
...
@@ -68,4 +68,114 @@ public:
bool
operator
!=
(
const
ThreadLocalAllocatorAdapter
&
rhs
)
const
{
return
false
;
}
};
}
// namespace mgb::imperative
\ No newline at end of file
template
<
typename
T
>
class
ForwardAllocator
{
public:
using
value_type
=
T
;
using
size_type
=
std
::
size_t
;
using
pointer
=
T
*
;
static
constexpr
size_t
alignment
=
alignof
(
T
);
static
constexpr
size_t
element_offset
=
sizeof
(
T
)
+
((
sizeof
(
T
)
%
alignment
)
?
0
:
(
alignment
-
sizeof
(
T
)
%
alignment
));
private:
struct
Block
{
std
::
unique_ptr
<
std
::
byte
[]
>
data
;
size_t
size
=
0
;
size_t
capacity
=
0
;
T
*
allocate
(
size_type
n
)
{
static_assert
(
element_offset
>
std
::
max
(
alignment
,
sizeof
(
T
)));
size_t
begin
=
size
;
size_t
end
=
begin
+
element_offset
*
n
;
if
(
end
>
capacity
)
{
return
nullptr
;
}
size
=
end
;
return
reinterpret_cast
<
T
*>
(
data
.
get
()
+
begin
);
}
void
reset
()
{
size
=
0
;
}
};
std
::
vector
<
Block
>
m_used
;
std
::
optional
<
Block
>
m_current
;
size_t
block_size
=
16
*
1024
*
1024
;
size_t
nr_allocated
=
0
;
private:
Block
allocate_block
()
{
block_size
*=
2
;
return
Block
{
std
::
make_unique
<
std
::
byte
[]
>
(
block_size
),
0
,
block_size
};
}
public:
pointer
allocate
(
size_type
n
)
{
if
(
!
m_current
)
{
m_current
.
emplace
(
allocate_block
());
}
pointer
pointer
=
m_current
->
allocate
(
n
);
while
(
pointer
==
nullptr
)
{
m_used
.
push_back
(
allocate_block
());
std
::
swap
(
m_used
.
back
(),
*
m_current
);
pointer
=
m_current
->
allocate
(
n
);
}
nr_allocated
++
;
return
pointer
;
}
void
deallocate
(
pointer
p
,
size_type
n
)
{
mgb_assert
(
nr_allocated
>
0
);
nr_allocated
--
;
}
void
clear
()
{
if
(
mgb_likely
(
m_used
.
empty
()))
{
// fastpath
if
(
m_current
)
{
m_current
->
reset
();
}
}
else
{
// trim
*
m_current
=
allocate_block
();
m_used
.
clear
();
}
mgb_assert
(
nr_allocated
==
0
);
}
bool
operator
==
(
const
ForwardAllocator
&
rhs
)
const
{
return
&
rhs
==
this
;
}
bool
operator
!=
(
const
ForwardAllocator
&
rhs
)
const
{
return
&
rhs
!=
this
;
}
};
template
<
typename
T
,
template
<
typename
>
typename
TAllocator
>
class
ProxyAllocator
{
public:
using
value_type
=
T
;
using
size_type
=
typename
TAllocator
<
T
>::
size_type
;
using
pointer
=
typename
TAllocator
<
T
>::
pointer
;
private:
TAllocator
<
T
>*
m_impl
;
public:
T
*
allocate
(
size_type
n
)
{
return
m_impl
->
allocate
(
n
);
}
void
deallocate
(
pointer
*
p
,
size_type
n
)
{
return
m_impl
->
deallocate
(
p
,
n
);
}
bool
operator
==
(
const
ProxyAllocator
<
T
,
TAllocator
>&
rhs
)
const
{
if
(
m_impl
==
rhs
.
m_impl
)
{
return
true
;
}
else
if
(
bool
(
m_impl
)
^
bool
(
rhs
.
m_impl
))
{
return
false
;
}
else
{
return
*
m_impl
==
*
rhs
.
m_impl
;
}
}
bool
operator
!=
(
const
ProxyAllocator
<
T
,
TAllocator
>&
rhs
)
const
{
return
!
((
*
this
)
==
rhs
);
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/local_ptr.h
浏览文件 @
ca001777
...
...
@@ -16,6 +16,8 @@
#include "megbrain/imperative/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
#define MGB_FAT_LOCAL_PTR 0
namespace
mgb
::
imperative
{
template
<
typename
T
>
...
...
@@ -52,6 +54,8 @@ private:
}
}
size_t
ref_count
()
const
{
return
m_ref_count
;
}
template
<
typename
U
>
friend
class
LocalPtr
;
...
...
@@ -88,14 +92,24 @@ public:
using
storage_t
=
LocalPtrStorage
<
T
>
;
using
pool_t
=
MemPool
<
storage_t
>
;
using
weak_type
=
LocalWeakPtr
<
T
>
;
using
pointer_t
=
T
*
;
private:
storage_t
*
m_storage
=
nullptr
;
#if MGB_FAT_LOCAL_PTR
pointer_t
m_pointer
=
nullptr
;
#endif
// (m_storage == nullptr) == (m_pointer == nullptr)
void
emplace
(
storage_t
*
ptr
)
{
if
(
ptr
)
{
ptr
->
inc_ref
();
m_storage
=
ptr
;
#if MGB_FAT_LOCAL_PTR
m_pointer
=
ptr
->
m_pointer
;
#endif
}
}
...
...
@@ -103,8 +117,22 @@ private:
public:
LocalPtr
()
=
default
;
LocalPtr
(
const
LocalPtr
&
rhs
)
{
(
*
this
)
=
rhs
;
}
LocalPtr
(
LocalPtr
&&
rhs
)
{
(
*
this
)
=
std
::
move
(
rhs
);
}
LocalPtr
(
const
LocalPtr
&
rhs
)
{
auto
storage
=
rhs
.
m_storage
;
if
(
storage
)
{
storage
->
inc_ref
();
}
m_storage
=
storage
;
#if MGB_FAT_LOCAL_PTR
m_pointer
=
rhs
.
m_pointer
;
#endif
}
LocalPtr
(
LocalPtr
&&
rhs
)
{
std
::
swap
(
m_storage
,
rhs
.
m_storage
);
#if MGB_FAT_LOCAL_PTR
std
::
swap
(
m_pointer
,
rhs
.
m_pointer
);
#endif
}
LocalPtr
&
operator
=
(
const
LocalPtr
&
rhs
)
{
if
(
this
==
&
rhs
)
{
return
*
this
;
...
...
@@ -115,9 +143,11 @@ public:
}
if
(
m_storage
)
{
m_storage
->
dec_ref
();
// rhs.m_storage may be invalid here
}
m_storage
=
storage
;
#if MGB_FAT_LOCAL_PTR
m_pointer
=
rhs
.
m_pointer
;
#endif
return
*
this
;
}
LocalPtr
&
operator
=
(
LocalPtr
&&
rhs
)
{
...
...
@@ -125,6 +155,9 @@ public:
return
*
this
;
}
std
::
swap
(
m_storage
,
rhs
.
m_storage
);
#if MGB_FAT_LOCAL_PTR
std
::
swap
(
m_pointer
,
rhs
.
m_pointer
);
#endif
rhs
.
reset
();
return
*
this
;
}
...
...
@@ -186,10 +219,11 @@ public:
T
&
operator
*
()
const
{
return
*
get
();
}
T
*
get
()
const
{
if
((
!
m_storage
)
||
!
m_storage
->
m_pointer
)
{
return
nullptr
;
}
return
m_storage
->
m_pointer
;
#if MGB_FAT_LOCAL_PTR
return
m_pointer
;
#else
return
m_storage
?
m_storage
->
m_pointer
:
nullptr
;
#endif
}
T
*
operator
->
()
const
{
return
get
();
}
...
...
@@ -202,6 +236,9 @@ public:
if
(
m_storage
)
{
m_storage
->
dec_ref
();
m_storage
=
nullptr
;
#if MGB_FAT_LOCAL_PTR
m_pointer
=
nullptr
;
#endif
}
}
...
...
imperative/src/include/megbrain/imperative/utils/mempool.h
浏览文件 @
ca001777
...
...
@@ -49,8 +49,8 @@ public:
instance
=
std
::
make_unique
<
MemPool
<
T
>>
();
sm_instance
=
instance
.
get
();
}
mgb_assert
(
sm_instance
);
}
return
*
sm_instance
;
}
};
...
...
@@ -62,9 +62,9 @@ std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>>
MemPoolUtils
<
T
>::
sm_instances
;
template
<
typename
T
>
thread_local
MemPool
<
T
>*
MemPoolUtils
<
T
>::
tm_instance
;
thread_local
MemPool
<
T
>*
MemPoolUtils
<
T
>::
tm_instance
=
nullptr
;
template
<
typename
T
>
MemPool
<
T
>*
MemPoolUtils
<
T
>::
sm_instance
;
MemPool
<
T
>*
MemPoolUtils
<
T
>::
sm_instance
=
nullptr
;
}
// namespace mgb::imperative
\ No newline at end of file
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/value_shape.h
浏览文件 @
ca001777
...
...
@@ -95,6 +95,8 @@ struct ValueShape {
}
return
true
;
}
bool
operator
!=
(
const
ValueShape
&
rhs
)
const
{
return
!
operator
==
(
rhs
);
}
};
static_assert
(
sizeof
(
size_t
)
>=
sizeof
(
int
));
...
...
imperative/src/include/megbrain/imperative/value.h
浏览文件 @
ca001777
...
...
@@ -47,6 +47,17 @@ class StringValue;
class
Operator
;
class
ValueRefList
;
template
<
typename
T
>
class
Type
{
private:
const
size_t
m_code
=
T
::
TYPE_CODE
;
public:
inline
size_t
code
()
const
{
return
m_code
;
}
};
/**
* \brief an smart reference of value
*
...
...
@@ -64,8 +75,9 @@ public:
protected:
mutable
storage_t
m_storage
;
size_t
m_id
=
std
::
numeric_limits
<
size_t
>::
max
();
ValueRef
(
storage_t
storage
)
{
m_storage
=
storage
;
}
inline
ValueRef
(
storage_t
storage
);
private:
/**
...
...
@@ -75,6 +87,10 @@ private:
*/
storage_t
&
storage
()
const
;
const
Value
*
as
(
size_t
typecode
)
const
;
bool
is
(
size_t
typecode
)
const
;
public:
ValueRef
()
=
default
;
...
...
@@ -86,7 +102,7 @@ public:
* \return false if empty or type of value is not TValue
*/
template
<
typename
TValue
>
bool
is
(
)
const
;
inline
bool
is
(
Type
<
TValue
>
type
=
{}
)
const
;
/**
* \brief try cast value as target type
...
...
@@ -95,7 +111,7 @@ public:
* \return TValue* raw pointer if success, otherwise nullptr
*/
template
<
typename
TValue
>
const
TValue
*
as
(
)
const
;
inline
const
TValue
*
as
(
Type
<
TValue
>
type
=
{}
)
const
;
/**
* \brief cast value to target type
...
...
@@ -104,7 +120,7 @@ public:
* \return TValue& reference of value
*/
template
<
typename
TValue
>
const
TValue
&
cast
(
)
const
;
inline
const
TValue
&
cast
(
Type
<
TValue
>
type
=
{}
)
const
;
/**
* \brief like as(), but returns TypedValueRef instead
...
...
@@ -113,7 +129,13 @@ public:
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template
<
typename
TValue
>
inline
TypedValueRef
<
TValue
>
as_ref
()
const
;
inline
TypedValueRef
<
TValue
>
as_ref
(
Type
<
TValue
>
type
=
{})
const
;
template
<
typename
TValue
>
inline
TypedValueRef
<
TValue
>
cast_ref
(
Type
<
TValue
>
type
=
{})
const
;
template
<
typename
TValue
>
void
on_cast_failure
()
const
;
operator
bool
()
const
{
return
bool
(
m_storage
);
}
...
...
@@ -132,7 +154,7 @@ public:
ValueRef
unwrap
()
const
;
std
::
string
to_string
()
const
;
std
::
string
raw_type
()
const
;
uint64_t
id
()
const
;
uint64_t
id
()
const
{
return
m_id
;
}
size_t
hash
()
const
{
return
id
();
}
static
ValueRef
make
(
storage_t
storage
);
...
...
@@ -144,7 +166,7 @@ public:
friend
class
TypedValueRef
;
template
<
typename
T
>
friend
class
ValueImpl
;
friend
std
::
vector
<
ValueRef
>
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
friend
ValueRefList
apply
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
);
};
template
<
>
...
...
@@ -244,7 +266,7 @@ public:
using
ref_t
=
TypedValueRef
<
T
>
;
using
weak_ref_t
=
TypedValueWeakRef
<
T
>
;
static
inline
size_t
TYPE_CODE
=
[]
{
return
register_type
(
typeid
(
T
));
}();
static
inline
const
size_t
TYPE_CODE
=
[]
{
return
register_type
(
typeid
(
T
));
}();
/**
* \brief helper function for construct a value
...
...
@@ -254,7 +276,7 @@ public:
* \return TypedValueRef<T> reference of value
*/
template
<
typename
...
TArgs
>
static
TypedValueRef
<
T
>
make
(
TArgs
&&
...
args
)
{
static
MGB_NOINLINE
TypedValueRef
<
T
>
make
(
TArgs
&&
...
args
)
{
static_assert
(
std
::
is_final_v
<
T
>
);
return
ValueRef
::
make
(
LocalPtr
<
Value
>::
make
<
T
>
(
std
::
forward
<
TArgs
&&>
(
args
)...));
}
...
...
@@ -279,46 +301,60 @@ public:
bool
eq
(
const
TMixin
&
value
)
const
{
return
((
const
TMixin
&
)
*
this
)
==
value
;
}
};
inline
ValueRef
::
ValueRef
(
storage_t
storage
)
{
// mgb_assert(storage);
m_storage
=
storage
;
m_id
=
m_storage
->
m_id
;
}
template
<
typename
TValue
>
const
TValue
*
ValueRef
::
as
(
)
const
{
inline
const
TValue
*
ValueRef
::
as
(
Type
<
TValue
>
type
)
const
{
static_assert
(
std
::
is_base_of_v
<
ValueImpl
<
TValue
>
,
TValue
>
);
auto
storage
=
this
->
storage
();
if
(
!
storage
)
{
return
nullptr
;
}
if
(
storage
->
m_typecode
!=
TValue
::
TYPE_CODE
)
{
return
nullptr
;
}
return
static_cast
<
TValue
*>
(
storage
.
get
());
return
static_cast
<
const
TValue
*>
(
as
(
type
.
code
()));
}
template
<
typename
TValue
>
const
TValue
&
ValueRef
::
cast
()
const
{
auto
*
ptr
=
as
<
TValue
>
();
if
(
!
ptr
)
{
// if this is ErrorValue, rethrow directly
storage
()
->
try_rethrow
();
mgb_assert
(
ptr
,
"expect type %s, got %s"
,
typeid
(
TValue
).
name
(),
to_string
().
c_str
());
inline
const
TValue
&
ValueRef
::
cast
(
Type
<
TValue
>
type
)
const
{
auto
*
ptr
=
as
<
TValue
>
(
type
);
if
(
mgb_unlikely
(
!
ptr
))
{
on_cast_failure
<
TValue
>
();
}
return
*
ptr
;
return
static_cast
<
const
TValue
&>
(
*
ptr
);
}
template
<
typename
TValue
>
inline
bool
ValueRef
::
is
(
Type
<
TValue
>
type
)
const
{
return
is
(
type
.
code
());
}
template
<
typename
TValue
>
bool
ValueRef
::
is
()
const
{
auto
*
ptr
=
as
<
TValue
>
();
return
ptr
!=
nullptr
;
inline
TypedValueRef
<
TValue
>
ValueRef
::
as_ref
(
Type
<
TValue
>
type
)
const
{
if
(
!
is
<
TValue
>
(
type
))
{
return
{};
}
return
TypedValueRef
<
TValue
>
(
*
this
);
}
template
<
typename
TValue
>
TypedValueRef
<
TValue
>
ValueRef
::
as_ref
(
)
const
{
if
(
!
is
<
TValue
>
()
)
{
inline
TypedValueRef
<
TValue
>
ValueRef
::
cast_ref
(
Type
<
TValue
>
type
)
const
{
if
(
!
m_storage
)
{
return
{};
}
if
(
mgb_unlikely
(
!
is
<
TValue
>
(
type
)))
{
on_cast_failure
<
TValue
>
();
}
return
TypedValueRef
<
TValue
>
(
*
this
);
}
template
<
typename
TValue
>
void
ValueRef
::
on_cast_failure
()
const
{
// if this is ErrorValue, rethrow directly
storage
()
->
try_rethrow
();
mgb_assert
(
storage
()
->
m_typecode
!=
TValue
::
TYPE_CODE
,
"expect type %s, got %s"
,
typeid
(
TValue
).
name
(),
to_string
().
c_str
());
}
/**
* \brief ValueRef with concrete type, convenient for dereference
*
...
...
@@ -361,11 +397,87 @@ private:
public:
TypedValueWeakRef
(
ValueRef
value
)
:
ValueWeakRef
(
value
)
{}
TypedValueWeakRef
(
ValueWeakRef
value
)
:
ValueWeakRef
(
value
)
{}
TypedValueRef
<
T
>
lock
()
{
return
ValueWeakRef
::
lock
().
template
as_ref
<
T
>();
}
TypedValueRef
<
T
>
lock
()
{
auto
value
=
ValueWeakRef
::
lock
();
if
(
value
)
{
return
value
.
template
as_ref
<
T
>();
}
else
{
return
{};
}
}
};
// TODO: add proxy value type, which is meant to be reset in the end
class
ValueRefList
{
private:
ValueRef
*
m_data
=
nullptr
;
size_t
m_size
=
0
;
std
::
aligned_storage_t
<
sizeof
(
ValueRef
),
alignof
(
ValueRef
)
>
m_storage
;
private:
void
init
(
size_t
nr_elems
);
ValueRef
*
inline_storage
()
{
return
reinterpret_cast
<
ValueRef
*>
(
&
m_storage
);
}
public:
ValueRefList
()
=
default
;
ValueRefList
(
size_t
nr_elems
);
ValueRefList
(
ValueRef
item
);
ValueRefList
(
std
::
initializer_list
<
ValueRef
>
values
);
template
<
typename
TIterator
>
ValueRefList
(
TIterator
begin
,
TIterator
end
);
ValueRefList
(
const
ValueRefList
&
rhs
);
ValueRefList
(
ValueRefList
&&
rhs
);
ValueRefList
&
operator
=
(
const
ValueRefList
&
rhs
);
ValueRefList
&
operator
=
(
ValueRefList
&&
rhs
);
~
ValueRefList
();
void
clear
();
ValueRef
*
begin
()
{
return
m_data
;
}
ValueRef
*
end
()
{
return
m_data
+
m_size
;
}
const
ValueRef
*
cbegin
()
const
{
return
m_data
;
}
const
ValueRef
*
cend
()
const
{
return
m_data
+
m_size
;
}
size_t
size
()
const
{
return
m_size
;
}
ValueRef
&
at
(
size_t
idx
)
{
mgb_assert
(
idx
<
m_size
);
return
m_data
[
idx
];
}
const
ValueRef
&
at
(
size_t
idx
)
const
{
mgb_assert
(
idx
<
m_size
);
return
m_data
[
idx
];
}
ValueRef
&
operator
[](
size_t
idx
)
{
return
m_data
[
idx
];
}
const
ValueRef
&
operator
[](
size_t
idx
)
const
{
return
m_data
[
idx
];
}
ValueRef
*
data
()
{
return
m_data
;
}
const
ValueRef
*
data
()
const
{
return
m_data
;
}
bool
empty
()
const
{
return
m_size
==
0
;
}
ValueRef
&
front
()
{
mgb_assert
(
m_size
>
1
);
return
m_data
[
0
];
}
ValueRef
&
back
()
{
mgb_assert
(
m_size
>
1
);
return
m_data
[
m_size
-
1
];
}
};
template
<
typename
TIterator
>
ValueRefList
::
ValueRefList
(
TIterator
begin
,
TIterator
end
)
:
ValueRefList
(
end
-
begin
)
{
for
(
size_t
i
=
0
;
i
<
m_size
;
++
i
)
{
m_data
[
i
]
=
*
(
begin
+
i
);
}
}
inline
ValueRefList
::
ValueRefList
(
ValueRef
item
)
:
m_data
(
inline_storage
()),
m_size
(
1
)
{
new
(
m_data
)
ValueRef
();
m_data
[
0
]
=
std
::
move
(
item
);
}
/*class ValueRefList : public SmallVector<ValueRef, 1> {
public:
using SmallVector::SmallVector;
};*/
}
// namespace imperative
}
// namespace mgb
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录