Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5798f6ce
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5798f6ce
编写于
8月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(subgraph): add OpMeth make_forward_graph
GitOrigin-RevId: 171301fc2be5f867d4d653bc9a3fb22a94c289e6
上级
48db45d1
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
335 addition
and
22 deletion
+335
-22
imperative/src/impl/op_def.cpp
imperative/src/impl/op_def.cpp
+13
-0
imperative/src/impl/op_trait.cpp
imperative/src/impl/op_trait.cpp
+39
-13
imperative/src/impl/op_trait.h
imperative/src/impl/op_trait.h
+58
-9
imperative/src/impl/subgraph_detail.cpp
imperative/src/impl/subgraph_detail.cpp
+169
-0
imperative/src/include/megbrain/imperative/op_def.h
imperative/src/include/megbrain/imperative/op_def.h
+5
-0
imperative/src/include/megbrain/imperative/subgraph_detail.h
imperative/src/include/megbrain/imperative/subgraph_detail.h
+51
-0
未找到文件。
imperative/src/impl/op_def.cpp
浏览文件 @
5798f6ce
...
...
@@ -100,6 +100,19 @@ std::vector<std::pair<const char*, std::string>> OpDef::props(
return
def
.
trait
()
->
props
(
def
);
}
EncodedSubraph
OpDef
::
make_forward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
){
using
ForwardGraphCache
=
OpMethResultCache
<
EncodedSubraph
,
SmallVector
<
bool
>
,
SmallVector
<
bool
>>
;
thread_local
ForwardGraphCache
cache
;
decltype
(
cache
)
::
key_t
cache_key
{
const_cast
<
OpDef
&>
(
def
).
shared_from_this
(),
inputs
};
auto
iter
=
cache
.
find
(
cache_key
);
if
(
iter
==
cache
.
end
())
{
iter
=
cache
.
insert
({
cache_key
,
def
.
trait
()
->
make_forward_graph
(
def
,
inputs
)}).
first
;
}
return
iter
->
second
;
}
std
::
string
OpDef
::
to_string
()
const
{
std
::
string
builder
=
trait
()
->
make_name
(
*
this
)
+
"{"
;
for
(
auto
&&
[
name
,
value
]
:
props
(
*
this
))
{
...
...
imperative/src/impl/op_trait.cpp
浏览文件 @
5798f6ce
...
...
@@ -16,6 +16,7 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/tensor.h"
#include "./op_trait.h"
...
...
@@ -38,24 +39,45 @@ StaticData& static_data() {
return
data
;
}
void
OpMethFallback
::
impl
(
ApplyOnPhysicalTensor
&
func
,
void
OpMethFallback
ByProxyGraph
::
impl
(
ApplyOnPhysicalTensor
&
func
,
op_meth_tag
::
ApplyOnPhysicalTensor
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
apply_on_physical_tensor
);
}
void
OpMethFallback
::
impl
(
Execute
&
func
,
op_meth_tag
::
Execute
)
{
void
OpMethFallback
ByProxyGraph
::
impl
(
Execute
&
func
,
op_meth_tag
::
Execute
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
execute
);
}
void
OpMethFallback
::
impl
(
InferOutputMemDesc
&
func
,
void
OpMethFallback
ByProxyGraph
::
impl
(
InferOutputMemDesc
&
func
,
op_meth_tag
::
InferOutputMemDesc
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
infer_output_mem_desc
);
}
void
OpMethFallback
::
impl
(
InferOutputAttrsFallible
&
func
,
void
OpMethFallback
ByProxyGraph
::
impl
(
InferOutputAttrsFallible
&
func
,
op_meth_tag
::
InferOutputAttrsFallible
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
infer_output_attrs_fallible
);
}
void
OpMethFallback
::
impl
(
GradMaker
&
func
,
op_meth_tag
::
GradMaker
)
{
void
OpMethFallback
ByProxyGraph
::
impl
(
GradMaker
&
func
,
op_meth_tag
::
GradMaker
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
make_backward_graph
);
}
void
OpMethFallbackFromSubgraph
::
impl
(
ApplyOnPhysicalTensor
&
func
,
op_meth_tag
::
ApplyOnPhysicalTensor
)
{
func
.
Base
::
operator
=
(
subgraph_detail
::
apply_on_physical_tensor
);
}
void
OpMethFallbackFromSubgraph
::
impl
(
InferOutputMemDesc
&
func
,
op_meth_tag
::
InferOutputMemDesc
)
{
func
.
Base
::
operator
=
(
subgraph_detail
::
infer_output_mem_desc
);
}
void
OpMethFallbackFromSubgraph
::
impl
(
ApplyOnVarNode
&
func
,
op_meth_tag
::
ApplyOnVarNode
)
{
func
.
Base
::
operator
=
(
subgraph_detail
::
apply_on_var_node
);
}
void
OpMethFallbackFromSubgraph
::
impl
(
InferOutputAttrsFallible
&
func
,
op_meth_tag
::
InferOutputAttrsFallible
)
{
func
.
Base
::
operator
=
(
subgraph_detail
::
infer_output_attrs_fallible
);
}
void
OpMethFallbackFromSubgraph
::
impl
(
GradMaker
&
func
,
op_meth_tag
::
GradMaker
)
{
func
.
Base
::
operator
=
(
subgraph_detail
::
make_backward_graph
);
}
void
OpMethFallback
::
impl
(
DecideDispatchMode
&
func
,
op_meth_tag
::
DecideDispatchMode
)
{
static
auto
decide_dispatch_mode
=
...
...
@@ -99,16 +121,20 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
OpTraitRegistry
&
OpTraitRegistry
::
fallback
()
{
using
Mode
=
detail
::
OpMethFallbackMode
;
uint64_t
mode
=
Mode
::
None
;
if
(
trait
->
make_forward_graph
)
{
mode
|=
Mode
::
FromSubgraph
;
}
if
(
trait
->
apply_on_var_node
)
{
// fallback to proxy graph impl
trait
->
apply_on_physical_tensor
.
allow_fallback
=
true
;
trait
->
execute
.
allow_fallback
=
true
;
trait
->
infer_output_mem_desc
.
allow_fallback
=
true
;
trait
->
infer_output_attrs_fallible
.
allow_fallback
=
true
;
trait
->
make_backward_graph
.
allow_fallback
=
true
;
mode
|=
Mode
::
ByProxyGraph
;
}
trait
->
decide_dispatch_mode
.
allow_fallback
=
true
;
trait
->
make_name
.
allow_fallback
=
true
;
mode
|=
Mode
::
Default
;
#define SET_FALLBACK_MODE(meth) \
trait->meth.fallback_mode = mode;
FOR_EACH_OP_METH
(
SET_FALLBACK_MODE
)
#undef SET_FALLBACK_MODE
return
*
this
;
}
...
...
imperative/src/impl/op_trait.h
浏览文件 @
5798f6ce
...
...
@@ -95,9 +95,18 @@ OpMethType(IsSame,
OpMethType
(
MakeNameFunc
,
std
::
string
(
const
OpDef
&
));
OpMethType
(
GraphMaker
,
decltype
(
OpDef
::
make_forward_graph
));
// clang-format on
namespace
detail
{
struct
OpMethImplBase
{
template
<
typename
Tag
,
typename
RType
,
typename
...
Args
>
static
void
impl
(
thin_function
<
RType
(
Args
...)
>&
func
,
Tag
)
{}
};
struct
OpMethNotImpl
{
template
<
typename
Tag
,
typename
RType
,
typename
...
Args
>
static
void
impl
(
thin_function
<
RType
(
Args
...)
>&
func
,
Tag
)
{
...
...
@@ -106,8 +115,15 @@ struct OpMethNotImpl {
};
}
};
struct
OpMethFallback
:
public
OpMethNotImpl
{
using
OpMethNotImpl
::
impl
;
struct
OpMethFallback
:
OpMethImplBase
{
using
OpMethImplBase
::
impl
;
static
void
impl
(
DecideDispatchMode
&
func
,
op_meth_tag
::
DecideDispatchMode
);
static
void
impl
(
MakeNameFunc
&
func
,
op_meth_tag
::
MakeNameFunc
);
};
struct
OpMethFallbackByProxyGraph
:
OpMethImplBase
{
using
OpMethImplBase
::
impl
;
static
void
impl
(
ApplyOnPhysicalTensor
&
func
,
op_meth_tag
::
ApplyOnPhysicalTensor
);
static
void
impl
(
Execute
&
func
,
op_meth_tag
::
Execute
);
...
...
@@ -115,18 +131,48 @@ struct OpMethFallback : public OpMethNotImpl {
static
void
impl
(
InferOutputAttrsFallible
&
func
,
op_meth_tag
::
InferOutputAttrsFallible
);
static
void
impl
(
GradMaker
&
func
,
op_meth_tag
::
GradMaker
);
static
void
impl
(
DecideDispatchMode
&
func
,
op_meth_tag
::
DecideDispatchMode
);
static
void
impl
(
MakeNameFunc
&
func
,
op_meth_tag
::
MakeNameFunc
);
};
struct
OpMethFallbackFromSubgraph
:
OpMethImplBase
{
using
OpMethImplBase
::
impl
;
static
void
impl
(
ApplyOnPhysicalTensor
&
func
,
op_meth_tag
::
ApplyOnPhysicalTensor
);
static
void
impl
(
InferOutputMemDesc
&
func
,
op_meth_tag
::
InferOutputMemDesc
);
static
void
impl
(
ApplyOnVarNode
&
func
,
op_meth_tag
::
ApplyOnVarNode
);
static
void
impl
(
InferOutputAttrsFallible
&
func
,
op_meth_tag
::
InferOutputAttrsFallible
);
static
void
impl
(
GradMaker
&
func
,
op_meth_tag
::
GradMaker
);
};
struct
OpMethFallbackMode
{
static
constexpr
uint64_t
None
=
0
;
static
constexpr
uint64_t
Default
=
1
;
static
constexpr
uint64_t
ByProxyGraph
=
2
;
static
constexpr
uint64_t
FromSubgraph
=
4
;
};
template
<
typename
Tag
,
typename
RType
,
typename
...
Args
>
struct
OpMeth
<
Tag
,
RType
(
Args
...)
>
:
public
thin_function
<
RType
(
Args
...)
>
{
using
Base
=
thin_function
<
RType
(
Args
...)
>
;
OpMeth
()
:
Base
{}
,
allow_fallback
(
false
)
{};
OpMeth
()
:
Base
{}{};
explicit
OpMeth
(
const
Base
&
base
)
{
this
->
Base
::
operator
=
(
base
);
}
using
Base
::
operator
bool
;
RType
operator
()(
Args
...
args
)
const
{
if
(
!
this
->
Base
::
operator
bool
())
{
if
(
allow_fallback
)
{
uint64_t
mode_mask
=
~
uint64_t
(
0
);
auto
match_mode
=
[
&
](
uint64_t
mode
){
if
((
fallback_mode
&
mode_mask
)
&
mode
)
{
mode_mask
&=
~
mode
;
return
true
;
}
return
false
;
};
while
(
!
this
->
Base
::
operator
bool
())
{
using
Mode
=
OpMethFallbackMode
;
if
(
match_mode
(
Mode
::
FromSubgraph
))
{
OpMethFallbackFromSubgraph
::
impl
(
*
const_cast
<
OpMeth
*>
(
this
),
Tag
{});
}
else
if
(
match_mode
(
Mode
::
ByProxyGraph
))
{
OpMethFallbackByProxyGraph
::
impl
(
*
const_cast
<
OpMeth
*>
(
this
),
Tag
{});
}
else
if
(
match_mode
(
Mode
::
Default
))
{
OpMethFallback
::
impl
(
*
const_cast
<
OpMeth
*>
(
this
),
Tag
{});
}
else
{
OpMethNotImpl
::
impl
(
*
const_cast
<
OpMeth
*>
(
this
),
Tag
{});
...
...
@@ -134,7 +180,7 @@ struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
}
return
this
->
Base
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
}
bool
allow_fallback
=
fals
e
;
uint64_t
fallback_mode
=
OpMethFallbackMode
::
Non
e
;
};
}
// namespace detail
...
...
@@ -153,6 +199,7 @@ struct OpTrait {
HashFunc
hash
;
IsSame
is_same_st
;
MakeNameFunc
make_name
;
GraphMaker
make_forward_graph
;
OpTrait
(
const
char
*
name
);
static
OpTrait
*
find_by_name
(
const
char
*
name
);
static
OpTrait
*
find_by_typeinfo
(
Typeinfo
*
type
);
...
...
@@ -173,7 +220,9 @@ struct OpTrait {
cb(props) \
cb(hash) \
cb(is_same_st) \
cb(make_name)
cb(make_name) \
cb(make_forward_graph) \
// clang-format on
struct
OpTraitRegistry
{
...
...
imperative/src/impl/subgraph_detail.cpp
0 → 100644
浏览文件 @
5798f6ce
/**
* \file imperative/src/impl/subgraph_detail.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/opr/io.h"
#include "megbrain/imperative/ops/autogen.h"
#include "./op_trait.h"
namespace
mgb
{
namespace
imperative
{
namespace
subgraph_detail
{
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
SmallVector
<
LogicalTensorDesc
>
input_descs
;
for
(
auto
&&
input
:
inputs
)
{
input_descs
.
push_back
({
TensorLayout
{
input
->
dtype
()},
input
->
comp_node
()});
}
auto
apply_functor
=
[](
const
std
::
shared_ptr
<
OpDef
>&
op
,
const
VarNodeArray
&
inputs
,
size_t
nr_outputs
){
return
OpDef
::
apply_on_var_node
(
*
op
,
inputs
);
};
auto
const_functor
=
[
&
](
const
TensorPtr
&
value
)
{
return
opr
::
ImmutableTensor
::
make
(
*
inputs
[
0
]
->
owner_graph
(),
value
->
get_value
()).
node
();
};
auto
subgraph
=
def
.
trait
()
->
make_forward_graph
(
def
,
input_descs
);
auto
outputs
=
subgraph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
return
outputs
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
subgraph
=
def
.
trait
()
->
make_forward_graph
(
def
,
inputs
);
bool
all_validated
=
true
;
auto
apply_functor
=
[
&
](
const
std
::
shared_ptr
<
OpDef
>&
op
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
size_t
nr_outputs
){
auto
[
outputs
,
validated
]
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
inputs
);
all_validated
=
all_validated
&&
validated
;
return
outputs
;
};
auto
const_functor
=
[
&
](
const
TensorPtr
&
value
)
{
return
LogicalTensorDesc
{
value
->
layout
(),
value
->
comp_node
(),
value
->
get_value
().
proxy_to_default_cpu
()};
};
auto
outputs
=
subgraph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
return
{
outputs
,
all_validated
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
)
{
SmallVector
<
LogicalTensorDesc
>
input_descs
;
for
(
auto
&&
input
:
inputs
)
{
input_descs
.
push_back
({
input
->
layout
(),
input
->
comp_node
()});
}
auto
subgraph
=
def
.
trait
()
->
make_forward_graph
(
def
,
input_descs
);
auto
apply_functor
=
[](
const
std
::
shared_ptr
<
OpDef
>&
op
,
const
SmallVector
<
TensorPtr
>&
inputs
,
size_t
nr_outputs
){
return
OpDef
::
apply_on_physical_tensor
(
*
op
,
inputs
);
};
auto
const_functor
=
[
&
](
const
TensorPtr
&
value
)
{
return
value
;
};
auto
outputs
=
subgraph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
return
outputs
;
}
static
EncodedSubraph
make_backward_graph_from_forward
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
,
EncodedSubraph
forward_graph
)
{
using
namespace
std
::
placeholders
;
using
var_t
=
Subgraph
::
var_t
;
using
vars_t
=
Subgraph
::
vars_t
;
Subgraph
::
Builder
<
LogicalTensorDesc
>
builder
([](
auto
&&
op
,
auto
&&
input_descs
,
size_t
nr_outputs
){
auto
[
descs
,
_
]
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
return
descs
;
});
auto
accum_grad
=
[
&
](
var_t
lhs
,
var_t
rhs
)
{
return
builder
.
write_expr
(
Elemwise
::
make
(
Elemwise
::
Mode
::
ADD
),
{
lhs
,
rhs
},
1
)[
0
];
};
GradContext
<
var_t
>
grad_context
{
accum_grad
};
auto
input_vars
=
builder
.
write_inputs
(
inputs
);
auto
outputs
=
forward_graph
.
apply
(
input_vars
,
std
::
bind
(
&
decltype
(
builder
)
::
write_expr
,
&
builder
,
_1
,
_2
,
_3
),
[
&
](
TensorPtr
constant
){
return
builder
.
write_constant
(
constant
,
{
constant
->
layout
(),
constant
->
comp_node
()});
});
size_t
nr_outputs
=
outputs
.
size
();
auto
apply_mask
=
[](
auto
&&
values
,
SmallVector
<
bool
>
mask
)
{
mgb_assert
(
mask
.
size
()
==
values
.
size
(),
""
);
std
::
decay_t
<
decltype
(
values
)
>
results
;
for
(
size_t
i
=
0
;
i
<
mask
.
size
();
++
i
)
{
if
(
mask
[
i
])
{
results
.
push_back
(
values
[
i
]);
}
}
return
results
;
};
grad_context
.
mark_require_grads
(
apply_mask
(
input_vars
,
input_requires_grad
));
builder
.
iterate
([
&
](
std
::
list
<
Subgraph
::
expr_t
>::
iterator
iter
){
grad_context
.
record_expr
(
iter
->
op
,
iter
->
inputs
,
iter
->
outputs
);
});
auto
output_descs
=
builder
.
get_descs
(
outputs
);
auto
computed_outputs
=
builder
.
write_inputs
(
output_descs
);
auto
output_grads
=
builder
.
write_inputs
(
output_descs
);
grad_context
.
backward
(
apply_mask
(
outputs
,
output_has_grad
),
apply_mask
(
output_grads
,
output_has_grad
),
[
&
](
Subgraph
::
expr_t
expr
,
vars_t
output_grads
)
{
auto
bg
=
OpDef
::
make_backward_graph
(
*
expr
.
op
,
builder
.
get_descs
(
expr
.
inputs
),
grad_context
.
get_require_grads
(
expr
.
inputs
),
grad_context
.
get_has_grads
(
expr
.
outputs
));
if
(
bg
.
graph
.
empty
())
{
return
vars_t
(
expr
.
inputs
.
size
(),
0
);
}
vars_t
grad_inputs
;
grad_inputs
.
insert
(
grad_inputs
.
end
(),
expr
.
inputs
.
begin
(),
expr
.
inputs
.
end
());
grad_inputs
.
insert
(
grad_inputs
.
end
(),
expr
.
outputs
.
begin
(),
expr
.
outputs
.
end
());
grad_inputs
.
insert
(
grad_inputs
.
end
(),
output_grads
.
begin
(),
output_grads
.
end
());
auto
apply_functor
=
std
::
bind
(
&
decltype
(
builder
)
::
write_expr
,
&
builder
,
_1
,
_2
,
_3
);
auto
const_functor
=
[
&
](
TensorPtr
constant
)
{
return
builder
.
write_constant
(
constant
,
{
constant
->
layout
(),
constant
->
comp_node
()});
};
return
bg
.
apply
(
grad_inputs
,
apply_functor
,
const_functor
);
});
builder
.
add_outputs
(
grad_context
.
get_grads
(
input_vars
));
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
builder
.
replace_var
(
outputs
[
i
],
computed_outputs
[
i
]);
}
auto
backward_graph
=
builder
.
encode
();
return
backward_graph
;
}
EncodedSubraph
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
)
{
auto
forward_graph
=
OpDef
::
make_forward_graph
(
def
,
inputs
);
return
make_backward_graph_from_forward
(
inputs
,
input_requires_grad
,
output_has_grad
,
forward_graph
);
}
std
::
tuple
<
SmallVector
<
MemoryDesc
>
,
SmallVector
<
MemoryDesc
>>
infer_output_mem_desc
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs_tensors
,
const
SmallVector
<
MemoryDesc
>&
inputs_mems
)
{
return
{{},
{}};
}
}
}
}
imperative/src/include/megbrain/imperative/op_def.h
浏览文件 @
5798f6ce
...
...
@@ -13,6 +13,7 @@
#include "megbrain/graph.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/utils/to_string.h"
#include "megbrain/imperative/subgraph.h"
...
...
@@ -94,6 +95,10 @@ public:
static
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
def
);
static
EncodedSubraph
make_forward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
const
OpTrait
*
trait
()
const
;
std
::
string
to_string
()
const
;
...
...
imperative/src/include/megbrain/imperative/subgraph_detail.h
0 → 100644
浏览文件 @
5798f6ce
/**
* \file imperative/src/include/megbrain/imperative/subgraph_detail.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
namespace
mgb
{
namespace
imperative
{
namespace
subgraph_detail
{
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
);
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
EncodedSubraph
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
);
cg
::
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
);
EncodedSubraph
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
);
std
::
tuple
<
SmallVector
<
MemoryDesc
>
,
SmallVector
<
MemoryDesc
>>
infer_output_mem_desc
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs_tensors
,
const
SmallVector
<
MemoryDesc
>&
inputs_mems
);
}
}
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录