Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2a063f8e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2a063f8e
编写于
8月 09, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(subgraph): fix scope mismatch of subgraph content
GitOrigin-RevId: 6e23456250aa70c4cbdd71ecd9cfa6c19270a316
上级
3206af9d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
92 addition
and
43 deletion
+92
-43
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+14
-14
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+19
-6
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+15
-6
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+9
-4
imperative/src/impl/ops/utility.cpp
imperative/src/impl/ops/utility.cpp
+18
-9
imperative/src/impl/subgraph_detail.cpp
imperative/src/impl/subgraph_detail.cpp
+2
-1
imperative/src/include/megbrain/imperative/ops/utility.h
imperative/src/include/megbrain/imperative/ops/utility.h
+15
-3
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
2a063f8e
...
@@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
...
@@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
gopt_level
=
None
# disable jit and compile
gopt_level
=
None
# disable jit and compile
binary_ops
=
{
binary_ops
=
{
"+"
:
builtin
.
Elemwise
(
mode
=
"add"
),
"+"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"add"
),
"-"
:
builtin
.
Elemwise
(
mode
=
"sub"
),
"-"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"sub"
),
"*"
:
builtin
.
Elemwise
(
mode
=
"mul"
),
"*"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"mul"
),
"/"
:
builtin
.
Elemwise
(
mode
=
"true_div"
),
"/"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"true_div"
),
"//"
:
builtin
.
Elemwise
(
mode
=
"floor_div"
),
"//"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"floor_div"
),
"**"
:
builtin
.
Elemwise
(
mode
=
"pow"
),
"**"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"pow"
),
"√"
:
builtin
.
Elemwise
(
mode
=
"expm1"
),
"√"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"expm1"
),
"max"
:
builtin
.
Elemwise
(
mode
=
"max"
),
"max"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"max"
),
"additive"
:
builtin
.
Elemwise
(
mode
=
"add"
),
"additive"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"add"
),
}
}
unary_ops
=
{
unary_ops
=
{
"-"
:
builtin
.
Elemwise
(
mode
=
"negate"
),
"-"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"negate"
),
}
}
def
decorator
(
func
):
def
decorator
(
func
):
...
@@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
...
@@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def
apply_expr
(
op
,
*
args
):
def
apply_expr
(
op
,
*
args
):
if
isinstance
(
op
,
str
):
if
isinstance
(
op
,
str
):
if
len
(
args
)
==
2
:
if
len
(
args
)
==
2
:
op
=
binary_ops
[
op
]
op
=
binary_ops
[
op
]
()
elif
len
(
args
)
==
1
:
elif
len
(
args
)
==
1
:
op
=
unary_ops
[
op
]
op
=
unary_ops
[
op
]
()
return
builder
.
apply
(
op
,
args
,
1
)[
0
]
return
builder
.
apply
(
op
,
args
,
1
)[
0
]
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
...
@@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
...
@@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
builder
.
outputs
(
outputs
)
builder
.
outputs
(
outputs
)
builder
.
outputs_has_grad
(
outputs_has_grad
)
builder
.
outputs_has_grad
(
outputs_has_grad
)
if
gopt_level
is
None
:
if
gopt_level
is
None
:
return
builder
.
get
()
return
lambda
:
builder
.
get
()
else
:
else
:
return
builder
.
compile
(
gopt_level
)
return
lambda
:
builder
.
compile
(
gopt_level
)
return
decorator
return
decorator
imperative/python/megengine/functional/math.py
浏览文件 @
2a063f8e
...
@@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor:
...
@@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor:
return
result
return
result
class
_Hashable
:
def
__init__
(
self
,
value
)
->
None
:
self
.
value
=
value
def
__hash__
(
self
)
->
int
:
return
hash
(
str
(
self
.
value
))
def
__eq__
(
self
,
o
:
object
)
->
bool
:
if
not
isinstance
(
o
,
_Hashable
):
return
False
return
self
.
value
==
o
.
value
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedMatrixMulOp
(
def
_get_extentedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
...
@@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp(
...
@@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp(
transposeB
=
transpose_b
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
compute_mode
=
compute_mode
,
format
=
format
,
format
=
format
,
strategy
=
strategy
,
strategy
=
strategy
.
value
,
)
)
result
=
f
(
op
,
inp1
,
inp2
)
result
=
f
(
op
,
inp1
,
inp2
)
result_shape
=
f
(
GetVarShape
(),
result
)
result_shape
=
f
(
GetVarShape
(),
result
)
...
@@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp(
...
@@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp(
transposeB
=
transpose_b
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
compute_mode
=
compute_mode
,
format
=
format
,
format
=
format
,
strategy
=
strategy
,
strategy
=
strategy
.
value
,
)
)
result
=
f
(
op
,
inp1
,
inp2
)
result
=
f
(
op
,
inp1
,
inp2
)
...
@@ -1051,9 +1064,9 @@ def matmul(
...
@@ -1051,9 +1064,9 @@ def matmul(
transpose_b
,
transpose_b
,
compute_mode
,
compute_mode
,
format
,
format
,
strategy
=
get_execution_strategy
(
),
strategy
=
_Hashable
(
get_execution_strategy
()
),
)
)
(
result
,)
=
apply
(
extentedMatrixMulOp
,
inp1
,
inp2
)
(
result
,)
=
apply
(
extentedMatrixMulOp
()
,
inp1
,
inp2
)
return
result
return
result
else
:
# dispath to BatchedMatrixMul
else
:
# dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
...
@@ -1065,9 +1078,9 @@ def matmul(
...
@@ -1065,9 +1078,9 @@ def matmul(
transpose_b
,
transpose_b
,
compute_mode
,
compute_mode
,
format
,
format
,
strategy
=
get_execution_strategy
(
),
strategy
=
_Hashable
(
get_execution_strategy
()
),
)
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
,
inp1
,
inp2
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
()
,
inp1
,
inp2
)
return
result
return
result
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
2a063f8e
...
@@ -1328,7 +1328,7 @@ def sync_batch_norm(
...
@@ -1328,7 +1328,7 @@ def sync_batch_norm(
syncbn_split_stats
,
syncbn_split_stats
,
)
=
_get_sync_bn_ops
(
_device
,
_dtype
,
eps_mode
,
_ndim
,
_channels
)
)
=
_get_sync_bn_ops
(
_device
,
_dtype
,
eps_mode
,
_ndim
,
_channels
)
reduce_shape
,
reduce_size
,
channel_x1s
,
channel_x2s
=
apply
(
syncbn_stage0
,
inp
)
reduce_shape
,
reduce_size
,
channel_x1s
,
channel_x2s
=
apply
(
syncbn_stage0
()
,
inp
)
eps
=
convert_single_value
(
eps
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
eps
=
convert_single_value
(
eps
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
...
@@ -1338,19 +1338,28 @@ def sync_batch_norm(
...
@@ -1338,19 +1338,28 @@ def sync_batch_norm(
if
training
:
if
training
:
if
is_distributed
():
if
is_distributed
():
# reduce all nodes' data to calculate mean and variance
# reduce all nodes' data to calculate mean and variance
(
stat
,)
=
apply
(
syncbn_concat_stats
,
reduce_size
,
channel_x1s
,
channel_x2s
)
(
stat
,)
=
apply
(
syncbn_concat_stats
(),
reduce_size
,
channel_x1s
,
channel_x2s
)
stat
=
all_reduce_sum
(
stat
,
group
)
stat
=
all_reduce_sum
(
stat
,
group
)
reduce_size
,
channel_x1s
,
channel_x2s
=
apply
(
syncbn_split_stats
,
stat
)
reduce_size
,
channel_x1s
,
channel_x2s
=
apply
(
syncbn_split_stats
()
,
stat
)
outvar
,
channel_mean
,
*
_
=
apply
(
outvar
,
channel_mean
,
*
_
=
apply
(
syncbn_stage1
,
inp
,
reduce_size
,
channel_x1s
,
channel_x2s
,
eps
,
weight
,
bias
syncbn_stage1
(),
inp
,
reduce_size
,
channel_x1s
,
channel_x2s
,
eps
,
weight
,
bias
,
)
)
else
:
else
:
assert
running_var
is
not
None
and
running_mean
is
not
None
assert
running_var
is
not
None
and
running_mean
is
not
None
channel_mean
=
running_mean
channel_mean
=
running_mean
channel_var
=
running_var
channel_var
=
running_var
outvar
,
*
_
=
apply
(
outvar
,
*
_
=
apply
(
syncbn_stage1_inference
,
inp
,
channel_mean
,
channel_var
,
eps
,
weight
,
bias
syncbn_stage1_inference
()
,
inp
,
channel_mean
,
channel_var
,
eps
,
weight
,
bias
)
)
# outvar = output * weight + bias
# outvar = output * weight + bias
...
@@ -1362,7 +1371,7 @@ def sync_batch_norm(
...
@@ -1362,7 +1371,7 @@ def sync_batch_norm(
if
training
and
running_var
is
not
None
and
running_mean
is
not
None
:
if
training
and
running_var
is
not
None
and
running_mean
is
not
None
:
momentum
=
convert_single_value
(
momentum
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
momentum
=
convert_single_value
(
momentum
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
running_mean
[...],
running_var
[...]
=
apply
(
running_mean
[...],
running_var
[...]
=
apply
(
syncbn_stage2
,
syncbn_stage2
()
,
running_mean
,
running_mean
,
running_var
,
running_var
,
momentum
,
momentum
,
...
...
imperative/python/src/ops.cpp
浏览文件 @
2a063f8e
...
@@ -482,9 +482,15 @@ void init_ops(py::module m) {
...
@@ -482,9 +482,15 @@ void init_ops(py::module m) {
struct
PySubgraphBuilder
{
struct
PySubgraphBuilder
{
explicit
PySubgraphBuilder
(
std
::
string
name
)
:
name
{
name
}{}
explicit
PySubgraphBuilder
(
std
::
string
name
)
:
name
{
name
}{}
std
::
string
name
;
std
::
string
name
;
Subgraph
graph
;
std
::
shared_ptr
<
Subgraph
>
graph_storage
=
std
::
make_shared
<
Subgraph
>
();
std
::
shared_ptr
<
UniqueKey
>
graph_key
=
std
::
make_shared
<
UniqueKey
>
();
Subgraph
&
graph
=
*
graph_storage
;
mgb
::
SmallVector
<
bool
>
output_grad_mask
;
mgb
::
SmallVector
<
bool
>
output_grad_mask
;
Subgraph
::
var_t
next_var
=
1
;
Subgraph
::
var_t
next_var
=
1
;
std
::
shared_ptr
<
OpDef
>
build
()
const
{
return
SubgraphOp
::
make
(
name
,
graph_storage
,
output_grad_mask
,
graph_key
);
}
};
};
py
::
class_
<
PySubgraphBuilder
>
(
m
,
"SubgraphBuilder"
)
py
::
class_
<
PySubgraphBuilder
>
(
m
,
"SubgraphBuilder"
)
...
@@ -518,10 +524,9 @@ void init_ops(py::module m) {
...
@@ -518,10 +524,9 @@ void init_ops(py::module m) {
self
.
output_grad_mask
=
outputs_has_grad
;
self
.
output_grad_mask
=
outputs_has_grad
;
})
})
.
def
(
"get"
,
[](
PySubgraphBuilder
&
self
){
.
def
(
"get"
,
[](
PySubgraphBuilder
&
self
){
return
(
std
::
shared_ptr
<
OpDef
>
)
SubgraphOp
::
make
(
self
.
name
,
self
.
graph
,
self
.
output_grad_mask
);
return
(
std
::
shared_ptr
<
OpDef
>
)
self
.
build
(
);
})
})
.
def
(
"compile"
,
[](
PySubgraphBuilder
&
self
,
int
gopt_level
){
.
def
(
"compile"
,
[](
PySubgraphBuilder
&
self
,
int
gopt_level
){
auto
op
=
SubgraphOp
::
make
(
self
.
name
,
self
.
graph
,
self
.
output_grad_mask
);
return
(
std
::
shared_ptr
<
OpDef
>
)
CompiledOp
::
make
(
self
.
build
(),
gopt_level
);
return
(
std
::
shared_ptr
<
OpDef
>
)
CompiledOp
::
make
(
op
,
gopt_level
);
});
});
}
}
imperative/src/impl/ops/utility.cpp
浏览文件 @
2a063f8e
...
@@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity)
...
@@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity)
namespace
{
namespace
subgraph
{
namespace
{
namespace
subgraph
{
EncodedSubraph
make_forward_graph
(
const
OpDef
&
def
,
SmallVector
<
LogicalTensorDesc
>
inputs
)
{
EncodedSubraph
make_forward_graph
(
const
OpDef
&
def
,
SmallVector
<
LogicalTensorDesc
>
inputs
)
{
return
EncodedSubraph
::
make
(
def
.
cast_final_safe
<
SubgraphOp
>
().
graph
);
return
EncodedSubraph
::
make
(
*
def
.
cast_final_safe
<
SubgraphOp
>
().
graph
);
}
}
EncodedSubraph
make_backward_graph
(
EncodedSubraph
make_backward_graph
(
...
@@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph(
...
@@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph(
}
}
}
}
auto
bgraph
=
subgraph_detail
::
make_backward_graph
(
def
,
inputs
,
input_requires_grad
,
output_has_grad
);
auto
bgraph
=
subgraph_detail
::
make_backward_graph
(
def
,
inputs
,
input_requires_grad
,
output_has_grad
);
return
EncodedSubraph
::
make_single
(
SubgraphOp
::
make
(
op
.
name
+
"Grad"
,
bgraph
.
graph
),
bgraph
.
input_mask
,
bgraph
.
output_mask
);
return
EncodedSubraph
::
make_single
(
SubgraphOp
::
make
(
op
.
name
+
"Grad"
,
std
::
make_shared
<
Subgraph
>
(
bgraph
.
graph
)),
bgraph
.
input_mask
,
bgraph
.
output_mask
);
}
}
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
def
)
{
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
def
)
{
auto
&
op
=
def
.
cast_final_safe
<
SubgraphOp
>
();
auto
&
op
=
def
.
cast_final_safe
<
SubgraphOp
>
();
return
{
return
{
{
"name"
,
op
.
name
},
{
"name"
,
op
.
name
},
{
"inputs"
,
mgb
::
imperative
::
to_string
(
op
.
graph
.
inputs
)},
{
"inputs"
,
mgb
::
imperative
::
to_string
(
op
.
graph
->
inputs
)},
{
"exprs"
,
mgb
::
imperative
::
to_string
(
op
.
graph
.
exprs
)},
{
"exprs"
,
mgb
::
imperative
::
to_string
(
op
.
graph
->
exprs
)},
{
"outputs"
,
mgb
::
imperative
::
to_string
(
op
.
graph
.
outputs
)},
{
"outputs"
,
mgb
::
imperative
::
to_string
(
op
.
graph
->
outputs
)},
};
};
}
}
...
@@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) {
...
@@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) {
auto
hash
(
const
OpDef
&
def
)
{
auto
hash
(
const
OpDef
&
def
)
{
auto
&
op
=
def
.
cast_final_safe
<
SubgraphOp
>
();
auto
&
op
=
def
.
cast_final_safe
<
SubgraphOp
>
();
if
(
!
op
.
graph_key
)
{
if
(
!
op
.
graph_key
)
{
return
(
size_t
)
reinterpret_cast
<
uintptr_t
>
(
&
op
.
graph
);
return
(
size_t
)
reinterpret_cast
<
uintptr_t
>
(
op
.
graph
.
get
()
);
}
}
return
op
.
graph_key
->
hash
();
return
op
.
graph_key
->
hash
();
}
}
...
@@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) {
...
@@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) {
if
(
has_graph_key
)
{
if
(
has_graph_key
)
{
graph_same
=
rhs
.
graph_key
&&
lhs
.
graph_key
->
is_same
(
*
rhs
.
graph_key
);
graph_same
=
rhs
.
graph_key
&&
lhs
.
graph_key
->
is_same
(
*
rhs
.
graph_key
);
}
else
{
}
else
{
graph_same
=
!
rhs
.
graph_key
&&
&
lhs
.
graph
==
&
rhs
.
graph
;
graph_same
=
!
rhs
.
graph_key
&&
lhs
.
graph
.
get
()
==
rhs
.
graph
.
get
()
;
}
}
return
graph_same
;
return
graph_same
;
}
}
...
@@ -354,7 +357,9 @@ auto apply_on_physical_tensor(
...
@@ -354,7 +357,9 @@ auto apply_on_physical_tensor(
auto
apply_on_var_node
(
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
const
VarNodeArray
&
inputs
)
{
return
OpDef
::
apply_on_var_node
(
*
def
.
cast_final_safe
<
CompiledOp
>
().
op
,
inputs
);
auto
&
op
=
def
.
cast_final_safe
<
CompiledOp
>
();
op
.
op
->
set_scope
(
op
.
scope
());
return
OpDef
::
apply_on_var_node
(
*
op
.
op
,
inputs
);
}
}
auto
infer_output_attrs_fallible
(
auto
infer_output_attrs_fallible
(
...
@@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph(
...
@@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph(
if
(
backward_graph
.
graph
.
is_single
())
{
if
(
backward_graph
.
graph
.
is_single
())
{
bgraph_op
=
backward_graph
.
graph
.
as_single
();
bgraph_op
=
backward_graph
.
graph
.
as_single
();
}
else
{
}
else
{
bgraph_op
=
SubgraphOp
::
make
(
name
+
"Grad"
,
backward_graph
.
graph
,
grad_outputs_has_grad
,
key
);
bgraph_op
=
SubgraphOp
::
make
(
name
+
"Grad"
,
std
::
make_shared
<
Subgraph
>
(
backward_graph
.
graph
),
grad_outputs_has_grad
,
key
);
}
}
auto
compiled_op
=
CompiledOp
::
make
(
bgraph_op
,
op
.
gopt_level
);
auto
compiled_op
=
CompiledOp
::
make
(
bgraph_op
,
op
.
gopt_level
);
auto
encoded_graph
=
EncodedSubraph
::
make_single
(
compiled_op
,
backward_graph
.
input_mask
,
backward_graph
.
output_mask
);
auto
encoded_graph
=
EncodedSubraph
::
make_single
(
compiled_op
,
backward_graph
.
input_mask
,
backward_graph
.
output_mask
);
...
@@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp)
...
@@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp)
.
fallback
();
.
fallback
();
}}
}}
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
UniqueKey
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
SubgraphOp
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
SubgraphOp
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BackwardOpKey
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BackwardOpKey
);
...
...
imperative/src/impl/subgraph_detail.cpp
浏览文件 @
2a063f8e
...
@@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node(
...
@@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node(
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
input_descs
.
push_back
({
TensorLayout
{
input
->
dtype
()},
input
->
comp_node
()});
input_descs
.
push_back
({
TensorLayout
{
input
->
dtype
()},
input
->
comp_node
()});
}
}
auto
apply_functor
=
[](
const
std
::
shared_ptr
<
OpDef
>&
op
,
const
VarNodeArray
&
inputs
,
size_t
nr_outputs
){
auto
apply_functor
=
[
&
](
const
std
::
shared_ptr
<
OpDef
>&
op
,
const
VarNodeArray
&
inputs
,
size_t
nr_outputs
){
op
->
set_scope
(
def
.
scope
());
return
OpDef
::
apply_on_var_node
(
*
op
,
inputs
);
return
OpDef
::
apply_on_var_node
(
*
op
,
inputs
);
};
};
auto
const_functor
=
[
&
](
const
TensorPtr
&
value
)
{
auto
const_functor
=
[
&
](
const
TensorPtr
&
value
)
{
...
...
imperative/src/include/megbrain/imperative/ops/utility.h
浏览文件 @
2a063f8e
...
@@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> {
...
@@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> {
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
};
};
struct
UniqueKey
final
:
Hashable
{
public:
size_t
hash
()
const
override
{
return
reinterpret_cast
<
uintptr_t
>
(
this
);
}
protected:
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
{
return
this
==
&
rhs
.
cast_final_safe
<
UniqueKey
>
();
}
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
};
struct
SubgraphOp
final
:
OpDefImplBase
<
SubgraphOp
>
{
struct
SubgraphOp
final
:
OpDefImplBase
<
SubgraphOp
>
{
std
::
string
name
;
std
::
string
name
;
Subgraph
graph
;
std
::
shared_ptr
<
Subgraph
>
graph
;
SmallVector
<
bool
>
output_grad_mask
;
SmallVector
<
bool
>
output_grad_mask
;
std
::
shared_ptr
<
Hashable
>
graph_key
;
std
::
shared_ptr
<
Hashable
>
graph_key
;
SubgraphOp
()
=
default
;
SubgraphOp
()
=
default
;
SubgraphOp
(
std
::
string
name
,
Subgraph
graph
,
SmallVector
<
bool
>
output_grad_mask
=
{},
std
::
shared_ptr
<
Hashable
>
key
=
nullptr
)
SubgraphOp
(
std
::
string
name
,
std
::
shared_ptr
<
Subgraph
>
graph
,
SmallVector
<
bool
>
output_grad_mask
=
{},
std
::
shared_ptr
<
Hashable
>
key
=
nullptr
)
:
name
{
name
},
graph
{
graph
},
output_grad_mask
{
output_grad_mask
},
graph_key
{
std
::
move
(
key
)}{
:
name
{
name
},
graph
{
graph
},
output_grad_mask
{
output_grad_mask
},
graph_key
{
std
::
move
(
key
)}{
if
(
this
->
output_grad_mask
.
empty
())
{
if
(
this
->
output_grad_mask
.
empty
())
{
this
->
output_grad_mask
.
resize
(
graph
.
outputs
.
size
(),
true
);
this
->
output_grad_mask
.
resize
(
graph
->
outputs
.
size
(),
true
);
}
}
}
}
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录