Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d2b67c2a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d2b67c2a
编写于
1月 14, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dispatch): implement trace
GitOrigin-RevId: f8d3005732dad0f941d963e8e529f1c11d2d3ca5
上级
39ac606b
变更
2
展开全部
隐藏空白更改
内联
并排
Showing
2 changed file
with
1027 addition
and
0 deletion
+1027
-0
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+679
-0
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+348
-0
未找到文件。
imperative/src/impl/transformations/trace.cpp
0 → 100644
浏览文件 @
d2b67c2a
此差异已折叠。
点击以展开。
imperative/src/include/megbrain/imperative/transformations/trace.h
0 → 100644
浏览文件 @
d2b67c2a
/**
* \file imperative/src/include/megbrain/imperative/trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <chrono>
#include <future>
#include <variant>
#include "megbrain/gopt/inference.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/utils/box.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/serializer.h"
namespace
mgb
::
imperative
{
struct
TraceResult
{
struct
SeqItem
{
std
::
shared_ptr
<
OpDef
>
op
;
SmallVector
<
size_t
>
inputs
;
SmallVector
<
size_t
>
outputs
;
};
struct
VarInfo
{
enum
Kind
{
External
,
// End point of traced graph, its value is received from
// environment
Constant
,
// Also end point, but its value is constant in all executions,
// so we don't need to get from env every time, just capture it
Internal
,
// Not end point, produced by some op (or just forwarded) from
// op_seq
};
size_t
id
;
DType
dtype
;
CompNode
device
;
// if exists, assert equal when meet
ValueRef
bound_data
;
std
::
string
mark
;
std
::
string
name
;
Kind
kind
;
bool
value_required
=
false
;
bool
data_required
=
false
;
bool
shape_required
=
false
;
TensorShape
shape
;
};
using
VarKind
=
VarInfo
::
Kind
;
std
::
vector
<
SeqItem
>
seq
;
std
::
vector
<
VarInfo
>
vars
;
/**
* \brief dump to mgb computing graph
*
* \param graph mgb computing graph
* \param inputs (input_id, input_name, input_shape)
* \param outputs (output_id, outupt_name)
* \param prefer_input_names
* \return VarNodeArray output nodes
*/
VarNodeArray
dump
(
ComputingGraph
&
graph
,
std
::
vector
<
std
::
tuple
<
size_t
,
std
::
string
,
TensorShape
>>
inputs
,
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
outputs
,
bool
prefer_input_names
);
};
/**
* \brief mark an var as arg/kwarg/output
*
*/
class
TraceMarkVar
:
public
OperatorImpl
<
TraceMarkVar
,
Operator
::
IdentityLike
>
{
private:
std
::
string
m_mark
;
public:
TraceMarkVar
(
std
::
string
mark
)
:
m_mark
(
mark
)
{}
std
::
string
mark
()
const
{
return
m_mark
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TraceMarkVar{mark=%s}"
,
imperative
::
quoted
(
m_mark
).
c_str
());
}
};
class
TracingInfo
{
private:
ValueRef
m_value
=
{};
size_t
m_id
=
0
;
public:
TracingInfo
()
=
default
;
TracingInfo
(
ValueRef
value
,
size_t
id
)
:
m_value
(
value
),
m_id
(
id
)
{}
ValueRef
value
()
const
{
return
m_value
;
}
size_t
id
()
const
{
return
m_id
;
}
};
class
TracingValue
final
:
public
MixinValueImpl
<
TracingValue
,
TracingInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TracingValue{
\"
id
\"
=%zu,
\"
value
\"
=%s}"
,
id
(),
value
().
to_string
().
c_str
());
}
void
on_watch
()
override
{
value
().
watch
();
}
void
on_unwatch
()
override
{
value
().
unwatch
();
}
};
class
TracedInfo
{
private:
size_t
m_id
=
0
;
public:
TracedInfo
()
=
default
;
TracedInfo
(
size_t
id
)
:
m_id
(
id
)
{}
size_t
id
()
const
{
return
m_id
;
}
};
class
TracedValue
final
:
public
MixinValueImpl
<
TracedValue
,
TracedInfo
>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"TracedValue{
\"
id
\"
=%zu}"
,
id
());
}
};
/**
* \brief trace operation sequence to TraceResult
*
* TracingTransformation records and forwards all operations to next layer,
* as if it's transparent. When execution ends, it exports an operation sequence,
* which is usually used to build CompiledTransformation.
*/
class
TracingTransformation
final
:
public
Transformation
{
public:
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarKind
=
VarInfo
::
Kind
;
private:
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
TracingValue
::
weak_ref_t
>
m_weak_vars
;
bool
m_capture_as_const
=
false
;
bool
m_record_input_shapes
=
false
;
public:
TracingTransformation
(
bool
capture_as_const
,
bool
record_input_shapes
)
:
m_capture_as_const
(
capture_as_const
),
m_record_input_shapes
(
record_input_shapes
)
{}
/**
* \brief record values for trace
*
* \param value value to be traced
* \param capture whether capture value or not
* \param kind External, Constant or Internal
* \return TypedValueRef<TracingValue> traced value
*/
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
size_t
id
=
m_vars
.
size
();
auto
wrapped_value
=
TracingValue
::
make
(
value
,
id
);
m_vars
.
push_back
({
id
,
*
value
.
dtype
(),
*
value
.
device
()});
auto
&
var
=
m_vars
.
back
();
if
(
capture
)
{
var
.
bound_data
=
value
;
}
var
.
kind
=
kind
;
if
(
m_record_input_shapes
&&
kind
!=
VarKind
::
Internal
)
{
var
.
shape
=
value
.
shape
()
->
as_tensor_shape
();
}
if
(
auto
name
=
value
.
name
())
{
var
.
name
=
*
name
;
}
m_weak_vars
.
push_back
(
wrapped_value
);
return
wrapped_value
;
}
ValueRef
unwrap_var
(
ValueRef
value
)
{
if
(
auto
*
tracing_value
=
value
.
as
<
TracingValue
>
())
{
return
tracing_value
->
value
();
}
return
value
;
}
std
::
vector
<
ValueRef
>
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
if
(
auto
*
tracing_value
=
value
.
as
<
TracingValue
>
())
{
return
tracing_value
->
value
();
}
return
value
;
}
std
::
string
name
()
const
override
{
return
"TracingTransformation"
;
}
void
on_unregister
()
noexcept
override
;
TraceResult
get_result
()
{
return
{
m_seq
,
m_vars
};
}
};
class
TraceError
:
public
std
::
exception
{
private:
std
::
string
m_message
;
public:
TraceError
(
std
::
string
reason
)
{
m_message
=
ssprintf
(
"trace error because %s"
,
reason
.
c_str
());
}
const
char
*
what
()
const
noexcept
override
{
return
m_message
.
c_str
();
}
};
/**
* \brief boost with traced result from TracingTransformation
*
* CompiledTransformation is built with an operation sequence. It compiles a megbrain
* graph with the sequence and handle operation requests with this graph. Besides that,
* it also checks that if current operation is same as previous one in seq.
*/
class
CompiledTransformation
final
:
public
Transformation
{
public:
using
VarInfo
=
TraceResult
::
VarInfo
;
using
VarKind
=
VarInfo
::
Kind
;
struct
VarAccessor
{
VarNode
*
node
;
std
::
function
<
TensorShape
()
>
shape_getter
;
std
::
function
<
DeviceTensorND
()
>
data_getter
;
std
::
function
<
HostTensorND
()
>
value_getter
;
std
::
function
<
void
(
DeviceTensorND
)
>
data_setter
;
};
private:
std
::
vector
<
TraceResult
::
SeqItem
>
m_seq
;
std
::
vector
<
TraceResult
::
VarInfo
>
m_vars
;
std
::
vector
<
VarAccessor
>
m_var_accessors
;
size_t
m_pc
=
0
;
std
::
shared_ptr
<
ComputingGraph
>
m_graph
;
std
::
unique_ptr
<
cg
::
AsyncExecutable
>
m_executable
;
std
::
vector
<
TracedValue
::
weak_ref_t
>
m_weak_values
;
std
::
thread
m_graph_executor
;
std
::
function
<
bool
(
ValueRef
,
ValueRef
)
>
m_value_comparator
;
bool
m_input_shape_static
;
std
::
mutex
m_mutex
;
std
::
exception_ptr
m_graph_exc
;
std
::
vector
<
std
::
shared_ptr
<
BoxBase
>>
m_boxes
;
ComputingGraph
::
OutputSpec
m_output_spec
;
public:
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
)
:
m_seq
(
result
.
seq
),
m_vars
(
result
.
vars
),
m_input_shape_static
(
input_shape_static
)
{
m_graph
=
ComputingGraph
::
make
();
options
().
no_force_inplace
=
true
;
options
().
async_exec_level
=
0b100
;
}
ComputingGraph
&
graph
()
{
return
*
m_graph
;
}
ComputingGraph
::
Options
&
options
()
{
return
m_graph
->
options
();
}
/**
* \brief Set the value comparator object (usually from python)
*
* \param comparator
*/
void
set_value_comparator
(
std
::
function
<
bool
(
ValueRef
,
ValueRef
)
>
comparator
)
{
m_value_comparator
=
comparator
;
}
void
compile
();
void
recompile
();
void
assert_tensor_equal
(
ValueRef
lhs
,
ValueRef
rhs
);
/**
* \brief handle input for trace
*
* 1. For external, set input value to data_setter;
* 2. For const, do nothing;
* 3. For internal, assert var id;
* *. Always assert data equals if there are data bound.
*
* \param id
* \param value
*/
void
trace_input
(
size_t
id
,
ValueRef
value
);
/**
* \brief make a placeholder for output.
*
* \param id trace_id
* \return TracedValue::ref_t output placeholder, would be reset to real value when
* trace exits
*/
TracedValue
::
ref_t
trace_output
(
size_t
id
);
TraceResult
::
SeqItem
&
next_instruction
();
std
::
vector
<
ValueRef
>
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
void
on_unregister
()
noexcept
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
TracedValue
>
());
return
value
;
}
std
::
string
name
()
const
override
{
return
"CompiledTransformation"
;
}
void
execute
();
void
wait
();
std
::
exception_ptr
set_exception
(
std
::
exception_ptr
exc
)
noexcept
;
template
<
typename
T
>
std
::
shared_ptr
<
Box
<
T
>>
make_box
()
{
auto
box
=
Box
<
T
>::
make
();
m_boxes
.
push_back
(
box
);
return
box
;
}
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录