Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d2b67c2a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d2b67c2a
编写于
1月 14, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dispatch): implement trace
GitOrigin-RevId: f8d3005732dad0f941d963e8e529f1c11d2d3ca5
上级
39ac606b
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
1027 addition
and
0 deletion
+1027
-0
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+679
-0
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+348
-0
未找到文件。
imperative/src/impl/transformations/trace.cpp
0 → 100644
浏览文件 @
d2b67c2a
/**
* \file imperative/src/impl/transformations/trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/transformations/trace.h"
#include <chrono>
#include <exception>
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/helper.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/serializer.h"
#include "../event_pool.h"
#define trace_assert(_cond, _msg...) \
do { \
if (mgb_unlikely(!(_cond))) { \
auto exc = std::make_exception_ptr(TraceError(ssprintf(_msg))); \
set_exception(exc); \
std::rethrow_exception(exc); \
} \
} while (0)
namespace
mgb
{
namespace
imperative
{
VarNodeArray
TraceResult
::
dump
(
ComputingGraph
&
graph
,
std
::
vector
<
std
::
tuple
<
size_t
,
std
::
string
,
TensorShape
>>
inputs
,
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
outputs
,
bool
prefer_input_names
)
{
// var -> VarNode
std
::
vector
<
VarNode
*>
nodes
(
vars
.
size
(),
nullptr
);
// make h2d node for each input
for
(
auto
&&
[
input
,
name
,
shape
]
:
inputs
)
{
auto
&
var
=
vars
[
input
];
auto
&
node
=
nodes
[
input
];
// TODO: cambricon CompNode
auto
host
=
std
::
make_shared
<
HostTensorND
>
(
CompNode
::
load
(
"xpux"
),
shape
,
var
.
dtype
);
OperatorNodeConfig
config
;
// if prefer_input_names, prefer names from dump args
// else prefer names got from trace procedure
if
(
prefer_input_names
&&
!
name
.
empty
())
{
config
.
name
(
name
);
}
else
if
(
!
var
.
name
.
empty
())
{
config
.
name
(
var
.
name
);
}
else
if
(
!
name
.
empty
())
{
config
.
name
(
name
);
}
node
=
opr
::
Host2DeviceCopy
::
make
(
graph
,
host
,
{},
config
).
node
();
}
// make const node for each constant
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
auto
&
var
=
vars
[
i
];
auto
&
node
=
nodes
[
i
];
if
(
!
node
)
{
if
(
var
.
kind
!=
VarKind
::
Internal
)
{
if
(
!
var
.
bound_data
)
{
continue
;
}
if
(
!
var
.
name
.
empty
())
{
node
=
opr
::
ImmutableTensor
::
make
(
graph
,
var
.
bound_data
.
numpy
()
->
as_nd
(),
{
var
.
name
})
.
node
();
}
else
{
node
=
opr
::
ImmutableTensor
::
make
(
graph
,
var
.
bound_data
.
numpy
()
->
as_nd
())
.
node
();
}
}
}
}
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
cg
::
OperatorNodeBase
*>>
name2ops
;
// iterate over opr_seq
for
(
auto
&&
item
:
seq
)
{
auto
&&
[
op
,
inputs
,
outputs
]
=
item
;
VarNodeArray
input_nodes
;
for
(
auto
&&
input
:
inputs
)
{
auto
&
node
=
nodes
[
input
];
input_nodes
.
push_back
(
node
);
}
VarNodeArray
output_nodes
;
if
(
op
)
{
if
(
auto
*
bn
=
op
->
try_cast_final
<
BatchNorm
>
())
{
mgb_assert
(
bn
->
fwd_mode
==
BatchNorm
::
FwdMode
::
INFERENCE
,
"can not dump BatchNorm in training mode, maybe you forget to "
"do model.eval()?"
);
}
output_nodes
=
OpDef
::
apply_on_var_node
(
*
op
,
input_nodes
);
name2ops
[
output_nodes
[
0
]
->
owner_opr
()
->
name
()].
push_back
(
output_nodes
[
0
]
->
owner_opr
());
}
else
{
// no opr, just forward VarNode
mgb_assert
(
inputs
.
size
()
==
outputs
.
size
(),
"output size not equals to input size when forwarding"
);
output_nodes
=
input_nodes
;
}
mgb_assert
(
output_nodes
.
size
()
==
outputs
.
size
(),
"output size mismatch"
);
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
auto
output
=
outputs
[
i
];
auto
&
var
=
vars
[
output
];
auto
&
node
=
nodes
[
output
];
mgb_assert
(
var
.
kind
==
VarKind
::
Internal
,
"output node should be internal"
);
if
(
!
node
)
{
node
=
output_nodes
[
i
];
}
if
(
!
var
.
name
.
empty
())
{
node
->
name
(
var
.
name
);
}
}
}
for
(
auto
&&
[
name
,
ops
]
:
name2ops
)
{
if
(
ops
.
size
()
<=
1
)
{
continue
;
}
// ops.size() > 1, need dedup (rename op)
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
auto
&
op
=
ops
[
i
];
auto
new_name
=
ssprintf
(
"%s[%zu]"
,
name
.
c_str
(),
i
);
for
(
auto
&&
output
:
op
->
output
())
{
auto
output_name
=
output
->
name
();
auto
pos
=
output_name
.
find
(
name
);
if
(
pos
!=
std
::
string
::
npos
)
{
output_name
.
replace
(
pos
,
name
.
length
(),
new_name
);
}
output
->
name
(
output_name
);
}
op
->
name
(
new_name
);
}
}
VarNodeArray
output_nodes
;
for
(
auto
&&
[
output
,
name
]
:
outputs
)
{
mgb_assert
(
output
<
vars
.
size
(),
"invalid output id %zu"
,
output
);
mgb_assert
(
nodes
[
output
],
"output node invalid"
);
if
(
!
name
.
empty
())
{
nodes
[
output
]
->
name
(
name
);
}
output_nodes
.
push_back
(
nodes
[
output
]);
}
return
output_nodes
;
}
std
::
vector
<
ValueRef
>
TracingTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_value
=
op
.
as
<
ApplyOp
>
())
{
SmallVector
<
ValueRef
>
unwrapped_inputs
;
SmallVector
<
TracingValue
::
ref_t
>
wrapped_inputs
;
SmallVector
<
size_t
>
input_ids
;
for
(
auto
input
:
inputs
)
{
auto
tracing_value
=
input
.
as_ref
<
TracingValue
>
();
if
(
!
tracing_value
)
{
tracing_value
=
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
}
unwrapped_inputs
.
push_back
(
tracing_value
->
value
());
wrapped_inputs
.
push_back
(
tracing_value
);
input_ids
.
push_back
(
tracing_value
->
id
());
}
// TODO: remove OpDef::set_scope
auto
scopes
=
Transformation
::
scopes
();
std
::
string
scopes_join
;
for
(
auto
&&
scope
:
scopes
)
{
if
(
!
scopes_join
.
empty
())
{
scopes_join
.
push_back
(
'.'
);
}
scopes_join
.
append
(
scope
);
}
const_cast
<
OpDef
&>
(
op_value
->
op
()).
set_scope
(
scopes_join
);
auto
unwrapped_outputs
=
imperative
::
apply
(
op
,
unwrapped_inputs
);
std
::
vector
<
ValueRef
>
wrapped_outputs
;
SmallVector
<
size_t
>
output_ids
;
for
(
auto
&&
output
:
unwrapped_outputs
)
{
auto
wrapped_output
=
record_var
(
output
,
false
,
VarKind
::
Internal
);
wrapped_outputs
.
push_back
(
wrapped_output
);
output_ids
.
push_back
(
wrapped_output
->
id
());
}
m_seq
.
push_back
({
op_value
->
op
().
shared_from_this
(),
input_ids
,
output_ids
});
return
wrapped_outputs
;
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
auto
outputs
=
imperative
::
apply
(
op
,
inputs
);
if
(
create_tensor
->
kind
()
==
CreateTensor
::
NoTrace
)
{
return
outputs
;
}
bool
is_const
=
create_tensor
->
kind
()
==
CreateTensor
::
Const
;
auto
wrapped_input
=
record_var
(
outputs
[
0
],
is_const
||
m_capture_as_const
,
is_const
?
VarKind
::
Constant
:
VarKind
::
External
);
auto
wrapped_output
=
record_var
(
outputs
[
0
],
false
,
VarKind
::
Internal
);
auto
input_id
=
wrapped_input
->
id
();
auto
output_id
=
wrapped_output
->
id
();
m_seq
.
push_back
({{},
{
input_id
},
{
output_id
}});
return
{
wrapped_output
};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
unwrapped_input
=
unwrap_var
(
inputs
[
0
]);
auto
outputs
=
imperative
::
apply
(
op
,
unwrapped_input
);
if
(
auto
*
tracing_value
=
inputs
[
0
].
as
<
TracingValue
>
())
{
auto
&
var_info
=
m_vars
[
tracing_value
->
id
()];
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
// TODO: reduce h2d when data or value is available
var_info
.
shape_required
=
true
;
break
;
case
GetAttr
::
Data
:
var_info
.
data_required
=
true
;
break
;
case
GetAttr
::
Value
:
var_info
.
value_required
=
true
;
break
;
default:
break
;
}
}
return
outputs
;
}
else
if
(
auto
*
trace_mark_var
=
op
.
as
<
TraceMarkVar
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"TraceMarkVar expects exactly one input"
);
auto
input
=
inputs
[
0
];
auto
tracing_var
=
input
.
as_ref
<
TracingValue
>
();
if
(
!
tracing_var
)
{
bool
is_input
=
trace_mark_var
->
mark
().
substr
(
0
,
4
)
==
"arg_"
||
trace_mark_var
->
mark
().
substr
(
0
,
6
)
==
"kwarg_"
;
if
(
is_input
)
{
tracing_var
=
record_var
(
input
,
false
,
VarKind
::
External
);
}
else
{
tracing_var
=
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
}
}
else
{
input
=
tracing_var
->
value
();
}
auto
output
=
record_var
(
input
,
false
,
VarKind
::
Internal
);
m_vars
[
output
->
id
()].
mark
=
trace_mark_var
->
mark
();
m_seq
.
push_back
({{},
{
tracing_var
->
id
()},
{
output
->
id
()}});
return
{
output
};
}
else
if
(
auto
*
trace_name_var
=
op
.
as
<
RenameValue
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"RenameValue expects exactly one input"
);
auto
input
=
inputs
[
0
];
auto
tracing_var
=
input
.
as_ref
<
TracingValue
>
();
if
(
!
tracing_var
)
{
tracing_var
=
record_var
(
input
,
m_capture_as_const
,
VarKind
::
External
);
}
else
{
input
=
tracing_var
->
value
();
}
auto
output
=
record_var
(
input
,
false
,
VarKind
::
Internal
);
m_vars
[
output
->
id
()].
name
=
trace_name_var
->
name
();
m_seq
.
push_back
({{},
{
tracing_var
->
id
()},
{
output
->
id
()}});
return
{
output
};
}
else
if
(
op
.
is
<
GetName
>
())
{
mgb_assert
(
inputs
.
size
()
==
1
,
"GetName expects exactly one input"
);
auto
input
=
inputs
[
0
];
if
(
auto
tracing_var
=
input
.
as_ref
<
TracingValue
>
())
{
auto
name
=
m_vars
[
tracing_var
->
id
()].
name
;
if
(
!
name
.
empty
())
{
return
{
StringValue
::
make
(
name
)};
}
else
{
return
{
ValueRef
()};
}
}
return
imperative
::
apply
(
op
,
inputs
);
}
else
{
// TODO: handle DTRCommand and ...
return
op
.
fallback
(
inputs
);
}
}
void
TracingTransformation
::
on_unregister
()
noexcept
{
for
(
auto
&&
weak_var
:
m_weak_vars
)
{
if
(
auto
tracing_value
=
weak_var
.
lock
())
{
auto
&
var_info
=
m_vars
[
tracing_value
->
id
()];
var_info
.
data_required
=
true
;
tracing_value
.
reset
(
tracing_value
->
value
());
}
}
m_weak_vars
.
clear
();
}
void
CompiledTransformation
::
compile
()
{
// these ops require seq order, so we link them to an mm_io_link to ensure order
static
std
::
unordered_set
<
Typeinfo
*>
mm_io_ops
=
{
CollectiveComm
::
typeinfo
(),
RemoteSend
::
typeinfo
(),
RemoteRecv
::
typeinfo
()};
mgb_assert
(
!
m_executable
,
"already compiled"
);
// FIXME: mm_io_link and io_links should be merged
SymbolVarArray
io_links
;
SymbolVar
mm_io_link
;
auto
make_input
=
[
&
](
VarInfo
*
var_info
)
{
mgb_assert
(
var_info
->
kind
==
VarKind
::
External
,
"input node should be external"
);
VarAccessor
accessor
;
auto
box
=
make_box
<
DeviceTensorND
>
();
// TODO: attach ref count, release early
auto
outputs
=
opr
::
InputCallback
::
make
(
*
m_graph
,
[
box
]
{
return
box
->
take_value
();
},
var_info
->
device
,
var_info
->
dtype
,
var_info
->
shape
,
io_links
,
m_input_shape_static
);
// attach input_callback to io_links
accessor
.
node
=
outputs
[
0
].
node
();
io_links
=
{
outputs
[
1
]};
accessor
.
data_setter
=
[
box
](
DeviceTensorND
data
)
{
box
->
try_set_value
(
data
);
};
return
accessor
;
};
auto
make_output
=
[
&
](
TraceResult
::
VarInfo
*
var_info
,
SymbolVar
node
)
{
VarAccessor
accessor
;
accessor
.
node
=
node
.
node
();
if
(
var_info
->
shape_required
)
{
// TODO: use static infer manager for some vars?
auto
box
=
make_box
<
TensorShape
>
();
auto
callback
=
[
box
](
DeviceTensorND
data
)
{
box
->
try_set_value
(
data
.
shape
());
};
SymbolVarArray
inputs
=
io_links
;
inputs
.
insert
(
inputs
.
begin
(),
node
);
auto
output
=
opr
::
OutputCallback
::
make
({
callback
,
true
,
false
},
inputs
);
io_links
=
{
output
};
accessor
.
shape_getter
=
[
box
]()
->
TensorShape
{
return
box
->
get_value
();
};
}
if
(
var_info
->
data_required
)
{
auto
box
=
make_box
<
DeviceTensorND
>
();
auto
callback
=
[
box
](
DeviceTensorND
data
)
{
box
->
try_set_value
(
data
);
};
SymbolVarArray
inputs
=
io_links
;
inputs
.
insert
(
inputs
.
begin
(),
node
);
auto
output
=
opr
::
OutputCallback
::
make
({
callback
,
false
,
false
},
inputs
);
io_links
=
{
output
};
accessor
.
data_getter
=
[
box
]()
->
DeviceTensorND
{
return
box
->
get_value
();
};
}
if
(
var_info
->
value_required
)
{
struct
ValueWithEvent
{
HostTensorND
value
;
CompNode
::
Event
*
event
=
nullptr
;
};
auto
box
=
make_box
<
ValueWithEvent
>
();
auto
event
=
EventPool
::
without_timer
().
alloc_shared
(
var_info
->
device
);
auto
callback
=
[
box
,
event
](
DeviceTensorND
data
)
{
HostTensorND
host_val
;
host_val
.
copy_from
(
data
);
if
(
data
.
comp_node
()
!=
CompNode
::
default_cpu
())
{
mgb_assert
(
data
.
comp_node
()
==
event
->
comp_node
());
event
->
record
();
box
->
try_set_value
({
host_val
,
event
.
get
()});
}
else
{
box
->
try_set_value
({
host_val
});
}
};
SymbolVarArray
inputs
=
io_links
;
inputs
.
insert
(
inputs
.
begin
(),
node
);
auto
output
=
opr
::
OutputCallback
::
make
({
callback
,
false
,
true
},
inputs
);
io_links
=
{
output
};
accessor
.
value_getter
=
[
box
]()
->
HostTensorND
{
auto
&&
[
value
,
event
]
=
box
->
get_value
();
if
(
event
)
{
event
->
host_wait
();
}
return
value
;
};
}
return
accessor
;
};
auto
make_const
=
[
&
](
TraceResult
::
VarInfo
*
var_info
)
{
VarAccessor
accessor
;
mgb_assert
(
var_info
->
kind
==
VarKind
::
Constant
,
"const node should be constant"
);
HostTensorND
host_val
=
var_info
->
bound_data
.
numpy
()
->
as_nd
();
accessor
.
node
=
opr
::
ImmutableTensor
::
make
(
*
m_graph
,
host_val
).
node
();
return
accessor
;
};
std
::
vector
<
VarAccessor
>
var_accessors
(
m_vars
.
size
());
for
(
auto
&&
item
:
m_seq
)
{
bool
require_link
=
bool
(
item
.
op
)
&&
mm_io_ops
.
count
(
item
.
op
->
dyn_typeinfo
());
VarNodeArray
input_vars
;
for
(
auto
&&
input
:
item
.
inputs
)
{
auto
&
var
=
m_vars
[
input
];
if
(
!
var_accessors
[
input
].
node
)
{
switch
(
var
.
kind
)
{
case
VarKind
::
External
:
var_accessors
[
input
]
=
make_input
(
&
var
);
break
;
case
VarKind
::
Constant
:
var_accessors
[
input
]
=
make_const
(
&
var
);
break
;
default:
mgb_throw
(
AssertionError
,
"internal node should be valid when used as input"
);
}
}
input_vars
.
push_back
(
var_accessors
[
input
].
node
);
}
if
(
require_link
&&
mm_io_link
.
node
())
{
mgb_assert
(
!
input_vars
.
empty
(),
"io-mm operator should have at least one input"
);
input_vars
[
0
]
=
opr
::
VirtualDep
::
make
({
SymbolVar
(
input_vars
[
0
]),
mm_io_link
})
.
node
();
}
VarNodeArray
output_vars
;
if
(
item
.
op
)
{
output_vars
=
OpDef
::
apply_on_var_node
(
*
item
.
op
,
input_vars
);
}
else
{
// forward inputs to outputs
mgb_assert
(
item
.
inputs
.
size
()
==
item
.
outputs
.
size
(),
"output size not equals to input size when forwarding"
);
for
(
auto
&&
input_var
:
input_vars
)
{
output_vars
.
push_back
(
input_var
);
}
}
if
(
require_link
)
{
mgb_assert
(
!
item
.
outputs
.
empty
(),
"io-mm operator should have at least one output"
);
mm_io_link
=
SymbolVar
(
output_vars
[
0
]);
}
// init output accessors
for
(
size_t
i
=
0
;
i
<
output_vars
.
size
();
++
i
)
{
auto
output
=
item
.
outputs
[
i
];
auto
&
node
=
output_vars
[
i
];
auto
&
var
=
m_vars
[
output
];
var_accessors
[
output
]
=
make_output
(
&
var
,
node
);
}
}
ComputingGraph
::
OutputSpec
output_specs
;
// avoid input/output/callback from being optimized
for
(
auto
&&
io_link
:
io_links
)
{
output_specs
.
push_back
({
io_link
,
{}});
}
// avoid remote io ops from being optimized
if
(
mm_io_link
.
node
())
{
output_specs
.
push_back
({
mm_io_link
,
{}});
}
{
// set_priority_to_id
// workaround for having mm_io_link and io_links separated
auto
on_opr
=
[](
mgb
::
cg
::
OperatorNodeBase
*
opr
)
{
if
(
opr
->
node_prop
().
attribute
().
priority
==
0
)
{
opr
->
node_prop
().
attribute
().
priority
=
opr
->
id
();
}
};
mgb
::
cg
::
DepOprIter
dep_iter
{
on_opr
};
for
(
const
auto
&
output_spec
:
output_specs
)
{
dep_iter
.
add
(
output_spec
.
first
);
}
}
m_executable
=
m_graph
->
compile
(
output_specs
);
m_var_accessors
=
var_accessors
;
m_output_spec
=
output_specs
;
}
void
CompiledTransformation
::
recompile
()
{
mgb_assert
(
m_executable
);
m_executable
=
m_graph
->
compile
(
m_output_spec
);
}
void
CompiledTransformation
::
assert_tensor_equal
(
ValueRef
lhs
,
ValueRef
rhs
)
{
trace_assert
(
m_value_comparator
(
lhs
,
rhs
),
"tensors not equals"
);
}
void
CompiledTransformation
::
trace_input
(
size_t
id
,
ValueRef
value
)
{
try
{
auto
&
var
=
m_vars
[
id
];
auto
&
var_accessor
=
m_var_accessors
[
id
];
switch
(
var
.
kind
)
{
case
VarKind
::
External
:
{
trace_assert
(
!
value
.
is
<
TracedValue
>
(),
"expect external node, got internal"
);
if
(
var
.
bound_data
)
{
assert_tensor_equal
(
var
.
bound_data
,
value
);
}
else
{
DType
dtype
=
*
value
.
dtype
();
CompNode
device
=
*
value
.
device
();
trace_assert
(
var
.
dtype
==
dtype
,
"dtype mismatch: %s vs %s"
,
var
.
dtype
.
name
(),
dtype
.
name
());
trace_assert
(
var
.
device
==
device
,
"comp_node mismatch: %s vs %s"
,
var
.
device
.
to_string
().
c_str
(),
device
.
to_string
().
c_str
());
}
var_accessor
.
data_setter
(
value
.
dev_tensor
()
->
as_nd
());
break
;
}
case
VarKind
::
Constant
:
{
mgb_assert
(
var
.
bound_data
,
"const var without data bound"
);
assert_tensor_equal
(
var
.
bound_data
,
value
);
break
;
}
case
VarKind
::
Internal
:
{
trace_assert
(
value
.
is
<
TracedValue
>
(),
"expect internal node, got external"
);
auto
&
traced_value
=
value
.
cast
<
TracedValue
>
();
trace_assert
(
traced_value
.
id
()
==
id
,
"input id mismatch"
);
break
;
}
}
}
catch
(
TraceError
&
)
{
throw
;
}
catch
(...)
{
mgb_assert
(
false
,
"unexpected error"
);
}
}
TracedValue
::
ref_t
CompiledTransformation
::
trace_output
(
size_t
id
)
{
auto
traced_value
=
TracedValue
::
make
(
id
);
m_weak_values
.
push_back
(
traced_value
);
return
traced_value
;
}
TraceResult
::
SeqItem
&
CompiledTransformation
::
next_instruction
()
{
trace_assert
(
m_pc
<
m_seq
.
size
(),
"too many instructions"
);
return
m_seq
[
m_pc
++
];
}
std
::
vector
<
ValueRef
>
CompiledTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
op_value
=
op
.
as
<
ApplyOp
>
())
{
auto
&
item
=
next_instruction
();
SmallVector
<
ValueRef
>
unwrapped_inputs
;
SmallVector
<
ValueRef
>
wrapped_inputs
;
trace_assert
(
inputs
.
size
()
==
item
.
inputs
.
size
(),
"input size mismatch"
);
trace_assert
(
op_value
->
op
().
is_same
(
*
item
.
op
),
"operator mismatch"
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
trace_input
(
item
.
inputs
[
i
],
inputs
[
i
]);
}
std
::
vector
<
ValueRef
>
outputs
;
for
(
auto
&&
output_id
:
item
.
outputs
)
{
outputs
.
push_back
(
trace_output
(
output_id
));
}
return
outputs
;
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
if
(
create_tensor
->
kind
()
==
CreateTensor
::
NoTrace
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
auto
&
item
=
next_instruction
();
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
auto
input_id
=
item
.
inputs
[
0
];
auto
output_id
=
item
.
outputs
[
0
];
auto
tensor
=
imperative
::
apply
(
op
,
inputs
)[
0
];
trace_input
(
input_id
,
tensor
);
return
{
trace_output
(
output_id
)};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
if
(
auto
*
traced_value
=
inputs
[
0
].
as
<
TracedValue
>
())
{
ValueRef
output
;
auto
&
var
=
m_vars
[
traced_value
->
id
()];
auto
&
var_accessor
=
m_var_accessors
[
traced_value
->
id
()];
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
trace_assert
(
var_accessor
.
shape_getter
,
"shape unreadable"
);
output
=
ShapeValue
::
make
(
ValueShape
::
from
(
var_accessor
.
shape_getter
()));
break
;
case
GetAttr
::
Data
:
trace_assert
(
var_accessor
.
data_getter
,
"data unreadable"
);
output
=
DeviceValue
::
make
(
var_accessor
.
data_getter
());
break
;
case
GetAttr
::
Value
:
trace_assert
(
var_accessor
.
value_getter
,
"value unreadable"
);
output
=
HostValue
::
make
(
var_accessor
.
value_getter
());
break
;
case
GetAttr
::
DType
:
output
=
DTypeValue
::
make
(
var
.
dtype
);
break
;
case
GetAttr
::
Device
:
output
=
CompNodeValue
::
make
(
var
.
device
);
default:
break
;
}
return
{
output
};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
auto
*
trace_mark_var
=
op
.
as
<
TraceMarkVar
>
())
{
auto
&
item
=
next_instruction
();
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
trace_assert
(
item
.
inputs
.
size
()
==
1
,
"inputs size mismatch"
);
trace_assert
(
item
.
outputs
.
size
()
==
1
,
"inputs output mismatch"
);
trace_input
(
item
.
inputs
[
0
],
inputs
[
0
]);
trace_assert
(
trace_mark_var
->
mark
()
==
m_vars
[
item
.
outputs
[
0
]].
mark
,
"mark mismatch"
);
return
{
trace_output
(
item
.
outputs
[
0
])};
}
else
if
(
auto
*
trace_name_var
=
op
.
as
<
RenameValue
>
())
{
auto
&
item
=
next_instruction
();
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
trace_assert
(
item
.
inputs
.
size
()
==
1
,
"inputs size mismatch"
);
trace_assert
(
item
.
outputs
.
size
()
==
1
,
"outputs size mismatch"
);
trace_input
(
item
.
inputs
[
0
],
inputs
[
0
]);
trace_assert
(
trace_name_var
->
name
()
==
m_vars
[
item
.
outputs
[
0
]].
name
,
"name mismatch"
);
return
{
trace_output
(
item
.
outputs
[
0
])};
}
else
{
return
op
.
fallback
(
inputs
);
}
}
void
CompiledTransformation
::
on_unregister
()
noexcept
{
// resolve pending values
for
(
auto
&&
weak_value
:
m_weak_values
)
{
if
(
auto
traced_value
=
weak_value
.
lock
())
{
auto
&
var_accessor
=
m_var_accessors
[
traced_value
->
id
()];
auto
value
=
([
&
]()
->
ValueRef
{
try
{
trace_assert
(
var_accessor
.
data_getter
,
"data unreadable"
);
auto
dev_value
=
DeviceValue
::
make
(
var_accessor
.
data_getter
());
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Common
,
dev_value
->
device
(),
dev_value
->
dtype
(),
dev_value
->
shape
()),
DeviceStorage
::
make
(
dev_value
->
storage
()))[
0
];
}
catch
(...)
{
set_exception
(
std
::
current_exception
());
return
ErrorValue
::
make
(
"trace exit failed"
);
}
})();
traced_value
.
reset
(
value
);
}
}
m_weak_values
.
clear
();
}
void
CompiledTransformation
::
execute
()
{
mgb_assert
(
m_executable
!=
nullptr
);
m_graph_executor
=
std
::
thread
([
&
]
{
try
{
m_executable
->
execute
();
m_executable
->
wait
();
}
catch
(...)
{
auto
exc
=
std
::
current_exception
();
set_exception
(
exc
);
}
});
}
void
CompiledTransformation
::
wait
()
{
try
{
trace_assert
(
m_pc
==
m_seq
.
size
(),
"mismature end"
);
}
catch
(...)
{
}
mgb_assert
(
m_executable
!=
nullptr
);
m_graph_executor
.
join
();
m_graph_executor
=
{};
for
(
auto
&&
box
:
m_boxes
)
{
box
->
reset
();
}
m_pc
=
0
;
std
::
exception_ptr
graph_exc
;
std
::
swap
(
m_graph_exc
,
graph_exc
);
if
(
graph_exc
)
{
// graph with exception cannot be reused
recompile
();
std
::
rethrow_exception
(
graph_exc
);
}
}
std
::
exception_ptr
CompiledTransformation
::
set_exception
(
std
::
exception_ptr
exc
)
noexcept
{
MGB_LOCK_GUARD
(
m_mutex
);
if
(
m_graph_exc
)
{
return
m_graph_exc
;
}
for
(
auto
&&
box
:
m_boxes
)
{
box
->
try_set_exception
(
exc
);
}
m_graph_exc
=
exc
;
return
m_graph_exc
;
}
}
// namespace imperative
}
// namespace mgb
imperative/src/include/megbrain/imperative/transformations/trace.h
0 → 100644
浏览文件 @
d2b67c2a
/**
* \file imperative/src/include/megbrain/imperative/trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <chrono>
#include <future>
#include <variant>
#include "megbrain/gopt/inference.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/utils/box.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/serializer.h"
namespace
mgb
::
imperative
{
struct
TraceResult
{
struct
SeqItem
{
std
::
shared_ptr
<
OpDef
>
op
;
SmallVector
<
size_t
>
inputs
;
SmallVector
<
size_t
>
outputs
;
};
struct
VarInfo
{
enum
Kind
{
External
,
// End point of traced graph, its value is received from
// environment
Constant
,
// Also end point, but its value is constant in all executions,
// so we don't need to get from env every time, just capture it
Internal
,
// Not end point, produced by some op (or just forwarded) from
// op_seq
};
size_t
id
;
DType
dtype
;
CompNode
device
;
// if exists, assert equal when meet
ValueRef
bound_data
;
std
::
string
mark
;
std
::
string
name
;
Kind
kind
;
bool
value_required
=
false
;
bool
data_required
=
false
;
bool
shape_required
=
false
;
TensorShape
shape
;
};
using
VarKind
=
VarInfo
::
Kind
;
std
::
vector
<
SeqItem
>
seq
;
std
::
vector
<
VarInfo
>
vars
;
/**
* \brief dump to mgb computing graph
*
* \param graph mgb computing graph
* \param inputs (input_id, input_name, input_shape)
* \param outputs (output_id, outupt_name)
* \param prefer_input_names
* \return VarNodeArray output nodes
*/
VarNodeArray
dump
(
ComputingGraph
&
graph
,
std
::
vector
<
std
::
tuple
<
size_t
,
std
::
string
,
TensorShape
>>
inputs
,
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
outputs
,
bool
prefer_input_names
);
};
/**
* \brief mark an var as arg/kwarg/output
*
*/
class
TraceMarkVar
:
public
OperatorImpl
<
TraceMarkVar
,
Operator
::
IdentityLike
>
{
private:
std
::
string
m_mark
;
public:
TraceMarkVar
(
std
::
string
mark
)
:
m_mark
(
mark
)
{}
std
::
string
mark
()
const
{
return
m_mark
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TraceMarkVar{mark=%s}"
,
imperative
::
quoted
(
m_mark
).
c_str
());
}
};
class
TracingInfo
{
private:
ValueRef
m_value
=
{};
size_t
m_id
=
0
;
public:
TracingInfo
()
=
default
;
TracingInfo
(
ValueRef
value
,
size_t
id
)
:
m_value
(
value
),
m_id
(
id
)
{}
ValueRef
value
()
const
{
return
m_value
;
}
size_t
id
()
const
{
return
m_id
;
}
};
class
TracingValue
final
:
public
MixinValueImpl
<
TracingValue
,
TracingInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TracingValue{
\"
id
\"
=%zu,
\"
value
\"
=%s}"
,
id
(),
value
().
to_string
().
c_str
());
}
void
on_watch
()
override
{
value
().
watch
();
}
void
on_unwatch
()
override
{
value
().
unwatch
();
}
};
class
TracedInfo
{
private:
size_t
m_id
=
0
;
public:
TracedInfo
()
=
default
;
TracedInfo
(
size_t
id
)
:
m_id
(
id
)
{}
size_t
id
()
const
{
return
m_id
;
}
};
class
TracedValue
final
:
public
MixinValueImpl
<
TracedValue
,
TracedInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TracedValue{
\"
id
\"
=%zu}"
,
id
());
}
};
/**
* \brief trace operation sequence to TraceResult
*
* TracingTransformation records and forwards all operations to next layer,
* as if it's transparent. When execution ends, it exports an operation sequence,
* which is usually used to build CompiledTransformation.
*/
class
TracingTransformation
final
:
public
Transformation
{
public:
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarKind
=
VarInfo
::
Kind
;
private:
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
TracingValue
::
weak_ref_t
>
m_weak_vars
;
bool
m_capture_as_const
=
false
;
bool
m_record_input_shapes
=
false
;
public:
TracingTransformation
(
bool
capture_as_const
,
bool
record_input_shapes
)
:
m_capture_as_const
(
capture_as_const
),
m_record_input_shapes
(
record_input_shapes
)
{}
/**
* \brief record values for trace
*
* \param value value to be traced
* \param capture whether capture value or not
* \param kind External, Constant or Internal
* \return TypedValueRef<TracingValue> traced value
*/
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
size_t
id
=
m_vars
.
size
();
auto
wrapped_value
=
TracingValue
::
make
(
value
,
id
);
m_vars
.
push_back
({
id
,
*
value
.
dtype
(),
*
value
.
device
()});
auto
&
var
=
m_vars
.
back
();
if
(
capture
)
{
var
.
bound_data
=
value
;
}
var
.
kind
=
kind
;
if
(
m_record_input_shapes
&&
kind
!=
VarKind
::
Internal
)
{
var
.
shape
=
value
.
shape
()
->
as_tensor_shape
();
}
if
(
auto
name
=
value
.
name
())
{
var
.
name
=
*
name
;
}
m_weak_vars
.
push_back
(
wrapped_value
);
return
wrapped_value
;
}
ValueRef
unwrap_var
(
ValueRef
value
)
{
if
(
auto
*
tracing_value
=
value
.
as
<
TracingValue
>
())
{
return
tracing_value
->
value
();
}
return
value
;
}
std
::
vector
<
ValueRef
>
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
if
(
auto
*
tracing_value
=
value
.
as
<
TracingValue
>
())
{
return
tracing_value
->
value
();
}
return
value
;
}
std
::
string
name
()
const
override
{
return
"TracingTransformation"
;
}
void
on_unregister
()
noexcept
override
;
TraceResult
get_result
()
{
return
{
m_seq
,
m_vars
};
}
};
class
TraceError
:
public
std
::
exception
{
private:
std
::
string
m_message
;
public:
TraceError
(
std
::
string
reason
)
{
m_message
=
ssprintf
(
"trace error because %s"
,
reason
.
c_str
());
}
const
char
*
what
()
const
noexcept
override
{
return
m_message
.
c_str
();
}
};
/**
* \brief boost with traced result from TracingTransformation
*
* CompiledTransformation is built with an operation sequence. It compiles a megbrain
* graph with the sequence and handle operation requests with this graph. Besides that,
* it also checks that if current operation is same as previous one in seq.
*/
class
CompiledTransformation
final
:
public
Transformation
{
public:
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarKind
=
VarInfo
::
Kind
;
struct
VarAccessor
{
VarNode
*
node
;
std
::
function
<
TensorShape
()
>
shape_getter
;
std
::
function
<
DeviceTensorND
()
>
data_getter
;
std
::
function
<
HostTensorND
()
>
value_getter
;
std
::
function
<
void
(
DeviceTensorND
)
>
data_setter
;
};
private:
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
VarAccessor
>
m_var_accessors
;
size_t
m_pc
=
0
;
std
::
shared_ptr
<
ComputingGraph
>
m_graph
;
std
::
unique_ptr
<
cg
::
AsyncExecutable
>
m_executable
;
std
::
vector
<
TracedValue
::
weak_ref_t
>
m_weak_values
;
std
::
thread
m_graph_executor
;
std
::
function
<
bool
(
ValueRef
,
ValueRef
)
>
m_value_comparator
;
bool
m_input_shape_static
;
std
::
mutex
m_mutex
;
std
::
exception_ptr
m_graph_exc
;
std
::
vector
<
std
::
shared_ptr
<
BoxBase
>>
m_boxes
;
ComputingGraph
::
OutputSpec
m_output_spec
;
public:
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
)
:
m_seq
(
result
.
seq
),
m_vars
(
result
.
vars
),
m_input_shape_static
(
input_shape_static
)
{
m_graph
=
ComputingGraph
::
make
();
options
().
no_force_inplace
=
true
;
options
().
async_exec_level
=
0b100
;
}
ComputingGraph
&
graph
()
{
return
*
m_graph
;
}
ComputingGraph
::
Options
&
options
()
{
return
m_graph
->
options
();
}
/**
* \brief Set the value comparator object (usually from python)
*
* \param comparator
*/
void
set_value_comparator
(
std
::
function
<
bool
(
ValueRef
,
ValueRef
)
>
comparator
)
{
m_value_comparator
=
comparator
;
}
void
compile
();
void
recompile
();
void
assert_tensor_equal
(
ValueRef
lhs
,
ValueRef
rhs
);
/**
* \brief handle input for trace
*
* 1. For external, set input value to data_setter;
* 2. For const, do nothing;
* 3. For internal, assert var id;
* *. Always assert data equals if there are data bound.
*
* \param id
* \param value
*/
void
trace_input
(
size_t
id
,
ValueRef
value
);
/**
* \brief make a placeholder for output.
*
* \param id trace_id
* \return TracedValue::ref_t output placeholder, would be reset to real value when
* trace exits
*/
TracedValue
::
ref_t
trace_output
(
size_t
id
);
TraceResult
::
SeqItem
&
next_instruction
();
std
::
vector
<
ValueRef
>
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
void
on_unregister
()
noexcept
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
TracedValue
>
());
return
value
;
}
std
::
string
name
()
const
override
{
return
"CompiledTransformation"
;
}
void
execute
();
void
wait
();
std
::
exception_ptr
set_exception
(
std
::
exception_ptr
exc
)
noexcept
;
template
<
typename
T
>
std
::
shared_ptr
<
Box
<
T
>>
make_box
()
{
auto
box
=
Box
<
T
>::
make
();
m_boxes
.
push_back
(
box
);
return
box
;
}
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录