Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b11d4430
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b11d4430
编写于
6月 08, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/jit): support trace withouthost mode
GitOrigin-RevId: 09b29e3dac44a4e4330f2ceb10da7d55df772466
上级
3116e9f7
变更
27
展开全部
隐藏空白更改
内联
并排
Showing
27 changed file
with
1569 addition
and
40 deletion
+1569
-40
imperative/python/megengine/core/_trace_option.py
imperative/python/megengine/core/_trace_option.py
+13
-0
imperative/python/megengine/jit/__init__.py
imperative/python/megengine/jit/__init__.py
+2
-0
imperative/python/megengine/jit/partial_tracing.py
imperative/python/megengine/jit/partial_tracing.py
+224
-0
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+394
-2
imperative/python/megengine/jit/xla_backend.py
imperative/python/megengine/jit/xla_backend.py
+201
-0
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+2
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+2
-4
imperative/python/src/graph_rt.h
imperative/python/src/graph_rt.h
+1
-1
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+321
-12
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+2
-0
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+91
-1
imperative/src/impl/basic_operators.cpp
imperative/src/impl/basic_operators.cpp
+4
-0
imperative/src/impl/basic_values.cpp
imperative/src/impl/basic_values.cpp
+4
-0
imperative/src/impl/op_def.cpp
imperative/src/impl/op_def.cpp
+4
-0
imperative/src/impl/ops/opr_attr.cpp
imperative/src/impl/ops/opr_attr.cpp
+15
-2
imperative/src/impl/transformations/eval.cpp
imperative/src/impl/transformations/eval.cpp
+5
-0
imperative/src/impl/transformations/grad.cpp
imperative/src/impl/transformations/grad.cpp
+26
-1
imperative/src/impl/transformations/lazy.cpp
imperative/src/impl/transformations/lazy.cpp
+3
-0
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+105
-5
imperative/src/impl/value.cpp
imperative/src/impl/value.cpp
+4
-0
imperative/src/include/megbrain/imperative/basic_operators.h
imperative/src/include/megbrain/imperative/basic_operators.h
+8
-0
imperative/src/include/megbrain/imperative/basic_values.h
imperative/src/include/megbrain/imperative/basic_values.h
+19
-0
imperative/src/include/megbrain/imperative/op_def.h
imperative/src/include/megbrain/imperative/op_def.h
+2
-0
imperative/src/include/megbrain/imperative/ops/opr_attr.h
imperative/src/include/megbrain/imperative/ops/opr_attr.h
+6
-3
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+46
-1
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+61
-5
imperative/src/include/megbrain/imperative/value.h
imperative/src/include/megbrain/imperative/value.h
+4
-3
未找到文件。
imperative/python/megengine/core/_trace_option.py
浏览文件 @
b11d4430
...
@@ -8,6 +8,8 @@ _use_symbolic_shape = False
...
@@ -8,6 +8,8 @@ _use_symbolic_shape = False
if
os
.
environ
.
get
(
"MEGENGINE_USE_SYMBOLIC_SHAPE"
):
if
os
.
environ
.
get
(
"MEGENGINE_USE_SYMBOLIC_SHAPE"
):
_use_symbolic_shape
=
True
_use_symbolic_shape
=
True
_use_xla_backend
=
False
def
use_symbolic_shape
()
->
bool
:
def
use_symbolic_shape
()
->
bool
:
r
"""Returns whether tensor.shape returns a tensor instead of a tuple"""
r
"""Returns whether tensor.shape returns a tensor instead of a tuple"""
...
@@ -22,4 +24,15 @@ def set_symbolic_shape(option: bool):
...
@@ -22,4 +24,15 @@ def set_symbolic_shape(option: bool):
return
_org
return
_org
def
use_xla_backend
()
->
bool
:
return
_use_xla_backend
def
set_use_xla_backend
(
option
:
bool
)
->
bool
:
global
_use_xla_backend
_org
=
_use_xla_backend
_use_xla_backend
=
option
return
_org
set_cpp_use_symbolic_shape
(
use_symbolic_shape
)
set_cpp_use_symbolic_shape
(
use_symbolic_shape
)
imperative/python/megengine/jit/__init__.py
浏览文件 @
b11d4430
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
.dtr_config
import
DTRConfig
from
.dtr_config
import
DTRConfig
from
.graph_opt_config
import
GraphOptimizationConfig
from
.graph_opt_config
import
GraphOptimizationConfig
from
.partial_tracing
import
partial_trace
from
.sublinear_memory_config
import
SublinearMemoryConfig
from
.sublinear_memory_config
import
SublinearMemoryConfig
from
.tracing
import
TraceError
,
exclude_from_trace
,
trace
from
.tracing
import
TraceError
,
exclude_from_trace
,
trace
from
.xla_backend
import
xla_trace
imperative/python/megengine/jit/partial_tracing.py
0 → 100644
浏览文件 @
b11d4430
from
collections
import
OrderedDict
from
typing
import
Sequence
from
..core._imperative_rt.core2
import
add_backward_callback
as
_add_backward_callback
from
..core._imperative_rt.core2
import
get_grad_slot
,
get_handle_id
from
..tensor
import
Tensor
from
.tracing
import
trace
from
.xla_backend
import
xla_trace
def
_process_fwd_bwd_trace_result
(
fwd
,
bwd
,
inp_grad_map
,
out_grad_map
):
# partial_trace will record op sequences for forward/backward respectively, and get two TraceResult objects after tracing.
# But the inputs/outputs of backward graph are unknown. This function will determine the inputs and outputs of the backward graph
# var.handle_id is id of value ref. It's used to find the tensors used in both forward and backward calculation.
# inp_grad_map, key: handle id of forward inputs, value: handle id of grads of forward inputs.
# out_grad_map, key: handle id of foward outputs, value: handle id of grads of forward outputs.
fwd_features
=
set
([
t
.
handle_id
for
t
in
fwd
.
_trace
.
vars
])
bwd_features
=
set
([
t
.
handle_id
for
t
in
bwd
.
_trace
.
vars
])
keep_vars
=
fwd_features
.
intersection
(
bwd_features
)
# some intermediate vars produced by forward, and will be used in backward.
current
=
max
(
fwd
.
out_list
)
+
1
saved_feature_map
=
OrderedDict
()
saved_featrues
=
[]
# mark keep_vars as forward outputs
for
var
in
fwd
.
_trace
.
vars
:
if
(
var
.
handle_id
in
keep_vars
and
var
.
data_required
and
len
(
var
.
out_mark
)
==
0
and
var
.
kind
not
in
[
"const"
,
"external"
]
):
keep_vars
.
remove
(
var
.
handle_id
)
fwd
.
_trace
.
mark_output
(
current
,
var
.
id
)
saved_feature_map
[
var
.
handle_id
]
=
current
saved_featrues
.
append
(
current
)
current
+=
1
fwd
.
keeped_activation
=
saved_featrues
bwd_inp_idx
=
0
bwd_out_idx
=
0
bwd_dys
=
[]
bwd_inps
=
[
-
1
]
*
len
(
saved_feature_map
)
saved_feature_handle_id
=
list
(
saved_feature_map
.
keys
())
dy_ids
=
list
(
out_grad_map
.
values
())
# handle_id of grad of forward output
inp_grad_ids
=
list
(
inp_grad_map
.
values
())
# handle_id of grad of forward input
bwd_dys
=
[
-
1
]
*
len
(
dy_ids
)
bwd_outputs
=
[
-
1
]
*
len
(
inp_grad_ids
)
# dy_ids + saved_feature_map are backward inputs
# inp_grad_ids are backward outputs
# mark inputs/outputs for backward
for
var
in
bwd
.
_trace
.
vars
:
if
var
.
handle_id
in
dy_ids
and
var
.
kind
==
"external"
:
bwd
.
_trace
.
mark_input
(
bwd_inp_idx
,
var
.
id
)
idx
=
dy_ids
.
index
(
var
.
handle_id
)
bwd_dys
[
idx
]
=
bwd_inp_idx
bwd_inp_idx
+=
1
elif
var
.
handle_id
in
saved_feature_map
and
var
.
kind
==
"external"
:
bwd
.
_trace
.
mark_input
(
bwd_inp_idx
,
var
.
id
)
bwd_inps
[
saved_feature_handle_id
.
index
(
var
.
handle_id
)]
=
bwd_inp_idx
bwd_inp_idx
+=
1
if
var
.
handle_id
in
inp_grad_ids
and
var
.
data_required
:
bwd_outputs
[
inp_grad_ids
.
index
(
var
.
handle_id
)]
=
bwd_out_idx
bwd
.
_trace
.
mark_output
(
bwd_out_idx
,
var
.
id
)
bwd_out_idx
+=
1
# assert -1 not in bwd_dys
assert
-
1
not
in
bwd_inps
for
var
in
fwd
.
_trace
.
vars
:
if
not
var
.
out_mark
:
var
.
data_required
=
False
# assert -1 not in bwd_outputs
bwd
.
setup_io_without_trace
(
bwd_dys
+
bwd_inps
,
bwd_outputs
)
bwd
.
setup_without_host
()
def
check_external
(
trace_obj
):
for
var
in
trace_obj
.
vars
:
if
var
.
kind
==
"external"
and
not
var
.
inp_mark
:
raise
RuntimeError
(
"have unknown input in trace result"
)
check_external
(
fwd
)
check_external
(
bwd
)
JIT_BACKEND
=
{
"default"
:
trace
,
"xla"
:
xla_trace
}
def
partial_trace
(
func
=
None
,
*
,
backend
=
"default"
,
without_host
=
True
,
**
trace_options
):
assert
backend
in
JIT_BACKEND
assert
without_host
,
"partial_trace only support without_host mode currently!"
def
wrapper
(
func
):
trace_obj
=
JIT_BACKEND
[
backend
](
func
,
without_host
=
without_host
,
**
trace_options
)
trace_options
[
"capture_as_const"
]
=
False
backward_trace_obj
=
JIT_BACKEND
[
backend
](
None
,
without_host
=
without_host
,
**
trace_options
)
backward_trace_obj
.
check_external
=
(
False
# check if there are unknown external vars after tracing.
)
trace_obj
.
overall
=
False
# if trace overall train step
backward_trace_obj
.
overall
=
False
trace_obj
.
_trace
.
remove_unused_data_required
=
False
backward_trace_obj
.
_trace
.
remove_unused_data_required
=
False
inp_grad_maps
=
OrderedDict
()
# x, dx map
out_grad_maps
=
OrderedDict
()
# y, dy map
traced
=
False
# if wrapped function has been traced
compiled
=
False
# if wrapped function has been compiled
custom_autodiff
=
None
outdef
=
None
# treedef of forward return value
from
..core.autodiff.grad
import
Function
class
CustomAutodiff
(
Function
):
def
__init__
(
self
,
fwd
,
bwd
):
self
.
fwd
=
fwd
self
.
bwd
=
bwd
del
fwd
.
outdef
self
.
keeped_features
=
[]
def
forward
(
self
,
*
args
):
rst
=
self
.
fwd
(
*
args
)
keeped_features
=
rst
[
-
1
]
if
not
isinstance
(
keeped_features
,
Sequence
):
keeped_features
=
tuple
([
keeped_features
])
else
:
keeped_features
=
tuple
(
keeped_features
)
self
.
keeped_features
=
keeped_features
return
rst
[
0
]
def
get_keeped_features
(
self
):
rst
=
self
.
keeped_features
del
self
.
keeped_features
return
rst
def
backward
(
self
,
*
output_grads
):
output_grads
=
tuple
([
i
for
i
in
output_grads
if
i
is
not
None
])
return
self
.
bwd
(
*
(
output_grads
+
self
.
get_keeped_features
()))
class
CustomFwd
:
def
__init__
(
self
,
fwd
,
bwd
):
self
.
fwd
=
fwd
self
.
bwd
=
bwd
def
__call__
(
self
,
*
args
):
rst
=
self
.
fwd
(
*
args
)
if
self
.
fwd
.
keeped_activation
:
keeped_features
=
rst
[
-
1
]
if
not
isinstance
(
keeped_features
,
Sequence
):
keeped_features
=
tuple
([
keeped_features
])
else
:
keeped_features
=
tuple
(
keeped_features
)
self
.
keeped_features
=
keeped_features
return
rst
[
0
]
else
:
return
rst
def
wrapped_func
(
*
args
,
**
kwargs
):
from
..traced_module.pytree
import
tree_flatten
from
..module
import
Module
nonlocal
traced
nonlocal
compiled
nonlocal
custom_autodiff
nonlocal
outdef
if
not
traced
:
traced
=
True
fargs
=
trace_obj
.
flatten_inputs
(
*
args
,
**
kwargs
)
for
t
in
fargs
:
inp_grad_maps
[
t
]
=
get_grad_slot
(
t
)
del
fargs
def
exit_trace
():
backward_trace_obj
.
_trace
.
exit
()
new_dict
=
{}
for
k
,
v
in
inp_grad_maps
.
items
():
if
v
is
not
None
:
new_dict
[
get_handle_id
(
k
)]
=
get_handle_id
(
v
.
grad
)
else
:
new_dict
[
get_handle_id
(
k
)]
=
-
1
inp_grad_maps
.
clear
()
inp_grad_maps
.
update
(
new_dict
)
_add_backward_callback
(
exit_trace
)
ret
=
trace_obj
(
*
args
)
rlist
,
outdef
=
tree_flatten
(
ret
)
for
t
in
rlist
:
out_grad_maps
[
t
]
=
get_grad_slot
(
t
)
def
enter_trace
():
new_dict
=
{}
for
k
,
v
in
out_grad_maps
.
items
():
if
v
is
not
None
:
new_dict
[
get_handle_id
(
k
)]
=
get_handle_id
(
v
.
grad
)
out_grad_maps
.
clear
()
out_grad_maps
.
update
(
new_dict
)
backward_trace_obj
.
_trace
.
enter
()
_add_backward_callback
(
enter_trace
)
return
ret
elif
not
compiled
:
if
custom_autodiff
is
None
:
_process_fwd_bwd_trace_result
(
trace_obj
,
backward_trace_obj
,
inp_grad_maps
,
out_grad_maps
)
if
len
(
backward_trace_obj
.
_trace
.
ops
)
>
0
:
custom_autodiff
=
CustomAutodiff
(
trace_obj
,
backward_trace_obj
)
else
:
custom_autodiff
=
CustomFwd
(
trace_obj
,
backward_trace_obj
)
fargs
=
trace_obj
.
flatten_inputs
(
*
args
,
**
kwargs
)
del
args
del
kwargs
if
outdef
is
None
:
return
custom_autodiff
(
*
fargs
)
else
:
return
outdef
.
unflatten
(
custom_autodiff
(
*
fargs
))
return
wrapped_func
if
func
is
None
:
return
wrapper
else
:
return
wrapper
(
func
)
imperative/python/megengine/jit/tracing.py
浏览文件 @
b11d4430
...
@@ -9,6 +9,7 @@ import pickle
...
@@ -9,6 +9,7 @@ import pickle
import
re
import
re
import
struct
import
struct
import
sys
import
sys
from
collections
import
OrderedDict
,
defaultdict
from
typing
import
Any
,
Sequence
from
typing
import
Any
,
Sequence
import
cv2
import
cv2
...
@@ -16,9 +17,22 @@ import numpy as np
...
@@ -16,9 +17,22 @@ import numpy as np
from
..
import
tensor
from
..
import
tensor
from
..core
import
_imperative_rt
as
rt
from
..core
import
_imperative_rt
as
rt
from
..core._imperative_rt
import
GraphProfiler
,
GraphProfiler2
,
SerializationMetadata
from
..core._imperative_rt
import
(
CompNode
,
GraphProfiler
,
GraphProfiler2
,
SerializationMetadata
,
)
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Trace
,
TraceError
,
name_tensor
# skip_tracing,
from
..core._imperative_rt.core2
import
Trace
,
TraceError
# skip_tracing,
from
..core._imperative_rt.core2
import
add_backward_callback
as
_add_backward_callback
from
..core._imperative_rt.core2
import
(
get_marked_input_tensor
,
get_marked_output_tensor
,
get_marked_tensor
,
marked_input_tensor
,
name_tensor
,
)
from
..core._imperative_rt.graph
import
_set_priority_to_id
from
..core._imperative_rt.graph
import
_set_priority_to_id
from
..core._imperative_rt.ops
import
(
from
..core._imperative_rt.ops
import
(
AssertEqual
,
AssertEqual
,
...
@@ -31,6 +45,7 @@ from ..core._imperative_rt.ops import (
...
@@ -31,6 +45,7 @@ from ..core._imperative_rt.ops import (
from
..core._trace_option
import
set_symbolic_shape
from
..core._trace_option
import
set_symbolic_shape
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor
import
megbrain_graph
as
G
from
..logger
import
get_logger
from
..logger
import
get_logger
from
..tensor
import
Tensor
from
..utils
import
comp_graph_tools
as
cgtools
from
..utils
import
comp_graph_tools
as
cgtools
from
..utils.naming
import
AutoNaming
from
..utils.naming
import
AutoNaming
from
..utils.profiler
import
is_profiling
from
..utils.profiler
import
is_profiling
...
@@ -94,8 +109,13 @@ class trace:
...
@@ -94,8 +109,13 @@ class trace:
opt_level: optimization level for compiling trace. Default: 2
opt_level: optimization level for compiling trace. Default: 2
graph_opt_config: configuration for graph optimization. Default: None
graph_opt_config: configuration for graph optimization. Default: None
symbolic_shape: whether to use symbolic shape for tracing. Default: True
symbolic_shape: whether to use symbolic shape for tracing. Default: True
without_host: if True, will run python code of wrapped function on the first call,
and run the compiled graph/function on subsequent calls. if False, will run python code every time.
Default: False
"""
"""
third_party_backend
=
False
def
__new__
(
cls
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
args
:
if
not
args
:
return
functools
.
partial
(
cls
,
**
kwargs
)
return
functools
.
partial
(
cls
,
**
kwargs
)
...
@@ -113,6 +133,7 @@ class trace:
...
@@ -113,6 +133,7 @@ class trace:
opt_level
:
int
=
2
,
opt_level
:
int
=
2
,
graph_opt_config
:
GraphOptimizationConfig
=
None
,
graph_opt_config
:
GraphOptimizationConfig
=
None
,
symbolic_shape
:
bool
=
True
,
symbolic_shape
:
bool
=
True
,
without_host
:
bool
=
False
,
):
):
self
.
__wrapped__
=
function
self
.
__wrapped__
=
function
self
.
_capture_as_const
=
capture_as_const
or
record_only
self
.
_capture_as_const
=
capture_as_const
or
record_only
...
@@ -150,6 +171,7 @@ class trace:
...
@@ -150,6 +171,7 @@ class trace:
graph_options
[
"graph_opt.jit_config.fuse_reduce"
]
=
mapping
[
graph_options
[
"graph_opt.jit_config.fuse_reduce"
]
=
mapping
[
graph_opt_config
.
jit_fuse_reduce
graph_opt_config
.
jit_fuse_reduce
]
]
if
sublinear_memory_config
is
not
None
:
if
sublinear_memory_config
is
not
None
:
graph_options
[
"enable_sublinear_memory_opt"
]
=
True
graph_options
[
"enable_sublinear_memory_opt"
]
=
True
graph_options
[
graph_options
[
...
@@ -186,8 +208,114 @@ class trace:
...
@@ -186,8 +208,114 @@ class trace:
self
.
_trace
.
profile
=
profiling
self
.
_trace
.
profile
=
profiling
self
.
_trace
.
array_comparator
=
array_comparator
self
.
_trace
.
array_comparator
=
array_comparator
self
.
_trace
.
record_input_shapes
=
_input_node_use_static_shape
()
self
.
_trace
.
record_input_shapes
=
_input_node_use_static_shape
()
self
.
_trace
.
without_host
=
without_host
self
.
check_external
=
True
self
.
traced
=
False
self
.
input_num
=
0
self
.
output_num
=
0
self
.
arg_list
=
[]
self
.
out_list
=
[]
self
.
overall
=
True
# forward keeped activation
self
.
keeped_activation
=
[]
self
.
third_party_backend_compiled
=
False
@
property
def
check_external
(
self
):
return
self
.
_trace
.
check_external
@
check_external
.
setter
def
check_external
(
self
,
flag
):
self
.
_trace
.
check_external
=
flag
@
property
def
without_host
(
self
):
return
self
.
_trace
.
without_host
def
flatten_inputs
(
self
,
*
args
,
**
kwargs
):
from
..traced_module.pytree
import
tree_flatten
from
..module
import
Module
tensor_args
=
[]
modules
=
[]
fargs
,
_
=
tree_flatten
((
args
,
kwargs
))
for
a
in
fargs
:
if
isinstance
(
a
,
RawTensor
):
tensor_args
.
append
(
a
)
elif
isinstance
(
a
,
Module
):
modules
.
append
(
a
)
for
m
in
modules
:
tensor_args
.
extend
(
list
(
m
.
parameters
()))
return
tensor_args
def
compile
(
self
):
raise
NotImplementedError
def
execute
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
setup_env
(
self
):
pass
def
unset_env
(
self
):
pass
def
compile_and_exec
(
self
,
*
args
,
**
kwargs
):
if
not
self
.
third_party_backend_compiled
:
self
.
compile
()
self
.
third_party_backend_compiled
=
True
return
self
.
execute
(
*
args
,
**
kwargs
)
def
convert_optimizer_state_to_tensor
(
self
,
*
args
,
**
kwargs
):
from
..traced_module.pytree
import
tree_flatten
,
SUPPORTED_LEAF_CLS
from
..optimizer
import
Optimizer
from
..tensor
import
Tensor
if
Optimizer
not
in
SUPPORTED_LEAF_CLS
:
SUPPORTED_LEAF_CLS
.
append
(
Optimizer
)
args
,
_
=
tree_flatten
((
args
,
kwargs
))
for
arg
in
args
:
if
isinstance
(
arg
,
Optimizer
):
arg
.
_disable_type_convert
=
False
for
param_group
in
arg
.
param_groups
:
for
k
,
v
in
param_group
.
items
():
if
not
isinstance
(
v
,
(
Tensor
,
Sequence
)):
param_group
[
k
]
=
Tensor
(
v
)
elif
isinstance
(
v
,
Sequence
)
and
not
isinstance
(
v
[
0
],
Tensor
):
new_v
=
[]
for
i
in
range
(
len
(
v
)):
new_v
.
append
(
Tensor
(
v
[
i
]))
param_group
[
k
]
=
new_v
def
setup_io_without_trace
(
self
,
inputs
,
outputs
):
self
.
traced
=
True
self
.
arg_list
=
[
i
for
i
in
inputs
if
i
!=
-
1
]
self
.
out_list
=
outputs
self
.
input_num
=
len
(
self
.
arg_list
)
self
.
output_num
=
len
([
i
for
i
in
outputs
if
i
!=
-
1
])
def
setup_without_host
(
self
):
self
.
inp_modules
=
set
()
self
.
module_tensors
=
set
()
self
.
tensor_to_attr
=
dict
()
self
.
attr_to_key
=
dict
()
self
.
update_param_dict
=
dict
()
self
.
update_opt_param_dict
=
dict
()
self
.
capture_optimizer_state
=
set
()
self
.
opt_param_dict
=
dict
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
not
self
.
without_host
:
return
self
.
trace_normal
(
*
args
,
**
kwargs
)
elif
self
.
overall
:
return
self
.
trace_without_host_overall
(
*
args
,
**
kwargs
)
else
:
return
self
.
trace_without_host
(
*
args
,
**
kwargs
)
def
trace_normal
(
self
,
*
args
,
**
kwargs
):
global
active_trace
global
active_trace
symbolic_shape
=
None
symbolic_shape
=
None
outputs
=
None
outputs
=
None
...
@@ -214,6 +342,270 @@ class trace:
...
@@ -214,6 +342,270 @@ class trace:
raise
raise
return
outputs
return
outputs
def
trace_without_host
(
self
,
*
args
,
**
kwargs
):
from
..traced_module.pytree
import
tree_flatten
,
SUPPORTED_LEAF_CLS
from
..module
import
Module
from
..utils.module_utils
import
get_expand_structure
from
..tensor
import
Tensor
from
..optimizer
import
Optimizer
assert
self
.
without_host
and
not
self
.
overall
global
active_trace
symbolic_shape
=
None
outputs
=
None
if
self
.
traced
and
self
.
third_party_backend
:
return
self
.
compile_and_exec
(
*
args
,
**
kwargs
)
try
:
active_trace
=
self
self
.
_trace
.
enter
()
if
self
.
_trace
.
compiled
():
arglist
=
self
.
flatten_inputs
(
*
args
,
**
kwargs
)
idx
=
0
inp_dict
=
{}
for
a
in
arglist
:
if
isinstance
(
a
,
RawTensor
):
inp_dict
[
self
.
arg_list
[
idx
]]
=
a
idx
+=
1
self
.
_trace
.
put_datas
(
inp_dict
)
outlist
=
[]
for
i
in
self
.
out_list
:
if
i
==
-
1
:
if
not
hasattr
(
self
,
"outdef"
):
outlist
.
append
(
None
)
else
:
outlist
.
append
(
self
.
_trace
.
get_data
(
i
))
keep_vars
=
[]
for
i
in
self
.
keeped_activation
:
keep_vars
.
append
(
self
.
_trace
.
get_data
(
i
))
outputs
=
(
self
.
outdef
.
unflatten
(
outlist
)
if
hasattr
(
self
,
"outdef"
)
else
outlist
)
if
keep_vars
:
return
outputs
,
keep_vars
else
:
return
outputs
arg_list
=
self
.
flatten_inputs
(
*
args
,
**
kwargs
)
for
i
,
arg
in
enumerate
(
arg_list
):
arg_list
[
i
].
_reset
(
get_marked_input_tensor
(
self
.
input_num
,
arg
))
self
.
arg_list
.
append
(
self
.
input_num
)
self
.
input_num
+=
1
del
arg_list
symbolic_shape
=
set_symbolic_shape
(
self
.
_symbolic_shape
)
if
self
.
third_party_backend
:
self
.
setup_env
()
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
finally
:
handling_exc
=
sys
.
exc_info
()
!=
(
None
,)
*
3
active_trace
=
None
if
symbolic_shape
is
not
None
:
symbolic_shape
=
set_symbolic_shape
(
symbolic_shape
)
assert
symbolic_shape
==
self
.
_symbolic_shape
if
self
.
third_party_backend
:
self
.
unset_env
()
if
(
self
.
_capture_as_const
and
(
outputs
is
not
None
)
and
not
self
.
without_host
):
self
.
_process_outputs
(
outputs
)
if
not
self
.
_trace
.
compiled
():
outlist
,
self
.
outdef
=
tree_flatten
(
outputs
)
for
i
,
out
in
enumerate
(
outlist
):
assert
isinstance
(
out
,
RawTensor
),
type
(
out
)
outlist
[
i
]
=
get_marked_output_tensor
(
self
.
output_num
,
out
)
del
out
self
.
out_list
.
append
(
self
.
output_num
)
self
.
output_num
+=
1
outputs
=
self
.
outdef
.
unflatten
(
outlist
)
try
:
# may raise TraceError
self
.
_trace
.
exit
()
except
Exception
as
e
:
if
isinstance
(
e
,
TraceError
):
if
not
handling_exc
:
raise
else
:
self
.
_trace
.
set_execption
(
str
(
e
))
raise
self
.
traced
=
True
return
outputs
def
trace_without_host_overall
(
self
,
*
args
,
**
kwargs
):
# record overall train step include forward, backward, param update in a single trace object
from
..traced_module.pytree
import
tree_flatten
,
SUPPORTED_LEAF_CLS
from
..module
import
Module
from
..utils.module_utils
import
get_expand_structure
from
..tensor
import
Tensor
from
..optimizer
import
Optimizer
assert
self
.
without_host
global
active_trace
symbolic_shape
=
None
outputs
=
None
if
self
.
traced
and
self
.
third_party_backend
:
return
self
.
compile_and_exec
(
*
args
,
**
kwargs
)
try
:
active_trace
=
self
if
not
self
.
traced
:
self
.
convert_optimizer_state_to_tensor
(
*
args
,
**
kwargs
)
self
.
_trace
.
enter
()
if
self
.
_trace
.
compiled
():
arglist
,
_
=
tree_flatten
((
args
,
kwargs
))
idx
=
0
inp_dict
=
{}
for
a
in
arglist
:
if
isinstance
(
a
,
RawTensor
):
inp_dict
[
self
.
arg_list
[
idx
]]
=
a
idx
+=
1
for
t
,
key
in
self
.
opt_param_dict
.
items
():
inp_dict
[
key
]
=
t
self
.
_trace
.
put_datas
(
inp_dict
)
for
attr
,
key
in
self
.
attr_to_key
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
self
.
_trace
.
put_data
(
key
,
param
)
outlist
=
[]
for
i
in
self
.
out_list
:
if
i
==
-
1
:
if
not
hasattr
(
self
,
"outdef"
):
outlist
.
append
(
None
)
else
:
outlist
.
append
(
self
.
_trace
.
get_data
(
i
))
for
attr
,
key
in
self
.
update_param_dict
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
param
.
_reset
(
self
.
_trace
.
get_data
(
key
))
for
state
,
key
in
self
.
update_opt_param_dict
.
items
():
state
.
_reset
(
self
.
_trace
.
get_data
(
key
))
keep_vars
=
[]
for
i
in
self
.
keeped_activation
:
keep_vars
.
append
(
self
.
_trace
.
get_data
(
i
))
outputs
=
(
self
.
outdef
.
unflatten
(
outlist
)
if
hasattr
(
self
,
"outdef"
)
else
outlist
)
if
keep_vars
:
return
outputs
,
keep_vars
else
:
return
outputs
self
.
setup_without_host
()
def
get_attr_hook
(
obj
,
attr
):
rst
=
object
.
__getattribute__
(
obj
,
attr
)
if
isinstance
(
rst
,
RawTensor
):
assert
rst
in
self
.
tensor_to_attr
attr
=
self
.
tensor_to_attr
[
rst
]
if
attr
not
in
self
.
attr_to_key
:
self
.
attr_to_key
[
attr
]
=
self
.
input_num
self
.
input_num
+=
1
marked_input_tensor
(
self
.
attr_to_key
[
attr
],
rst
)
return
rst
origin_reset
=
Tensor
.
_reset
self
.
update_param_num
=
0
def
tensor_wrapper_resethook
(
obj
,
other
):
if
obj
in
self
.
tensor_to_attr
:
attr
=
self
.
tensor_to_attr
[
obj
]
other
=
get_marked_output_tensor
(
self
.
output_num
,
other
)
self
.
update_param_num
+=
1
self
.
update_param_dict
[
attr
]
=
self
.
output_num
self
.
output_num
+=
1
elif
obj
in
self
.
capture_optimizer_state
:
other
=
get_marked_output_tensor
(
self
.
output_num
,
other
)
self
.
update_opt_param_dict
[
obj
]
=
self
.
output_num
self
.
output_num
+=
1
origin_reset
(
obj
,
other
)
arg_list
,
self
.
argdef
=
tree_flatten
((
args
,
kwargs
))
for
i
,
arg
in
enumerate
(
arg_list
):
if
isinstance
(
arg
,
Module
):
for
k
,
v
in
arg
.
named_tensors
():
if
v
not
in
self
.
tensor_to_attr
:
self
.
tensor_to_attr
[
v
]
=
(
arg
,
k
)
self
.
inp_modules
.
add
(
arg
)
elif
isinstance
(
arg
,
RawTensor
):
arg_list
[
i
]
=
get_marked_input_tensor
(
self
.
input_num
,
arg
)
self
.
arg_list
.
append
(
self
.
input_num
)
self
.
input_num
+=
1
elif
isinstance
(
arg
,
Optimizer
):
opt_params
,
_
=
tree_flatten
(
arg
.
state_dict
(
keep_var
=
True
))
for
p
in
opt_params
:
if
isinstance
(
p
,
Tensor
):
self
.
capture_optimizer_state
.
add
(
p
)
self
.
opt_param_dict
=
{}
for
t
in
self
.
capture_optimizer_state
:
if
t
not
in
self
.
tensor_to_attr
:
# not module parameter
mark_param
=
get_marked_input_tensor
(
self
.
input_num
,
t
)
self
.
opt_param_dict
[
t
]
=
self
.
input_num
t
[...]
=
mark_param
self
.
input_num
+=
1
args
,
kwargs
=
self
.
argdef
.
unflatten
(
arg_list
)
Module
.
__getattribute__
=
get_attr_hook
Tensor
.
_reset
=
tensor_wrapper_resethook
symbolic_shape
=
set_symbolic_shape
(
self
.
_symbolic_shape
)
if
self
.
third_party_backend
:
self
.
setup_env
()
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
del
arg_list
del
args
del
kwargs
Module
.
__getattribute__
=
object
.
__getattribute__
Tensor
.
_reset
=
origin_reset
for
attr
,
key
in
self
.
attr_to_key
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
finally
:
handling_exc
=
sys
.
exc_info
()
!=
(
None
,)
*
3
active_trace
=
None
if
symbolic_shape
is
not
None
:
symbolic_shape
=
set_symbolic_shape
(
symbolic_shape
)
assert
symbolic_shape
==
self
.
_symbolic_shape
if
self
.
third_party_backend
:
self
.
unset_env
()
if
(
self
.
_capture_as_const
and
(
outputs
is
not
None
)
and
not
self
.
without_host
):
self
.
_process_outputs
(
outputs
)
if
not
self
.
_trace
.
compiled
():
outlist
,
self
.
outdef
=
tree_flatten
(
outputs
)
for
i
,
out
in
enumerate
(
outlist
):
assert
isinstance
(
out
,
RawTensor
)
outlist
[
i
]
=
get_marked_output_tensor
(
self
.
output_num
,
out
)
del
out
self
.
out_list
.
append
(
self
.
output_num
)
self
.
output_num
+=
1
outputs
=
self
.
outdef
.
unflatten
(
outlist
)
try
:
# may raise TraceError
self
.
_trace
.
exit
()
except
Exception
as
e
:
if
isinstance
(
e
,
TraceError
):
if
not
handling_exc
:
raise
else
:
self
.
_trace
.
set_execption
(
str
(
e
))
raise
self
.
traced
=
True
return
outputs
@
property
def
ops
(
self
):
return
self
.
_trace
.
ops
@
property
def
vars
(
self
):
return
self
.
_trace
.
vars
def
_process_inputs
(
self
,
*
args
,
**
kwargs
):
def
_process_inputs
(
self
,
*
args
,
**
kwargs
):
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
assert
isinstance
(
assert
isinstance
(
...
...
imperative/python/megengine/jit/xla_backend.py
0 → 100644
浏览文件 @
b11d4430
from
collections
import
OrderedDict
,
defaultdict
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._trace_option
import
set_use_xla_backend
from
..device
import
get_default_device
from
..utils.dlpack
import
from_dlpack
,
to_dlpack
from
.tracing
import
trace
try
:
from
..xla.lib
import
xla_client
as
xc
except
ImportError
:
pass
class
xla_trace
(
trace
):
r
"""Wraps a callable, and provides accelerated evaluation compiled by xla.
Currently it is an experimental feature.
Refer to :class:`~.jit.tracing.trace` for more information.
Examples:
.. code-block:: python
import numpy as np
from basecls.models.resnet import resnet18
from megengine.autodiff.grad_manager import GradManager
from megengine.jit import xla_trace
from megengine.optimizer import Adam
model = resnet18()
gm = GradManager()
opt = Adam(model.parameters(), lr=1e-4)
gm.attach(model.parameters())
# Only tensors in wrapped func args/kwargs will be treated as graph inputs,
# and other tensors will be captured as const value.
# Module, optimizer, and train data/label should be arguments of the wrapped function.
@xla_trace(capture_as_const=True)
def train_step(model, opt, data, label):
with gm:
pred = model(data)
loss = F.loss.cross_entropy(pred, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
"""
third_party_backend
=
True
def
__init__
(
self
,
function
,
*
,
without_host
=
True
,
symbolic_shape
=
False
,
**
kwargs
):
assert
without_host
,
"xla trace only support without host mode"
assert
not
symbolic_shape
,
"xla doesn't support dynamic shape currently"
super
().
__init__
(
function
,
without_host
=
without_host
,
symbolic_shape
=
symbolic_shape
,
**
kwargs
)
def
setup_env
(
self
):
self
.
orig_use_xla
=
set_use_xla_backend
(
True
)
def
unset_env
(
self
):
set_use_xla_backend
(
self
.
orig_use_xla
)
def
compile
(
self
):
from
..xla
import
build_xla
from
..traced_module.pytree
import
SUPPORTED_LEAF_TYPE
,
register_supported_type
from
..utils.module_utils
import
get_expand_structure
from
..xla.device
import
get_xla_backend_and_device
from
..tensor
import
Tensor
assert
self
.
traced
if
self
.
overall
:
for
attr
,
_
in
self
.
attr_to_key
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
param
.
_reset
(
param
.
to
(
"cpux"
))
for
tensor
,
_
in
self
.
opt_param_dict
.
items
():
tensor
.
_reset
(
tensor
.
to
(
"cpux"
))
self
.
xla_exec
,
self
.
inp_ids
,
self
.
out_ids
=
build_xla
(
self
,
return_with_io
=
True
,
return_device_array
=
True
)
id2inpidx
=
defaultdict
(
list
)
id2outidx
=
defaultdict
(
list
)
for
idx
,
id
in
enumerate
(
self
.
inp_ids
):
id2inpidx
[
id
].
append
(
idx
)
for
idx
,
id
in
enumerate
(
self
.
out_ids
):
id2outidx
[
id
].
append
(
idx
)
self
.
inpkey2idx
=
{}
self
.
outkey2idx
=
{}
if
self
.
input_num
==
len
(
set
(
self
.
inp_ids
))
-
1
:
self
.
has_randomstate
=
True
self
.
random_seed
=
Tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
"int32"
)
else
:
assert
self
.
input_num
==
len
(
set
(
self
.
inp_ids
))
self
.
has_randomstate
=
False
inpmark2id
=
dict
()
outmark2id
=
dict
()
for
var
in
self
.
vars
:
if
var
.
kind
==
"external"
:
for
mark
in
var
.
inp_mark
:
inpmark2id
[
mark
]
=
var
.
id
elif
var
.
data_required
and
var
.
out_mark
:
for
mark
in
var
.
out_mark
:
outmark2id
[
mark
]
=
var
.
id
for
k
,
v
in
inpmark2id
.
items
():
for
idx
in
id2inpidx
[
v
]:
self
.
inpkey2idx
[
k
]
=
idx
for
k
,
v
in
outmark2id
.
items
():
for
idx
in
id2outidx
[
v
]:
self
.
outkey2idx
[
k
]
=
idx
def
prepare_xla_inputs
(
self
,
tensors
):
from
..utils.module_utils
import
get_expand_structure
inp_count
=
0
inp_list
=
[
0
]
*
self
.
input_num
for
idx
,
t
in
enumerate
(
tensors
):
inp
=
self
.
inpkey2idx
[
self
.
arg_list
[
idx
]]
inp_list
[
inp
]
=
t
inp_count
+=
1
if
self
.
overall
:
for
attr
,
key
in
self
.
attr_to_key
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
inp
=
self
.
inpkey2idx
[
key
]
inp_list
[
inp
]
=
param
inp_count
+=
1
for
tensor
,
k
in
self
.
opt_param_dict
.
items
():
inp
=
self
.
inpkey2idx
[
k
]
inp_list
[
inp
]
=
tensor
inp_count
+=
1
assert
inp_count
==
self
.
input_num
if
self
.
has_randomstate
:
inp_list
.
append
(
self
.
random_seed
)
return
inp_list
def
to_dlpack
(
self
,
x
,
take_ownership
:
bool
=
True
):
return
xc
.
_xla
.
buffer_to_dlpack_managed_tensor
(
x
,
take_ownership
=
take_ownership
)
def
execute
(
self
,
*
args
,
**
kwargs
):
from
..traced_module.pytree
import
tree_flatten
from
..tensor
import
Tensor
from
..utils.module_utils
import
get_expand_structure
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
arrays
=
[]
cn
=
CompNode
(
get_default_device
())
stream
=
dict
(
self
.
xla_exec
.
backend
.
get_compute_compnode
())
device_kind
,
device_id
,
stream_id
=
cn
.
physical_locator
xla_stream
=
stream
[
device_id
]
xla_comp_cn
=
"gpu{}:{}"
.
format
(
device_id
,
xla_stream
)
for
t
in
inputs
:
if
isinstance
(
t
,
RawTensor
):
assert
cn
==
t
.
device
arrays
.
append
(
t
.
to
(
xla_comp_cn
,
_borrow
=
True
))
arrays
=
self
.
prepare_xla_inputs
(
arrays
)
outputs
=
self
.
xla_exec
(
*
arrays
)
return_vals
=
[]
for
i
in
self
.
out_list
:
if
i
==
-
1
:
if
not
hasattr
(
self
,
"outdef"
):
return_vals
.
append
(
None
)
else
:
return_vals
.
append
(
outputs
[
self
.
outkey2idx
[
i
]])
keeped_features
=
[]
for
i
in
self
.
keeped_activation
:
capsule
=
self
.
to_dlpack
(
outputs
[
self
.
outkey2idx
[
i
]])
t
=
from_dlpack
(
capsule
,
xla_stream
).
to
(
cn
,
_borrow
=
True
)
keeped_features
.
append
(
t
)
out_tensors
=
[]
for
array
in
return_vals
:
if
array
is
not
None
:
capsule
=
self
.
to_dlpack
(
array
)
t
=
from_dlpack
(
capsule
,
xla_stream
)
out_tensors
.
append
(
t
.
to
(
cn
,
_borrow
=
True
))
else
:
out_tensors
.
append
(
array
)
if
self
.
overall
:
for
attr
,
key
in
self
.
update_param_dict
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
xla_array
=
outputs
[
self
.
outkey2idx
[
key
]]
capsule
=
self
.
to_dlpack
(
xla_array
)
param
.
_reset
(
from_dlpack
(
capsule
).
to
(
cn
,
_borrow
=
True
))
for
state
,
key
in
self
.
update_opt_param_dict
.
items
():
xla_array
=
outputs
[
self
.
outkey2idx
[
key
]]
capsule
=
self
.
to_dlpack
(
xla_array
)
state
.
_reset
(
from_dlpack
(
capsule
).
to
(
cn
,
_borrow
=
True
))
rst
=
(
self
.
outdef
.
unflatten
(
out_tensors
)
if
hasattr
(
self
,
"outdef"
)
else
out_tensors
)
if
keeped_features
:
return
rst
,
keeped_features
else
:
return
rst
imperative/python/megengine/module/module.py
浏览文件 @
b11d4430
...
@@ -49,6 +49,8 @@ def _access_structure(obj, key, callback=None):
...
@@ -49,6 +49,8 @@ def _access_structure(obj, key, callback=None):
cur
=
cur
[
k
]
cur
=
cur
[
k
]
else
:
else
:
cur
=
getattr
(
cur
,
k
)
cur
=
getattr
(
cur
,
k
)
if
callable
is
None
:
return
cur
return
callback
(
parent
,
k
,
cur
)
return
callback
(
parent
,
k
,
cur
)
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
b11d4430
...
@@ -115,11 +115,9 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
...
@@ -115,11 +115,9 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
}
}
py
::
object
Py_Varnode
=
py
::
none
();
py
::
object
Py_Varnode
=
py
::
none
();
const
std
::
unique_ptr
<
mgb
::
OprFootprint
>
_imperative_sm_opr_footprint_ptr
{
std
::
make_unique
<
mgb
::
OprFootprint
>
()};
void
init_graph_rt
(
py
::
module
m
)
{
void
init_graph_rt
(
py
::
module
m
)
{
static
const
std
::
unique_ptr
<
mgb
::
OprFootprint
>
_imperative_sm_opr_footprint_ptr
{
std
::
make_unique
<
mgb
::
OprFootprint
>
()};
def_rendezvous
<
DeviceTensorND
>
(
m
,
"DeviceTensorNDRendezvous"
);
def_rendezvous
<
DeviceTensorND
>
(
m
,
"DeviceTensorNDRendezvous"
);
def_rendezvous
<
HostNDWithEvent
>
(
m
,
"HostTensorNDRendezvous"
);
def_rendezvous
<
HostNDWithEvent
>
(
m
,
"HostTensorNDRendezvous"
);
...
...
imperative/python/src/graph_rt.h
浏览文件 @
b11d4430
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
extern
py
::
object
Py_Varnode
;
extern
py
::
object
Py_Varnode
;
extern
const
std
::
unique_ptr
<
mgb
::
OprFootprint
>
_imperative_sm_opr_footprint_ptr
;
template
<
typename
T
>
template
<
typename
T
>
class
GraphNodePtr
{
class
GraphNodePtr
{
std
::
shared_ptr
<
mgb
::
cg
::
ComputingGraph
>
m_graph
;
std
::
shared_ptr
<
mgb
::
cg
::
ComputingGraph
>
m_graph
;
...
...
imperative/python/src/tensor.cpp
浏览文件 @
b11d4430
此差异已折叠。
点击以展开。
imperative/python/src/tensor.h
浏览文件 @
b11d4430
...
@@ -74,6 +74,7 @@ public:
...
@@ -74,6 +74,7 @@ public:
inline
ValueRef
data
()
const
{
return
m_data
.
unwrap
();
}
inline
ValueRef
data
()
const
{
return
m_data
.
unwrap
();
}
bool
is_scalar
()
{
return
data
().
is_scalar
();
}
bool
is_scalar
()
{
return
data
().
is_scalar
();
}
inline
std
::
string
name
()
{
return
m_name
;
}
inline
std
::
string
name
()
{
return
m_name
;
}
inline
size_t
value_id
()
{
return
m_data
.
id
();
}
inline
void
set_name
(
std
::
string
name
)
{
inline
void
set_name
(
std
::
string
name
)
{
m_name
=
name
;
m_name
=
name
;
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
...
@@ -128,6 +129,7 @@ public:
...
@@ -128,6 +129,7 @@ public:
void
reset
(
PyObject
*
);
void
reset
(
PyObject
*
);
PyObject
*
detach
();
PyObject
*
detach
();
PyObject
*
isscalar
();
PyObject
*
isscalar
();
PyObject
*
value_id
();
PyObject
*
_dev_tensor
();
PyObject
*
_dev_tensor
();
void
_drop
();
void
_drop
();
PyObject
*
varnode
();
PyObject
*
varnode
();
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
b11d4430
...
@@ -18,7 +18,13 @@ from megengine.core.ops import builtin as ops
...
@@ -18,7 +18,13 @@ from megengine.core.ops import builtin as ops
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.core.tensor.utils
import
isscalar
from
megengine.core.tensor.utils
import
isscalar
from
megengine.functional
import
exp
,
log
from
megengine.functional
import
exp
,
log
from
megengine.jit
import
GraphOptimizationConfig
,
TraceError
,
exclude_from_trace
,
trace
from
megengine.jit
import
(
GraphOptimizationConfig
,
TraceError
,
exclude_from_trace
,
partial_trace
,
trace
,
)
from
megengine.module
import
Module
from
megengine.module
import
Module
from
megengine.random
import
normal
,
uniform
from
megengine.random
import
normal
,
uniform
from
megengine.utils.naming
import
AutoNaming
from
megengine.utils.naming
import
AutoNaming
...
@@ -803,3 +809,87 @@ def test_dump_without_output_error():
...
@@ -803,3 +809,87 @@ def test_dump_without_output_error():
str
(
e
)
str
(
e
)
==
"the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
==
"the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
)
)
@
pytest
.
mark
.
parametrize
(
"trace_mode"
,
[
False
,
True
])
def
test_trace_without_host
(
trace_mode
):
@
trace
(
symbolic
=
trace_mode
,
without_host
=
True
)
def
fwd
(
a
,
b
,
c
):
x
=
a
+
b
y
=
a
+
c
z
=
x
*
y
z1
=
x
/
y
return
[
z
,
z1
]
a
=
tensor
([
1.0
])
b
=
tensor
([
2.0
])
c
=
tensor
([
3.0
])
rst
=
fwd
(
a
,
b
,
c
)
for
_
in
range
(
2
):
trace_rst
=
fwd
(
a
,
b
,
c
)
np
.
testing
.
assert_equal
(
rst
[
0
],
trace_rst
[
0
])
np
.
testing
.
assert_equal
(
rst
[
1
],
trace_rst
[
1
])
def
test_trace_without_error
():
const
=
tensor
([
8.0
])
@
trace
(
symbolic
=
False
,
without_host
=
True
)
def
fwd
(
a
,
b
,
c
):
x
=
a
+
b
y
=
a
+
c
z
=
x
*
y
z1
=
x
/
y
+
const
return
[
z
,
z1
]
try
:
a
=
tensor
([
1.0
])
b
=
tensor
([
2.0
])
c
=
tensor
([
3.0
])
fwd
(
a
,
b
,
c
)
except
Exception
as
e
:
assert
str
(
e
)
==
"have some unknown input tensors in trace result"
else
:
assert
False
def
test_partial_trace_fwd_bwd
():
class
Simple
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
a
=
Parameter
([
1.0
],
dtype
=
np
.
float32
)
self
.
b
=
Parameter
([
2.0
],
dtype
=
np
.
float32
)
@
partial_trace
def
forward
(
self
,
x
):
x
=
x
*
self
.
a
+
x
/
self
.
b
x
=
F
.
exp
(
x
)
return
x
def
clear_grad
(
self
):
self
.
a
.
grad
=
None
self
.
b
.
grad
=
None
@
partial_trace
def
fwd_only
(
a
,
b
):
return
a
*
b
+
a
/
b
m
=
Simple
()
gm
=
GradManager
()
gm
.
attach
(
m
.
parameters
())
def
func
(
x
):
with
gm
:
x
=
x
*
3
x
=
m
(
x
)
x
=
x
*
2
gm
.
backward
(
x
)
a
=
m
.
a
.
grad
b
=
m
.
b
.
grad
m
.
clear_grad
()
return
fwd_only
(
a
,
b
)
+
a
+
b
gt
=
func
(
tensor
(
1.0
))
for
_
in
range
(
3
):
out
=
func
(
tensor
(
1.0
))
np
.
testing
.
assert_equal
(
gt
.
numpy
(),
out
.
numpy
())
imperative/src/impl/basic_operators.cpp
浏览文件 @
b11d4430
...
@@ -105,6 +105,10 @@ std::string IsScalar::to_string() const {
...
@@ -105,6 +105,10 @@ std::string IsScalar::to_string() const {
return
"IsScalar"
;
return
"IsScalar"
;
}
}
std
::
string
GetId
::
to_string
()
const
{
return
"GetId"
;
}
std
::
string
GetFormat
::
to_string
()
const
{
std
::
string
GetFormat
::
to_string
()
const
{
return
"GetFormat{}"
;
return
"GetFormat{}"
;
}
}
...
...
imperative/src/impl/basic_values.cpp
浏览文件 @
b11d4430
...
@@ -15,6 +15,10 @@ std::string BoolValue::to_string() const {
...
@@ -15,6 +15,10 @@ std::string BoolValue::to_string() const {
return
(
*
this
)
?
"true"
:
"false"
;
return
(
*
this
)
?
"true"
:
"false"
;
}
}
std
::
string
IntegerValue
::
to_string
()
const
{
return
std
::
to_string
((
int
)
*
this
);
}
std
::
string
HostStorage
::
to_string
()
const
{
std
::
string
HostStorage
::
to_string
()
const
{
return
ssprintf
(
"HostStorage{device=%s}"
,
comp_node
().
to_string
().
c_str
());
return
ssprintf
(
"HostStorage{device=%s}"
,
comp_node
().
to_string
().
c_str
());
}
}
...
...
imperative/src/impl/op_def.cpp
浏览文件 @
b11d4430
...
@@ -142,6 +142,10 @@ const std::string OpDef::make_name() const {
...
@@ -142,6 +142,10 @@ const std::string OpDef::make_name() const {
return
m_scope
+
"."
+
trait
()
->
make_name
(
*
this
);
return
m_scope
+
"."
+
trait
()
->
make_name
(
*
this
);
}
}
const
std
::
string
OpDef
::
type_name
()
const
{
return
trait
()
->
name
;
}
static
thread_local
OpDef
::
allocator_t
local_allocator
;
static
thread_local
OpDef
::
allocator_t
local_allocator
;
void
OpDef
::
set_allocator
(
allocator_t
allocator
)
{
void
OpDef
::
set_allocator
(
allocator_t
allocator
)
{
...
...
imperative/src/impl/ops/opr_attr.cpp
浏览文件 @
b11d4430
...
@@ -13,6 +13,10 @@ namespace imperative {
...
@@ -13,6 +13,10 @@ namespace imperative {
namespace
{
namespace
{
class
OprParamsLoadContext
final
:
public
serialization
::
OprLoadContextRawPOD
{
class
OprParamsLoadContext
final
:
public
serialization
::
OprLoadContextRawPOD
{
public:
bool
strict
=
true
;
private:
const
OprAttr
::
Param
&
m_param
;
const
OprAttr
::
Param
&
m_param
;
size_t
m_pos
=
0
;
size_t
m_pos
=
0
;
ComputingGraph
*
m_graph
;
ComputingGraph
*
m_graph
;
...
@@ -40,7 +44,8 @@ public:
...
@@ -40,7 +44,8 @@ public:
m_graph
(
graph
)
{}
m_graph
(
graph
)
{}
~
OprParamsLoadContext
()
{
~
OprParamsLoadContext
()
{
mgb_assert
(
m_pos
==
m_param
.
size
(),
"param not fully consumed"
);
if
(
strict
)
mgb_assert
(
m_pos
==
m_param
.
size
(),
"param not fully consumed"
);
}
}
ComputingGraph
&
graph
()
override
{
return
*
m_graph
;
}
ComputingGraph
&
graph
()
override
{
return
*
m_graph
;
}
...
@@ -126,7 +131,9 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
...
@@ -126,7 +131,9 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
if
(
get_type2policy
().
find
(
opr
->
dyn_typeinfo
())
!=
get_type2policy
().
end
())
{
if
(
get_type2policy
().
find
(
opr
->
dyn_typeinfo
())
!=
get_type2policy
().
end
())
{
policy
=
get_type2policy
().
at
(
opr
->
dyn_typeinfo
()).
first
(
opr
);
policy
=
get_type2policy
().
at
(
opr
->
dyn_typeinfo
()).
first
(
opr
);
}
}
return
OprAttr
::
make
(
registry
->
name
,
std
::
move
(
ctx
.
m_param
),
policy
,
opr
->
config
());
return
OprAttr
::
make
(
registry
->
name
,
std
::
move
(
ctx
.
m_param
),
policy
,
opr
->
config
(),
opr
->
dyn_typeinfo
());
}
}
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
def
)
{
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
def
)
{
...
@@ -168,6 +175,12 @@ size_t OprAttr::hash() const {
...
@@ -168,6 +175,12 @@ size_t OprAttr::hash() const {
config
.
hash
());
config
.
hash
());
}
}
std
::
shared_ptr
<
json
::
Value
>
OprAttr
::
mgb_param
(
OprFootprint
*
footprint
)
{
OprParamsLoadContext
ctx
{
param
,
nullptr
};
ctx
.
strict
=
false
;
return
footprint
->
get_serial_param_json
(
mgb_opr_type
,
ctx
);
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
OprAttr
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
OprAttr
);
}
// namespace imperative
}
// namespace imperative
...
...
imperative/src/impl/transformations/eval.cpp
浏览文件 @
b11d4430
...
@@ -130,6 +130,11 @@ ValueRefList InterpreterTransformation::apply_transformation(
...
@@ -130,6 +130,11 @@ ValueRefList InterpreterTransformation::apply_transformation(
}
else
{
}
else
{
return
{
ValueRef
()};
return
{
ValueRef
()};
}
}
}
else
if
(
op
.
is
<
GetId
>
())
{
auto
&
val
=
inputs
[
0
].
cast
(
m_value_type
);
int64_t
id
=
val
.
id
();
return
{
IntegerValue
::
make
(
id
)};
}
else
if
(
op
.
is
<
DupTensor
>
())
{
}
else
if
(
op
.
is
<
DupTensor
>
())
{
auto
&
input
=
inputs
[
0
].
cast
(
m_value_type
);
auto
&
input
=
inputs
[
0
].
cast
(
m_value_type
);
DeviceTensorND
dev_tensor
;
DeviceTensorND
dev_tensor
;
...
...
imperative/src/impl/transformations/grad.cpp
浏览文件 @
b11d4430
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/resource_manager.h"
#include "megbrain/imperative/resource_manager.h"
#include <range/v3/all.hpp>
namespace
mgb
{
namespace
mgb
{
namespace
imperative
{
namespace
imperative
{
...
@@ -226,7 +227,7 @@ void GradKey::backward() {
...
@@ -226,7 +227,7 @@ void GradKey::backward() {
if
constexpr
(
std
::
is_same_v
<
T
,
std
::
monostate
>
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
std
::
monostate
>
)
{
mgb_throw
(
AssertionError
,
"invalid backward"
);
mgb_throw
(
AssertionError
,
"invalid backward"
);
}
else
{
}
else
{
mgb_assert
(
grad_fn
->
m_slots
.
size
()
>
0
);
//
mgb_assert(grad_fn->m_slots.size() > 0);
SmallVector
<
ValueRef
>
grads
(
grad_fn
->
m_slots
.
size
());
SmallVector
<
ValueRef
>
grads
(
grad_fn
->
m_slots
.
size
());
auto
iter
=
grads
.
begin
();
auto
iter
=
grads
.
begin
();
for
(
auto
&&
slot
:
grad_fn
->
m_slots
)
{
for
(
auto
&&
slot
:
grad_fn
->
m_slots
)
{
...
@@ -419,6 +420,23 @@ ValueRefList GradTransformation::apply_transformation(
...
@@ -419,6 +420,23 @@ ValueRefList GradTransformation::apply_transformation(
mgb_assert
(
!
grad_fn
->
m_slots
.
empty
());
mgb_assert
(
!
grad_fn
->
m_slots
.
empty
());
m_key
->
m_tape
.
push_back
({
grad_fn
,
op_val
->
op
().
shared_from_this
()});
m_key
->
m_tape
.
push_back
({
grad_fn
,
op_val
->
op
().
shared_from_this
()});
return
outputs
;
return
outputs
;
}
else
if
(
auto
*
igc
=
op
.
as
<
InsertGradCallback
>
())
{
auto
grad_fn
=
LocalPtr
<
GradFn
>::
make
();
auto
&
backward
=
std
::
get
<
CustomBackward
>
(
grad_fn
->
m_backward
=
CustomBackward
());
auto
id
=
inputs
[
0
];
backward
.
m_backward
=
[
id
,
callback
=
igc
->
callback
()](
Span
<
ValueRef
>
inputs
)
->
SmallVector
<
ValueRef
>
{
callback
({
&
id
,
(
size_t
)
1
});
return
{};
};
m_key
->
m_side_effects
.
push_back
(
grad_fn
);
m_key
->
m_tape
.
push_back
({
grad_fn
,
nullptr
});
auto
next_id
=
IntegerValue
::
make
((
int
)
id
.
cast
<
IntegerValue
>
()
+
1
);
auto
prev_count
=
imperative
::
apply
(
InsertGradCallback
(
igc
->
callback
()),
next_id
)[
0
];
auto
count
=
IntegerValue
::
make
((
int
)
prev_count
.
cast
<
IntegerValue
>
()
+
1
);
return
{
count
};
}
else
if
(
op
.
is
<
CreateTensor
>
())
{
}
else
if
(
op
.
is
<
CreateTensor
>
())
{
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
}
else
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
}
else
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
...
@@ -514,6 +532,13 @@ ValueRefList GradTransformation::apply_transformation(
...
@@ -514,6 +532,13 @@ ValueRefList GradTransformation::apply_transformation(
}
}
}
}
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
}
else
if
(
op
.
is
<
GetGradSlot
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
return
{
GradSlotValue
::
make
(
grad_value
->
slot
())};
}
else
{
return
{};
}
}
else
if
(
op
.
kind
()
==
Operator
::
IdentityLike
)
{
}
else
if
(
op
.
kind
()
==
Operator
::
IdentityLike
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
if
(
auto
&&
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
...
...
imperative/src/impl/transformations/lazy.cpp
浏览文件 @
b11d4430
...
@@ -53,6 +53,9 @@ ValueRefList LazyEvalTransformation::apply_transformation(
...
@@ -53,6 +53,9 @@ ValueRefList LazyEvalTransformation::apply_transformation(
outputs
[
i
]
=
record_var
(
output_nodes
[
i
]);
outputs
[
i
]
=
record_var
(
output_nodes
[
i
]);
}
}
return
outputs
;
return
outputs
;
}
else
if
(
op
.
is
<
GetId
>
())
{
int64_t
id
=
inputs
[
0
].
id
();
return
{
IntegerValue
::
make
(
id
)};
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
auto
&&
args
=
create_tensor
->
parse
(
inputs
);
auto
&&
args
=
create_tensor
->
parse
(
inputs
);
auto
get_dev_val
=
[
&
]
{
auto
get_dev_val
=
[
&
]
{
...
...
imperative/src/impl/transformations/trace.cpp
浏览文件 @
b11d4430
...
@@ -83,7 +83,7 @@ VarNodeArray TraceResult::dump(
...
@@ -83,7 +83,7 @@ VarNodeArray TraceResult::dump(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
cg
::
OperatorNodeBase
*>>
name2ops
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
cg
::
OperatorNodeBase
*>>
name2ops
;
// iterate over opr_seq
// iterate over opr_seq
for
(
auto
&&
item
:
seq
)
{
for
(
auto
&&
item
:
seq
)
{
auto
&&
[
op
,
inputs
,
outputs
]
=
item
;
auto
&&
[
op
,
inputs
,
outputs
,
type
]
=
item
;
VarNodeArray
input_nodes
;
VarNodeArray
input_nodes
;
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
auto
&
node
=
nodes
[
input
];
auto
&
node
=
nodes
[
input
];
...
@@ -207,7 +207,8 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -207,7 +207,8 @@ ValueRefList TracingTransformation::apply_transformation(
auto
wrapped_output
=
record_var
(
outputs
[
0
],
as_const
,
VarKind
::
Internal
);
auto
wrapped_output
=
record_var
(
outputs
[
0
],
as_const
,
VarKind
::
Internal
);
auto
input_id
=
wrapped_input
->
id
();
auto
input_id
=
wrapped_input
->
id
();
auto
output_id
=
wrapped_output
->
id
();
auto
output_id
=
wrapped_output
->
id
();
m_seq
.
push_back
({{},
{
input_id
},
{
output_id
}});
m_seq
.
push_back
({{},
{
input_id
},
{
output_id
},
OpKind
::
CreateTensor
});
return
{
wrapped_output
};
return
{
wrapped_output
};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
unwrapped_input
=
unwrap_var
(
inputs
[
0
]);
auto
unwrapped_input
=
unwrap_var
(
inputs
[
0
]);
...
@@ -246,7 +247,30 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -246,7 +247,30 @@ ValueRefList TracingTransformation::apply_transformation(
}
}
auto
output
=
record_var
(
input
,
false
,
VarKind
::
Internal
);
auto
output
=
record_var
(
input
,
false
,
VarKind
::
Internal
);
m_vars
[
output
->
id
()].
mark
=
trace_mark_var
->
mark
();
m_vars
[
output
->
id
()].
mark
=
trace_mark_var
->
mark
();
m_seq
.
push_back
({{},
{
tracing_var
->
id
()},
{
output
->
id
()}});
m_seq
.
push_back
(
{{},
{
tracing_var
->
id
()},
{
output
->
id
()},
OpKind
::
TraceMarkVar
});
return
{
output
};
}
else
if
(
auto
*
iomarker
=
op
.
as
<
IOMarkVar
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"IOMarkVar expects exactly one input"
);
auto
input
=
inputs
[
0
];
auto
tracing_var
=
input
.
as_ref
(
m_value_type
);
if
(
!
tracing_var
)
{
bool
is_input
=
iomarker
->
kind
()
==
IOMarkVar
::
Kind
::
Input
;
if
(
is_input
)
{
tracing_var
=
record_var
(
input
,
false
,
VarKind
::
External
);
}
else
{
tracing_var
=
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
}
}
else
{
input
=
tracing_var
->
value
();
}
auto
output
=
record_var
(
input
,
false
,
VarKind
::
Internal
);
if
(
iomarker
->
kind
()
==
IOMarkVar
::
Kind
::
Input
)
m_vars
[
tracing_var
->
id
()].
inp_marker
.
insert
(
iomarker
->
mark
());
else
m_vars
[
output
->
id
()].
out_marker
.
insert
(
iomarker
->
mark
());
m_seq
.
push_back
({{},
{
tracing_var
->
id
()},
{
output
->
id
()},
OpKind
::
IOMarkVar
});
return
{
output
};
return
{
output
};
}
else
if
(
auto
*
trace_name_var
=
op
.
as
<
RenameValue
>
())
{
}
else
if
(
auto
*
trace_name_var
=
op
.
as
<
RenameValue
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"RenameValue expects exactly one input"
);
mgb_assert
(
inputs
.
size
()
==
1
,
"RenameValue expects exactly one input"
);
...
@@ -259,7 +283,7 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -259,7 +283,7 @@ ValueRefList TracingTransformation::apply_transformation(
}
}
auto
output
=
record_var
(
input
,
false
,
VarKind
::
Internal
);
auto
output
=
record_var
(
input
,
false
,
VarKind
::
Internal
);
m_vars
[
output
->
id
()].
name
=
trace_name_var
->
name
();
m_vars
[
output
->
id
()].
name
=
trace_name_var
->
name
();
m_seq
.
push_back
({{},
{
tracing_var
->
id
()},
{
output
->
id
()}});
m_seq
.
push_back
({{},
{
tracing_var
->
id
()},
{
output
->
id
()}
,
OpKind
::
Rename
});
return
{
output
};
return
{
output
};
}
else
if
(
op
.
is
<
GetName
>
())
{
}
else
if
(
op
.
is
<
GetName
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"GetName expects exactly one input"
);
mgb_assert
(
inputs
.
size
()
==
1
,
"GetName expects exactly one input"
);
...
@@ -279,6 +303,78 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -279,6 +303,78 @@ ValueRefList TracingTransformation::apply_transformation(
}
}
}
}
void
TracingTransformation
::
postprocess_trace_result
()
{
std
::
unordered_map
<
size_t
,
size_t
>
identity_oi_map
,
identity_io_map
;
for
(
auto
&&
op
:
m_seq
)
{
if
(
op
.
op
==
nullptr
&&
op
.
inputs
.
size
()
==
1
&&
op
.
outputs
.
size
()
==
1
)
{
identity_oi_map
[
op
.
outputs
[
0
]]
=
op
.
inputs
[
0
];
identity_io_map
[
op
.
inputs
[
0
]]
=
op
.
outputs
[
0
];
}
}
for
(
auto
&&
op
:
m_seq
)
{
if
(
op
.
kind
==
OpKind
::
IOMarkVar
)
{
auto
&&
inpvar
=
m_vars
[
op
.
inputs
[
0
]];
auto
&&
outvar
=
m_vars
[
op
.
outputs
[
0
]];
if
(
inpvar
.
inp_marker
.
size
()
>
0
)
{
auto
id
=
inpvar
.
id
;
if
(
inpvar
.
kind
!=
VarKind
::
External
)
{
while
(
identity_oi_map
.
find
(
id
)
!=
identity_oi_map
.
end
())
{
id
=
identity_oi_map
[
id
];
}
if
(
m_vars
[
id
].
kind
==
VarKind
::
External
)
{
for
(
auto
mark
:
inpvar
.
inp_marker
)
{
mgb_assert
(
inpmark_to_id
.
find
(
mark
)
==
inpmark_to_id
.
end
()
||
inpmark_to_id
[
mark
]
==
id
,
"two nodes have same mark"
);
inpmark_to_id
[
mark
]
=
id
;
m_vars
[
id
].
inp_marker
.
insert
(
mark
);
}
inpvar
.
inp_marker
.
clear
();
}
}
else
{
for
(
auto
mark
:
inpvar
.
inp_marker
)
{
mgb_assert
(
inpmark_to_id
.
find
(
mark
)
==
inpmark_to_id
.
end
()
||
inpmark_to_id
[
mark
]
==
id
,
"two nodes have same mark"
);
inpmark_to_id
[
mark
]
=
id
;
}
}
}
else
{
mgb_assert
(
outvar
.
out_marker
.
size
()
>
0
);
auto
id
=
outvar
.
id
;
if
(
!
outvar
.
data_required
)
{
while
(
identity_io_map
.
find
(
id
)
!=
identity_io_map
.
end
())
{
id
=
identity_io_map
[
id
];
}
if
(
m_vars
[
id
].
data_required
)
{
for
(
auto
mark
:
outvar
.
out_marker
)
{
mgb_assert
(
outmark_to_id
.
find
(
mark
)
==
outmark_to_id
.
end
()
||
outmark_to_id
[
mark
]
==
id
,
"two nodes have same mark"
);
outmark_to_id
[
mark
]
=
id
;
m_vars
[
id
].
out_marker
.
insert
(
mark
);
}
outvar
.
out_marker
.
clear
();
}
}
else
{
for
(
auto
mark
:
outvar
.
out_marker
)
{
mgb_assert
(
outmark_to_id
.
find
(
mark
)
==
outmark_to_id
.
end
()
||
outmark_to_id
[
mark
]
==
id
,
"two nodes have same mark"
);
outmark_to_id
[
mark
]
=
id
;
}
}
}
}
}
}
void
TracingTransformation
::
on_unregister
()
noexcept
{
void
TracingTransformation
::
on_unregister
()
noexcept
{
for
(
auto
&&
weak_var
:
m_weak_vars
)
{
for
(
auto
&&
weak_var
:
m_weak_vars
)
{
if
(
auto
tracing_value
=
weak_var
.
lock
())
{
if
(
auto
tracing_value
=
weak_var
.
lock
())
{
...
@@ -526,7 +622,10 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
...
@@ -526,7 +622,10 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
var
.
device
->
to_string
().
c_str
(),
var
.
device
->
to_string
().
c_str
(),
device
.
to_string
().
c_str
());
device
.
to_string
().
c_str
());
}
}
var_accessor
.
data_setter
(
value
.
dev_tensor
()
->
as_nd
());
if
(
m_setted_extern
.
find
(
id
)
==
m_setted_extern
.
end
())
{
var_accessor
.
data_setter
(
value
.
dev_tensor
()
->
as_nd
());
m_setted_extern
.
insert
(
id
);
}
break
;
break
;
}
}
case
VarKind
::
Constant
:
{
case
VarKind
::
Constant
:
{
...
@@ -732,6 +831,7 @@ void CompiledTransformation::wait() {
...
@@ -732,6 +831,7 @@ void CompiledTransformation::wait() {
m_pc
=
0
;
m_pc
=
0
;
std
::
exception_ptr
graph_exc
;
std
::
exception_ptr
graph_exc
;
std
::
swap
(
m_graph_exc
,
graph_exc
);
std
::
swap
(
m_graph_exc
,
graph_exc
);
m_setted_extern
.
clear
();
if
(
graph_exc
)
{
if
(
graph_exc
)
{
// graph with exception cannot be reused
// graph with exception cannot be reused
recompile
();
recompile
();
...
...
imperative/src/impl/value.cpp
浏览文件 @
b11d4430
...
@@ -127,6 +127,10 @@ bool ValueRef::watching() const {
...
@@ -127,6 +127,10 @@ bool ValueRef::watching() const {
return
this
->
storage
()
->
m_watching
;
return
this
->
storage
()
->
m_watching
;
}
}
int
ValueRef
::
handle_id
()
const
{
return
imperative
::
apply
(
GetId
(),
*
this
)[
0
].
cast
<
IntegerValue
>
();
}
ValueRef
ValueRef
::
make
(
ValueRef
::
storage_t
storage
)
{
ValueRef
ValueRef
::
make
(
ValueRef
::
storage_t
storage
)
{
if
(
recording_values
)
{
if
(
recording_values
)
{
recorded_values
.
push_back
({
storage
});
recorded_values
.
push_back
({
storage
});
...
...
imperative/src/include/megbrain/imperative/basic_operators.h
浏览文件 @
b11d4430
...
@@ -141,6 +141,14 @@ public:
...
@@ -141,6 +141,14 @@ public:
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
};
};
class
GetId
final
:
public
OperatorImpl
<
GetId
,
Operator
::
GetAttrLike
>
{
public:
std
::
string
to_string
()
const
override
;
std
::
string
raw_type
()
const
{
return
"GetId"
;
}
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
};
/**
/**
* \brief return a value with new name
* \brief return a value with new name
*
*
...
...
imperative/src/include/megbrain/imperative/basic_values.h
浏览文件 @
b11d4430
...
@@ -48,6 +48,25 @@ public:
...
@@ -48,6 +48,25 @@ public:
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
Integer
{
private:
int64_t
m_value
;
public:
Integer
()
=
default
;
Integer
(
int64_t
value
)
:
m_value
(
value
)
{}
operator
int64_t
()
const
{
return
m_value
;
}
};
// TODO: override factory method
class
IntegerValue
final
:
public
PrimitiveValue
<
IntegerValue
,
Integer
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
;
};
class
HostStorage
final
:
public
PrimitiveValue
<
HostStorage
,
HostTensorStorage
>
{
class
HostStorage
final
:
public
PrimitiveValue
<
HostStorage
,
HostTensorStorage
>
{
public:
public:
using
PrimitiveValue
::
PrimitiveValue
;
using
PrimitiveValue
::
PrimitiveValue
;
...
...
imperative/src/include/megbrain/imperative/op_def.h
浏览文件 @
b11d4430
...
@@ -80,6 +80,8 @@ public:
...
@@ -80,6 +80,8 @@ public:
const
std
::
string
make_name
()
const
;
const
std
::
string
make_name
()
const
;
virtual
const
std
::
string
type_name
()
const
;
void
set_scope
(
const
std
::
string
&
scope
);
void
set_scope
(
const
std
::
string
&
scope
);
virtual
size_t
hash
()
const
;
virtual
size_t
hash
()
const
;
...
...
imperative/src/include/megbrain/imperative/ops/opr_attr.h
浏览文件 @
b11d4430
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/plugin/opr_footprint.h"
namespace
mgb
{
namespace
mgb
{
namespace
imperative
{
namespace
imperative
{
...
@@ -28,6 +29,7 @@ public:
...
@@ -28,6 +29,7 @@ public:
Type
type
;
Type
type
;
Param
param
;
Param
param
;
Typeinfo
*
mgb_opr_type
;
megdnn
::
param
::
ExecutionPolicy
policy
;
megdnn
::
param
::
ExecutionPolicy
policy
;
cg
::
OperatorNodeConfig
config
;
cg
::
OperatorNodeConfig
config
;
...
@@ -36,13 +38,14 @@ public:
...
@@ -36,13 +38,14 @@ public:
OprAttr
(
const
Type
&
t
,
const
Param
&
p
,
const
cg
::
OperatorNodeConfig
&
c
)
OprAttr
(
const
Type
&
t
,
const
Param
&
p
,
const
cg
::
OperatorNodeConfig
&
c
)
:
type
(
t
),
param
(
p
),
config
(
c
)
{}
:
type
(
t
),
param
(
p
),
config
(
c
)
{}
OprAttr
(
const
Type
&
t
,
const
Param
&
p
,
const
megdnn
::
param
::
ExecutionPolicy
ps
,
OprAttr
(
const
Type
&
t
,
const
Param
&
p
,
const
megdnn
::
param
::
ExecutionPolicy
ps
,
const
cg
::
OperatorNodeConfig
&
c
)
const
cg
::
OperatorNodeConfig
&
c
,
Typeinfo
*
optype
)
:
type
(
t
),
param
(
p
),
policy
(
ps
),
config
(
c
)
{}
:
type
(
t
),
param
(
p
),
policy
(
ps
),
config
(
c
)
,
mgb_opr_type
(
optype
)
{}
std
::
string
repr
()
const
;
std
::
string
repr
()
const
;
std
::
shared_ptr
<
json
::
Value
>
mgb_param
(
OprFootprint
*
);
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
;
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
;
size_t
hash
()
const
override
;
size_t
hash
()
const
override
;
const
std
::
string
type_name
()
const
override
{
return
type
;
}
};
};
}
// namespace imperative
}
// namespace imperative
...
...
imperative/src/include/megbrain/imperative/transformations/grad.h
浏览文件 @
b11d4430
...
@@ -107,7 +107,7 @@ private:
...
@@ -107,7 +107,7 @@ private:
public:
public:
std
::
string
to_string
()
const
;
std
::
string
to_string
()
const
;
ValueRef
grad
()
const
{
return
m_grad
;
}
friend
class
GradKey
;
friend
class
GradKey
;
friend
class
GradSlotProducerPtr
;
friend
class
GradSlotProducerPtr
;
friend
class
GradTransformation
;
friend
class
GradTransformation
;
...
@@ -224,6 +224,7 @@ public:
...
@@ -224,6 +224,7 @@ public:
class
GradKey
:
public
std
::
enable_shared_from_this
<
GradKey
>
{
class
GradKey
:
public
std
::
enable_shared_from_this
<
GradKey
>
{
private:
private:
std
::
string
m_name
;
std
::
string
m_name
;
std
::
vector
<
LocalPtr
<
GradFn
>>
m_side_effects
;
std
::
vector
<
std
::
pair
<
LocalWeakPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_tape
;
std
::
vector
<
std
::
pair
<
LocalWeakPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_tape
;
std
::
vector
<
std
::
pair
<
LocalPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_frozen_tape
;
std
::
vector
<
std
::
pair
<
LocalPtr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_frozen_tape
;
bool
m_frozen
=
false
;
bool
m_frozen
=
false
;
...
@@ -253,6 +254,13 @@ public:
...
@@ -253,6 +254,13 @@ public:
}
}
};
};
class
GradSlotValue
final
:
public
PrimitiveValue
<
GradSlotValue
,
GradSlotPtr
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GradSlot{}"
);
}
};
class
GradTransformation
final
:
public
Transformation
{
class
GradTransformation
final
:
public
Transformation
{
private:
private:
ObjectType
<
GradValue
>
m_value_type
{
"GradValue"
};
ObjectType
<
GradValue
>
m_value_type
{
"GradValue"
};
...
@@ -404,6 +412,28 @@ public:
...
@@ -404,6 +412,28 @@ public:
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
};
};
class
GetGradSlot
:
public
OperatorImpl
<
GetGradSlot
,
Operator
::
GetAttrLike
>
{
public:
GetGradSlot
()
=
default
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GetGradSlot{}"
);
}
std
::
string
raw_type
()
const
{
return
"GetGradSlot"
;
};
ValueRefList
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{};
}
};
class
InsertGradCallback
:
public
OperatorImpl
<
InsertGradCallback
,
Operator
::
Other
>
{
public:
GenericFunction
m_callback
;
public:
InsertGradCallback
(
GenericFunction
callback
)
:
m_callback
(
callback
)
{}
GenericFunction
callback
()
const
{
return
m_callback
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"InsertGradCallback{}"
);
}
std
::
string
raw_type
()
const
{
return
"InsertGradCallback"
;
}
};
class
GetBackwardColsure
class
GetBackwardColsure
:
public
OperatorImpl
<
GetBackwardColsure
,
Operator
::
GetAttrLike
>
{
:
public
OperatorImpl
<
GetBackwardColsure
,
Operator
::
GetAttrLike
>
{
private:
private:
...
@@ -420,4 +450,19 @@ public:
...
@@ -420,4 +450,19 @@ public:
std
::
string
raw_type
()
const
{
return
"GetBackwardClosure"
;
}
std
::
string
raw_type
()
const
{
return
"GetBackwardClosure"
;
}
};
};
class
GradTransformationGuard
final
:
public
Transformation
{
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
{
if
(
auto
*
igc
=
op
.
as
<
InsertGradCallback
>
())
{
auto
count
=
IntegerValue
::
make
(
0
);
return
{
count
};
}
return
imperative
::
apply
(
op
,
inputs
);
}
ValueRef
unwrap
(
ValueRef
value
)
override
{
return
value
;
};
std
::
string
name
()
const
override
{
return
"GradTransformationGuard"
;
};
};
}
// namespace mgb::imperative
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/transformations/trace.h
浏览文件 @
b11d4430
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
#include <chrono>
#include <chrono>
#include <future>
#include <future>
#include <set>
#include <variant>
#include <variant>
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/interpreter.h"
...
@@ -17,11 +17,21 @@ namespace mgb::imperative {
...
@@ -17,11 +17,21 @@ namespace mgb::imperative {
struct
TraceResult
{
struct
TraceResult
{
struct
SeqItem
{
struct
SeqItem
{
enum
OpKind
{
Unknown
,
TraceMarkVar
,
Rename
,
IOMarkVar
,
CreateTensor
,
};
std
::
shared_ptr
<
OpDef
>
op
;
std
::
shared_ptr
<
OpDef
>
op
;
SmallVector
<
size_t
>
inputs
;
SmallVector
<
size_t
>
inputs
;
SmallVector
<
size_t
>
outputs
;
SmallVector
<
size_t
>
outputs
;
OpKind
kind
=
OpKind
::
Unknown
;
};
};
using
OpKind
=
SeqItem
::
OpKind
;
struct
VarInfo
{
struct
VarInfo
{
enum
Kind
{
enum
Kind
{
External
,
// End point of traced graph, its value is received from
External
,
// End point of traced graph, its value is received from
...
@@ -41,12 +51,14 @@ struct TraceResult {
...
@@ -41,12 +51,14 @@ struct TraceResult {
ValueRef
bound_data
;
ValueRef
bound_data
;
std
::
string
mark
;
std
::
string
mark
;
std
::
string
name
;
std
::
string
name
;
int
handle_id
;
Kind
kind
;
Kind
kind
;
bool
value_required
=
false
;
bool
value_required
=
false
;
bool
data_required
=
false
;
bool
data_required
=
false
;
bool
shape_required
=
false
;
bool
shape_required
=
false
;
std
::
set
<
size_t
>
inp_marker
;
std
::
set
<
size_t
>
out_marker
;
TensorShape
shape
;
TensorShape
shape
;
};
};
...
@@ -91,6 +103,27 @@ public:
...
@@ -91,6 +103,27 @@ public:
std
::
string
raw_type
()
const
{
return
"TraceMarkVar"
;
}
std
::
string
raw_type
()
const
{
return
"TraceMarkVar"
;
}
};
};
class
IOMarkVar
:
public
OperatorImpl
<
IOMarkVar
,
Operator
::
IdentityLike
>
{
public:
enum
Kind
{
Input
,
Output
,
};
private:
size_t
m_mark
;
Kind
m_kind
;
public:
IOMarkVar
(
size_t
mark
,
Kind
kind
)
:
m_mark
(
mark
),
m_kind
(
kind
)
{}
size_t
mark
()
const
{
return
m_mark
;
}
Kind
kind
()
const
{
return
m_kind
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"IOMarkVar"
);
}
std
::
string
raw_type
()
const
override
{
return
"IOMarkVar"
;
}
};
class
TracingValue
final
:
public
ObjectValue
<
TracingValue
>
{
class
TracingValue
final
:
public
ObjectValue
<
TracingValue
>
{
private:
private:
ValueRef
m_value
=
{};
ValueRef
m_value
=
{};
...
@@ -125,15 +158,22 @@ class TracingTransformation final : public Transformation {
...
@@ -125,15 +158,22 @@ class TracingTransformation final : public Transformation {
public:
public:
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarKind
=
VarInfo
::
Kind
;
using
VarKind
=
VarInfo
::
Kind
;
using
OpKind
=
TraceResult
::
SeqItem
::
OpKind
;
private:
private:
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
TracingValue
::
weak_ref_t
>
m_weak_vars
;
std
::
vector
<
TracingValue
::
weak_ref_t
>
m_weak_vars
;
std
::
unordered_map
<
size_t
,
size_t
>
extern_var_to_id
;
bool
m_capture_as_const
=
false
;
bool
m_capture_as_const
=
false
;
bool
m_record_input_shapes
=
false
;
bool
m_record_input_shapes
=
false
;
bool
m_record_all_shapes
=
false
;
ObjectType
<
TracingValue
>
m_value_type
{
"TracingValue"
};
ObjectType
<
TracingValue
>
m_value_type
{
"TracingValue"
};
public:
std
::
unordered_map
<
size_t
,
size_t
>
inpmark_to_id
;
std
::
unordered_map
<
size_t
,
size_t
>
outmark_to_id
;
public:
public:
TracingTransformation
(
bool
capture_as_const
,
bool
record_input_shapes
)
TracingTransformation
(
bool
capture_as_const
,
bool
record_input_shapes
)
:
m_capture_as_const
(
capture_as_const
),
:
m_capture_as_const
(
capture_as_const
),
...
@@ -148,7 +188,14 @@ public:
...
@@ -148,7 +188,14 @@ public:
* \return TypedValueRef<TracingValue> traced value
* \return TypedValueRef<TracingValue> traced value
*/
*/
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
if
(
kind
==
VarKind
::
External
&&
extern_var_to_id
.
find
(
value
.
id
())
!=
extern_var_to_id
.
end
())
{
return
m_value_type
.
make
(
value
,
extern_var_to_id
[
value
.
id
()]);
}
size_t
id
=
m_vars
.
size
();
size_t
id
=
m_vars
.
size
();
if
(
kind
==
VarKind
::
External
)
{
extern_var_to_id
[
value
.
id
()]
=
id
;
}
auto
wrapped_value
=
m_value_type
.
make
(
value
,
id
);
auto
wrapped_value
=
m_value_type
.
make
(
value
,
id
);
m_vars
.
push_back
({
id
,
value
.
dtype
(),
value
.
device
()});
m_vars
.
push_back
({
id
,
value
.
dtype
(),
value
.
device
()});
auto
&
var
=
m_vars
.
back
();
auto
&
var
=
m_vars
.
back
();
...
@@ -156,9 +203,12 @@ public:
...
@@ -156,9 +203,12 @@ public:
var
.
bound_data
=
value
;
var
.
bound_data
=
value
;
}
}
var
.
kind
=
kind
;
var
.
kind
=
kind
;
if
(
m_record_input_shapes
&&
kind
!=
VarKind
::
Internal
)
{
if
((
m_record_input_shapes
&&
kind
!=
VarKind
::
Internal
)
||
m_record_all_shapes
)
{
var
.
shape
=
value
.
shape
()
->
as_tensor_shape
();
var
.
shape
=
value
.
shape
()
->
as_tensor_shape
();
}
}
if
(
m_record_all_shapes
)
var
.
handle_id
=
value
.
handle_id
();
if
(
auto
name
=
value
.
name
())
{
if
(
auto
name
=
value
.
name
())
{
var
.
name
=
*
name
;
var
.
name
=
*
name
;
}
}
...
@@ -185,8 +235,9 @@ public:
...
@@ -185,8 +235,9 @@ public:
std
::
string
name
()
const
override
{
return
"TracingTransformation"
;
}
std
::
string
name
()
const
override
{
return
"TracingTransformation"
;
}
void
on_unregister
()
noexcept
override
;
void
on_unregister
()
noexcept
override
;
void
postprocess_trace_result
();
TraceResult
get_result
()
{
return
{
m_seq
,
m_vars
};
}
TraceResult
get_result
()
{
return
{
m_seq
,
m_vars
};
}
void
enable_record_all_shapes
()
{
m_record_all_shapes
=
true
;
}
};
};
class
TraceError
:
public
std
::
exception
{
class
TraceError
:
public
std
::
exception
{
...
@@ -211,6 +262,7 @@ class CompiledTransformation final : public Transformation {
...
@@ -211,6 +262,7 @@ class CompiledTransformation final : public Transformation {
public:
public:
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarKind
=
VarInfo
::
Kind
;
using
VarKind
=
VarInfo
::
Kind
;
using
OpKind
=
TraceResult
::
SeqItem
::
OpKind
;
struct
VarAccessor
{
struct
VarAccessor
{
VarNode
*
node
;
VarNode
*
node
;
...
@@ -254,6 +306,7 @@ private:
...
@@ -254,6 +306,7 @@ private:
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
VarAccessor
>
m_var_accessors
;
std
::
vector
<
VarAccessor
>
m_var_accessors
;
std
::
unordered_map
<
std
::
string
,
size_t
>
mark2id
;
size_t
m_pc
=
0
;
size_t
m_pc
=
0
;
std
::
shared_ptr
<
ComputingGraph
>
m_graph
;
std
::
shared_ptr
<
ComputingGraph
>
m_graph
;
std
::
unique_ptr
<
cg
::
AsyncExecutable
>
m_executable
;
std
::
unique_ptr
<
cg
::
AsyncExecutable
>
m_executable
;
...
@@ -268,6 +321,7 @@ private:
...
@@ -268,6 +321,7 @@ private:
std
::
vector
<
std
::
shared_ptr
<
BoxBase
>>
m_boxes
;
std
::
vector
<
std
::
shared_ptr
<
BoxBase
>>
m_boxes
;
ComputingGraph
::
OutputSpec
m_output_spec
;
ComputingGraph
::
OutputSpec
m_output_spec
;
ObjectType
<
TracedValue
>
m_value_type
{
"TracedValue"
};
ObjectType
<
TracedValue
>
m_value_type
{
"TracedValue"
};
std
::
set
<
size_t
>
m_setted_extern
;
public:
public:
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
)
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
)
...
@@ -360,8 +414,10 @@ public:
...
@@ -360,8 +414,10 @@ public:
return
value
;
return
value
;
}
}
std
::
string
name
()
const
override
{
return
"CompiledTransformation"
;
}
VarAccessor
&
get_accessor_by_id
(
size_t
id
)
{
return
m_var_accessors
[
id
]
;
}
std
::
string
name
()
const
override
{
return
"CompiledTransformation"
;
}
void
set_pc_to_end
()
{
m_pc
=
m_seq
.
size
();
}
void
execute
();
void
execute
();
void
wait
();
void
wait
();
...
...
imperative/src/include/megbrain/imperative/value.h
浏览文件 @
b11d4430
...
@@ -222,6 +222,7 @@ public:
...
@@ -222,6 +222,7 @@ public:
TypedValueRef
<
DTypeValue
>
dtype
()
const
;
TypedValueRef
<
DTypeValue
>
dtype
()
const
;
TypedValueRef
<
FormatValue
>
format
()
const
;
TypedValueRef
<
FormatValue
>
format
()
const
;
TypedValueRef
<
StringValue
>
name
()
const
;
TypedValueRef
<
StringValue
>
name
()
const
;
int
handle_id
()
const
;
bool
is_scalar
()
const
;
bool
is_scalar
()
const
;
void
watch
()
const
;
void
watch
()
const
;
...
@@ -298,7 +299,7 @@ protected:
...
@@ -298,7 +299,7 @@ protected:
public:
public:
const
IType
&
type
()
const
{
return
*
m_type
;
}
const
IType
&
type
()
const
{
return
*
m_type
;
}
uint64_t
id
()
const
{
return
m_id
;
}
static
void
register_value
(
ValueRef
value
);
static
void
register_value
(
ValueRef
value
);
static
ValueRef
get_value_by_id
(
uint64_t
id
);
static
ValueRef
get_value_by_id
(
uint64_t
id
);
static
void
begin_record_values
();
static
void
begin_record_values
();
...
@@ -538,11 +539,11 @@ public:
...
@@ -538,11 +539,11 @@ public:
const
ValueRef
*
data
()
const
{
return
m_data
;
}
const
ValueRef
*
data
()
const
{
return
m_data
;
}
bool
empty
()
const
{
return
m_size
==
0
;
}
bool
empty
()
const
{
return
m_size
==
0
;
}
ValueRef
&
front
()
{
ValueRef
&
front
()
{
mgb_assert
(
m_size
>
1
);
mgb_assert
(
m_size
>
=
1
);
return
m_data
[
0
];
return
m_data
[
0
];
}
}
ValueRef
&
back
()
{
ValueRef
&
back
()
{
mgb_assert
(
m_size
>
1
);
mgb_assert
(
m_size
>
=
1
);
return
m_data
[
m_size
-
1
];
return
m_data
[
m_size
-
1
];
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录