Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
634de590
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
634de590
编写于
11月 06, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/imperative): add valid flag of `infer_output_attrs_fallible`
GitOrigin-RevId: b2b32774eeb893503c25d3434fa6f2ba64f1c8c6
上级
50c4daac
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
81 addition
and
62 deletion
+81
-62
imperative/src/impl/interpreter_impl.cpp
imperative/src/impl/interpreter_impl.cpp
+15
-4
imperative/src/impl/op_def.cpp
imperative/src/impl/op_def.cpp
+1
-1
imperative/src/impl/ops/backward_graph.cpp
imperative/src/impl/ops/backward_graph.cpp
+20
-16
imperative/src/impl/ops/batch_norm.cpp
imperative/src/impl/ops/batch_norm.cpp
+2
-2
imperative/src/impl/ops/broadcast.cpp
imperative/src/impl/ops/broadcast.cpp
+3
-3
imperative/src/impl/ops/cond_take.cpp
imperative/src/impl/ops/cond_take.cpp
+5
-5
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+4
-4
imperative/src/impl/ops/tensor_manip.cpp
imperative/src/impl/ops/tensor_manip.cpp
+3
-3
imperative/src/impl/profiler.cpp
imperative/src/impl/profiler.cpp
+4
-3
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+15
-10
imperative/src/impl/proxy_graph.h
imperative/src/impl/proxy_graph.h
+2
-2
imperative/src/impl/proxy_graph_detail.cpp
imperative/src/impl/proxy_graph_detail.cpp
+2
-3
imperative/src/impl/proxy_graph_detail.h
imperative/src/impl/proxy_graph_detail.h
+2
-3
imperative/src/include/megbrain/imperative/op_def.h
imperative/src/include/megbrain/imperative/op_def.h
+1
-1
imperative/src/include/megbrain/imperative/ops/backward_graph.h
...tive/src/include/megbrain/imperative/ops/backward_graph.h
+2
-2
未找到文件。
imperative/src/impl/interpreter_impl.cpp
浏览文件 @
634de590
...
@@ -63,15 +63,17 @@ SmallVector<void*> ChannelImpl::apply_op(
...
@@ -63,15 +63,17 @@ SmallVector<void*> ChannelImpl::apply_op(
input_infos
.
push_back
(
info
);
input_infos
.
push_back
(
info
);
input_descs
.
push_back
(
info
->
desc
);
input_descs
.
push_back
(
info
->
desc
);
}
}
auto
output_descs
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
auto
[
output_descs
,
validated
]
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
ApplyOp
cmd
{
std
::
move
(
op
)};
ApplyOp
cmd
{
std
::
move
(
op
)};
cmd
.
inputs
=
std
::
move
(
input_infos
);
cmd
.
inputs
=
std
::
move
(
input_infos
);
cmd
.
outputs
.
reserve
(
output_descs
.
size
());
cmd
.
outputs
.
reserve
(
output_descs
.
size
());
SmallVector
<
void
*>
outputs
;
SmallVector
<
void
*>
outputs
;
bool
is_fallible
=
false
;
// FIXME: remove this check when op check is correct
bool
validated_bkp
=
true
;
for
(
auto
&&
desc
:
output_descs
)
{
for
(
auto
&&
desc
:
output_descs
)
{
if
(
desc
.
layout
.
ndim
==
0
)
{
if
(
desc
.
layout
.
ndim
==
0
)
{
is_fallible
=
tru
e
;
validated_bkp
=
fals
e
;
}
}
auto
info
=
alloc
();
auto
info
=
alloc
();
info
->
desc
=
desc
;
info
->
desc
=
desc
;
...
@@ -80,8 +82,14 @@ SmallVector<void*> ChannelImpl::apply_op(
...
@@ -80,8 +82,14 @@ SmallVector<void*> ChannelImpl::apply_op(
outputs
.
push_back
(
info
);
outputs
.
push_back
(
info
);
}
}
m_worker
.
add_task
(
std
::
move
(
cmd
));
m_worker
.
add_task
(
std
::
move
(
cmd
));
if
(
is_fallible
&&
m_async_level
<=
1
)
{
if
(
!
(
validated
&&
validated_bkp
)
&&
m_async_level
==
1
)
{
sync
();
}
else
if
(
m_async_level
==
0
)
{
sync
();
sync
();
// check device error
for
(
auto
&&
oup
:
cmd
.
outputs
)
{
oup
->
ptr
->
comp_node
().
sync
();
}
}
}
return
outputs
;
return
outputs
;
}
}
...
@@ -194,6 +202,9 @@ ChannelImpl::~ChannelImpl() {
...
@@ -194,6 +202,9 @@ ChannelImpl::~ChannelImpl() {
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
)
{
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
)
{
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
dest
->
value_fetched
=
ptr
->
value_fetched
();
dest
->
value_fetched
=
ptr
->
value_fetched
();
// update tensor desc for static infer
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
ptr
=
std
::
move
(
ptr
);
dest
->
ptr
=
std
::
move
(
ptr
);
if
(
m_waitee
==
dest
)
{
if
(
m_waitee
==
dest
)
{
m_cv
.
notify_all
();
m_cv
.
notify_all
();
...
...
imperative/src/impl/op_def.cpp
浏览文件 @
634de590
...
@@ -42,7 +42,7 @@ cg::OperatorNodeBase* OpDef::apply_on_var_node(
...
@@ -42,7 +42,7 @@ cg::OperatorNodeBase* OpDef::apply_on_var_node(
return
def
.
trait
()
->
apply_on_var_node
(
def
,
inputs
);
return
def
.
trait
()
->
apply_on_var_node
(
def
,
inputs
);
}
}
SmallVector
<
LogicalTensorDesc
>
OpDef
::
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
OpDef
::
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
return
def
.
trait
()
->
infer_output_attrs_fallible
(
def
,
inputs
);
return
def
.
trait
()
->
infer_output_attrs_fallible
(
def
,
inputs
);
...
...
imperative/src/impl/ops/backward_graph.cpp
浏览文件 @
634de590
...
@@ -24,12 +24,12 @@ BackwardGraph::InternalGraph::apply(
...
@@ -24,12 +24,12 @@ BackwardGraph::InternalGraph::apply(
inputs
);
inputs
);
}
}
SmallVector
<
LogicalTensorDesc
>
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
BackwardGraph
::
InternalGraph
::
infer_attrs
(
BackwardGraph
::
InternalGraph
::
infer_attrs
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
const
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
const
{
using
TensorAttr
=
LogicalTensorDesc
;
using
TensorAttr
=
LogicalTensorDesc
;
ThinHashMap
<
size_t
,
TensorAttr
>
node2attr
;
ThinHashMap
<
size_t
,
TensorAttr
>
node2attr
;
auto
&&
input_nodes
=
this
->
inputs
;
auto
&&
input_nodes
=
this
->
inputs
;
auto
&&
output_nodes
=
this
->
outputs
;
mgb_assert
(
inputs
.
size
()
==
input_nodes
.
size
());
mgb_assert
(
inputs
.
size
()
==
input_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
node2attr
[
input_nodes
[
i
]]
=
inputs
[
i
];
node2attr
[
input_nodes
[
i
]]
=
inputs
[
i
];
...
@@ -41,25 +41,29 @@ BackwardGraph::InternalGraph::infer_attrs(
...
@@ -41,25 +41,29 @@ BackwardGraph::InternalGraph::infer_attrs(
i
.
second
->
layout
(),
i
.
second
->
comp_node
(),
i
.
second
->
layout
(),
i
.
second
->
comp_node
(),
value
->
proxy_to_default_cpu
()};
value
->
proxy_to_default_cpu
()};
}
}
bool
validated
=
true
;
for
(
size_t
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
auto
&&
expr
=
exprs
[
i
];
auto
&&
[
expr_op
,
expr_inps
,
expr_oups
]
=
exprs
[
i
];
SmallVector
<
TensorAttr
>
input
s
;
SmallVector
<
TensorAttr
>
expr_input_desc
s
;
for
(
auto
&&
in
:
std
::
get
<
1
>
(
expr
)
)
{
for
(
auto
&&
in
p
:
expr_inps
)
{
inputs
.
push_back
(
node2attr
.
at
(
in
));
expr_input_descs
.
push_back
(
node2attr
.
at
(
inp
));
}
}
auto
outputs
=
OpDef
::
infer_output_attrs_fallible
(
*
std
::
get
<
0
>
(
expr
),
inputs
);
auto
[
expr_output_descs
,
expr_validated
]
=
OpDef
::
infer_output_attrs_fallible
(
auto
output_nodes
=
std
::
get
<
2
>
(
expr
);
*
expr_op
,
expr_input_descs
);
mgb_assert
(
outputs
.
size
()
==
output_nodes
.
size
());
validated
=
validated
&&
expr_validated
;
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
node2attr
[
output_nodes
[
i
]]
=
outputs
[
i
];
mgb_assert
(
expr_output_descs
.
size
()
==
expr_oups
.
size
());
for
(
size_t
i
=
0
;
i
<
expr_output_descs
.
size
();
++
i
)
{
node2attr
[
expr_oups
[
i
]]
=
expr_output_descs
[
i
];
}
}
}
}
SmallVector
<
TensorAttr
>
ret
;
SmallVector
<
TensorAttr
>
ret
;
for
(
auto
&&
i
:
outputs
)
{
for
(
auto
&&
i
:
output
_node
s
)
{
ret
.
push_back
(
node2attr
.
at
(
i
));
ret
.
push_back
(
node2attr
.
at
(
i
));
}
}
return
ret
;
return
{
ret
,
validated
}
;
}
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BackwardGraph
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BackwardGraph
);
...
@@ -72,7 +76,7 @@ SmallVector<TensorPtr> backward_impl(
...
@@ -72,7 +76,7 @@ SmallVector<TensorPtr> backward_impl(
.
graph
().
apply
(
tensors
);
.
graph
().
apply
(
tensors
);
}
}
SmallVector
<
LogicalTensorDesc
>
infer_tensor_attrs
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_tensor_attrs
(
const
OpDef
&
backward_graph
,
const
OpDef
&
backward_graph
,
const
SmallVector
<
LogicalTensorDesc
>
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>
inputs
)
{
return
backward_graph
.
cast_final_safe
<
BackwardGraph
>
()
return
backward_graph
.
cast_final_safe
<
BackwardGraph
>
()
...
...
imperative/src/impl/ops/batch_norm.cpp
浏览文件 @
634de590
...
@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node(
...
@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node(
}
}
}
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
BatchNorm
>
();
auto
&&
op_def
=
def
.
cast_final_safe
<
BatchNorm
>
();
...
@@ -66,7 +66,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
...
@@ -66,7 +66,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
out_shapes
[
i
]
=
{
i1
.
layout
,
i1
.
comp_node
};
out_shapes
[
i
]
=
{
i1
.
layout
,
i1
.
comp_node
};
}
}
out_shapes
[
nr_out
-
1
]
=
{
i0
.
layout
,
i0
.
comp_node
};
out_shapes
[
nr_out
-
1
]
=
{
i0
.
layout
,
i0
.
comp_node
};
return
out_shapes
;
return
{
out_shapes
,
true
}
;
}
}
OP_TRAIT_REG
(
BatchNorm
,
BatchNorm
,
opr
::
BatchNorm
)
OP_TRAIT_REG
(
BatchNorm
,
BatchNorm
,
opr
::
BatchNorm
)
...
...
imperative/src/impl/ops/broadcast.cpp
浏览文件 @
634de590
...
@@ -47,7 +47,7 @@ bool valid_broadcast(const TensorShape& src_shape,
...
@@ -47,7 +47,7 @@ bool valid_broadcast(const TensorShape& src_shape,
return
true
;
return
true
;
}
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
def
.
cast_final_safe
<
Broadcast
>
();
def
.
cast_final_safe
<
Broadcast
>
();
...
@@ -59,7 +59,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
...
@@ -59,7 +59,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorLayout
out_layout
=
src
.
layout
;
TensorLayout
out_layout
=
src
.
layout
;
if
(
tshp
.
layout
.
ndim
==
0
||
tshp
.
value
.
empty
())
{
if
(
tshp
.
layout
.
ndim
==
0
||
tshp
.
value
.
empty
())
{
out_layout
.
ndim
=
0
;
out_layout
.
ndim
=
0
;
return
{{
out_layout
,
src
.
comp_node
}
};
return
{{
{
out_layout
,
src
.
comp_node
}},
true
};
}
}
mgb_assert
(
mgb_assert
(
tshp
.
layout
.
ndim
==
1
,
tshp
.
layout
.
ndim
==
1
,
...
@@ -77,7 +77,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
...
@@ -77,7 +77,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
src
.
layout
.
TensorShape
::
to_string
().
c_str
(),
src
.
layout
.
TensorShape
::
to_string
().
c_str
(),
out_layout
.
TensorShape
::
to_string
().
c_str
());
out_layout
.
TensorShape
::
to_string
().
c_str
());
return
{{
out_layout
,
src
.
comp_node
}
};
return
{{
{
out_layout
,
src
.
comp_node
}},
true
};
}
}
OP_TRAIT_REG
(
Broadcast
,
Broadcast
,
opr
::
Broadcast
)
OP_TRAIT_REG
(
Broadcast
,
Broadcast
,
opr
::
Broadcast
)
...
...
imperative/src/impl/ops/cond_take.cpp
浏览文件 @
634de590
...
@@ -110,14 +110,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -110,14 +110,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return
out
;
return
out
;
}
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
cn
=
inputs
[
0
].
comp_node
;
auto
cn
=
inputs
[
0
].
comp_node
;
return
{
return
{
{
{
TensorLayout
(
inputs
[
0
].
layout
.
dtype
),
cn
},
{
TensorLayout
(
inputs
[
0
].
layout
.
dtype
),
cn
},
{
TensorLayout
(
dtype
::
Int32
()),
cn
}
{
TensorLayout
(
dtype
::
Int32
()),
cn
}
};
}
,
true
}
;
}
}
OP_TRAIT_REG
(
CondTake
,
CondTake
,
opr
::
CondTake
)
OP_TRAIT_REG
(
CondTake
,
CondTake
,
opr
::
CondTake
)
...
...
imperative/src/impl/ops/elemwise.cpp
浏览文件 @
634de590
...
@@ -29,7 +29,7 @@ cg::OperatorNodeBase* apply_on_var_node(
...
@@ -29,7 +29,7 @@ cg::OperatorNodeBase* apply_on_var_node(
return
opr
::
Elemwise
::
make
(
inputs
,
elemwise_opr
.
mode
).
node
()
->
owner_opr
();
return
opr
::
Elemwise
::
make
(
inputs
,
elemwise_opr
.
mode
).
node
()
->
owner_opr
();
}
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Elemwise
>
();
auto
&&
op_def
=
def
.
cast_final_safe
<
Elemwise
>
();
...
@@ -55,12 +55,12 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
...
@@ -55,12 +55,12 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorLayout
out_layout
;
TensorLayout
out_layout
;
out_layout
.
ndim
=
0
;
out_layout
.
ndim
=
0
;
out_layout
.
dtype
=
out_dt
;
out_layout
.
dtype
=
out_dt
;
return
{{
out_layout
,
out_cn
}
};
return
{{
{
out_layout
,
out_cn
}},
true
};
}
}
}
}
auto
&&
out_shape
=
opr
::
Elemwise
::
get_output_var_shape
(
op_def
.
mode
,
inp_shapes
);
auto
&&
out_shape
=
opr
::
Elemwise
::
get_output_var_shape
(
op_def
.
mode
,
inp_shapes
);
return
{{
TensorLayout
(
out_shape
,
out_dt
,
inputs
[
0
].
layout
.
format
),
out_cn
}
};
return
{{
{
TensorLayout
(
out_shape
,
out_dt
,
inputs
[
0
].
layout
.
format
),
out_cn
}},
true
};
}
}
OP_TRAIT_REG
(
Elemwise
,
Elemwise
,
opr
::
Elemwise
)
OP_TRAIT_REG
(
Elemwise
,
Elemwise
,
opr
::
Elemwise
)
...
...
imperative/src/impl/ops/tensor_manip.cpp
浏览文件 @
634de590
...
@@ -40,21 +40,21 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -40,21 +40,21 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return
{
Tensor
::
make
(
std
::
move
(
hv
))};
return
{
Tensor
::
make
(
std
::
move
(
hv
))};
}
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
def
.
cast_final_safe
<
GetVarShape
>
();
def
.
cast_final_safe
<
GetVarShape
>
();
mgb_assert
(
inputs
.
size
()
==
1
,
"GetVarShape take 1 input, got %lu"
,
inputs
.
size
());
mgb_assert
(
inputs
.
size
()
==
1
,
"GetVarShape take 1 input, got %lu"
,
inputs
.
size
());
auto
&&
desc
=
inputs
[
0
];
auto
&&
desc
=
inputs
[
0
];
if
(
!
desc
.
layout
.
ndim
)
{
if
(
!
desc
.
layout
.
ndim
)
{
return
{{
TensorLayout
(
dtype
::
Int32
()),
desc
.
comp_node
}
};
return
{{
{
TensorLayout
(
dtype
::
Int32
()),
desc
.
comp_node
}},
true
};
}
}
DeviceTensorND
value
(
CompNode
::
default_cpu
(),
{
desc
.
layout
.
ndim
},
dtype
::
Int32
());
DeviceTensorND
value
(
CompNode
::
default_cpu
(),
{
desc
.
layout
.
ndim
},
dtype
::
Int32
());
auto
*
ptr
=
value
.
ptr
<
dt_int32
>
();
auto
*
ptr
=
value
.
ptr
<
dt_int32
>
();
for
(
size_t
i
=
0
;
i
<
desc
.
layout
.
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
desc
.
layout
.
ndim
;
++
i
)
{
ptr
[
i
]
=
desc
.
layout
[
i
];
ptr
[
i
]
=
desc
.
layout
[
i
];
}
}
return
{{
value
.
layout
(),
desc
.
comp_node
,
std
::
move
(
value
)}
};
return
{{
{
value
.
layout
(),
desc
.
comp_node
,
std
::
move
(
value
)}},
true
};
}
}
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
...
...
imperative/src/impl/profiler.cpp
浏览文件 @
634de590
...
@@ -28,12 +28,13 @@ namespace {
...
@@ -28,12 +28,13 @@ namespace {
CompNode
::
UnorderedSet
collect_comp_nodes
(
CompNode
::
UnorderedSet
collect_comp_nodes
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
CompNode
::
UnorderedSet
comp_nodes
;
CompNode
::
UnorderedSet
comp_nodes
;
SmallVector
<
LogicalTensorDesc
>
descs
;
SmallVector
<
LogicalTensorDesc
>
inp_
descs
;
for
(
auto
&&
i
:
inputs
)
{
for
(
auto
&&
i
:
inputs
)
{
comp_nodes
.
insert
(
i
->
comp_node
());
comp_nodes
.
insert
(
i
->
comp_node
());
descs
.
push_back
({
i
->
layout
(),
i
->
comp_node
(),
{}});
inp_
descs
.
push_back
({
i
->
layout
(),
i
->
comp_node
(),
{}});
}
}
for
(
auto
&&
output_attr
:
def
.
infer_output_attrs_fallible
(
def
,
descs
))
{
SmallVector
<
LogicalTensorDesc
>
oup_descs
=
std
::
get
<
0
>
(
def
.
infer_output_attrs_fallible
(
def
,
inp_descs
));
for
(
auto
&&
output_attr
:
oup_descs
)
{
comp_nodes
.
insert
(
output_attr
.
comp_node
);
comp_nodes
.
insert
(
output_attr
.
comp_node
);
}
}
return
comp_nodes
;
return
comp_nodes
;
...
...
imperative/src/impl/proxy_graph.cpp
浏览文件 @
634de590
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "megbrain/graph/static_infer.h"
#include "megbrain/graph/static_infer.h"
#include "megbrain/graph/operator_node.h"
#include "megbrain/graph/operator_node.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/backward_graph.h"
...
@@ -590,10 +591,9 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
...
@@ -590,10 +591,9 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
vinputs
[
i
]
=
InputPlaceholder
::
make
(
*
m_graph
,
*
inputs
[
i
]).
node
();
vinputs
[
i
]
=
InputPlaceholder
::
make
(
*
m_graph
,
*
inputs
[
i
]).
node
();
}
}
auto
opr
=
OpDef
::
apply_on_var_node
(
opdef
,
vinputs
);
auto
opr
=
OpDef
::
apply_on_var_node
(
opdef
,
vinputs
);
mgb_assert
(
opr
->
dyn_typeinfo
()
!=
InputPlaceholder
::
typeinfo
());
mgb_assert
(
!
opr
->
same_type
<
InputPlaceholder
>
());
for
(
auto
&&
i
:
opr
->
input
())
{
for
(
auto
&&
i
:
opr
->
input
())
{
mgb_assert
(
i
->
owner_opr
()
->
dyn_typeinfo
()
==
mgb_assert
(
i
->
owner_opr
()
->
same_type
<
InputPlaceholder
>
());
InputPlaceholder
::
typeinfo
());
}
}
return
opr
;
return
opr
;
}
}
...
@@ -605,17 +605,18 @@ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef,
...
@@ -605,17 +605,18 @@ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef,
return
get_proxy_opr
(
opdef
,
inputs
)
->
usable_output
().
size
();
return
get_proxy_opr
(
opdef
,
inputs
)
->
usable_output
().
size
();
}
}
SmallVector
<
LogicalTensorDesc
>
ProxyGraph
::
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
ProxyGraph
::
infer_output_attrs_fallible
(
const
OpDef
&
opdef
,
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
opr
=
get_proxy_opr
(
opdef
,
inputs
);
auto
opr
=
get_proxy_opr
(
opdef
,
inputs
);
CUR_OPR_GUARD
(
opr
);
CUR_OPR_GUARD
(
opr
);
do_shape_infer
(
false
)
;
SmallVector
<
LogicalTensorDesc
>
outputs
;
SmallVector
<
LogicalTensorDesc
>
ret
;
bool
validated
=
do_shape_infer
(
false
)
;
for
(
auto
&&
i
:
opr
->
usable_output
())
{
for
(
auto
&&
i
:
opr
->
usable_output
())
{
ret
.
push_back
({{
i
->
shape
(),
i
->
dtype
()},
i
->
comp_node
()});
outputs
.
push_back
({{
i
->
shape
(),
i
->
dtype
()},
i
->
comp_node
()});
}
}
return
ret
;
bool
need_check
=
opr
->
same_type
<
opr
::
Reshape
>
();
return
{
outputs
,
validated
&&
!
need_check
};
}
}
struct
ProxyGraph
::
GradGraph
{
struct
ProxyGraph
::
GradGraph
{
...
@@ -811,16 +812,20 @@ VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTenso
...
@@ -811,16 +812,20 @@ VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTenso
/*********************** Common Impl ***********************/
/*********************** Common Impl ***********************/
void
ProxyGraph
::
do_shape_infer
(
bool
sync_value
)
{
bool
ProxyGraph
::
do_shape_infer
(
bool
sync_value
)
{
m_static_infer_manager
->
update
();
m_static_infer_manager
->
update
();
bool
validated
=
true
;
for
(
auto
*
var
:
m_cur_opr
->
output
())
{
for
(
auto
*
var
:
m_cur_opr
->
output
())
{
if
(
sync_value
)
{
if
(
sync_value
)
{
var
->
shape
(
m_static_infer_manager
->
infer_shape
(
var
));
var
->
shape
(
m_static_infer_manager
->
infer_shape
(
var
));
}
else
if
(
auto
*
shape
=
m_static_infer_manager
->
infer_shape_fallible
(
var
))
{
}
else
if
(
auto
*
shape
=
m_static_infer_manager
->
infer_shape_fallible
(
var
))
{
var
->
shape
(
*
shape
);
var
->
shape
(
*
shape
);
}
else
{
validated
=
false
;
}
}
}
}
return
validated
;
}
}
TensorPtr
ProxyGraph
::
as_tensor
(
cg
::
OperatorNodeBase
*
opr
,
bool
share
)
{
TensorPtr
ProxyGraph
::
as_tensor
(
cg
::
OperatorNodeBase
*
opr
,
bool
share
)
{
...
...
imperative/src/impl/proxy_graph.h
浏览文件 @
634de590
...
@@ -48,7 +48,7 @@ public:
...
@@ -48,7 +48,7 @@ public:
const
OpDef
&
opdef
,
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
opdef
,
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
...
@@ -88,7 +88,7 @@ private:
...
@@ -88,7 +88,7 @@ private:
/********************** Common Helper **********************/
/********************** Common Helper **********************/
void
do_shape_infer
(
bool
sync_value
);
bool
do_shape_infer
(
bool
sync_value
);
TensorPtr
as_tensor
(
cg
::
OperatorNodeBase
*
opr
,
bool
share
=
true
);
TensorPtr
as_tensor
(
cg
::
OperatorNodeBase
*
opr
,
bool
share
=
true
);
...
...
imperative/src/impl/proxy_graph_detail.cpp
浏览文件 @
634de590
...
@@ -80,8 +80,7 @@ apply_on_physical_tensor(const OpDef& def,
...
@@ -80,8 +80,7 @@ apply_on_physical_tensor(const OpDef& def,
return
outputs
;
return
outputs
;
}
}
SmallVector
<
LogicalTensorDesc
>
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
graph
=
ProxyGraph
::
get_default_graph
();
auto
&&
graph
=
ProxyGraph
::
get_default_graph
();
return
graph
->
infer_output_attrs_fallible
(
def
,
inputs
);
return
graph
->
infer_output_attrs_fallible
(
def
,
inputs
);
...
...
imperative/src/impl/proxy_graph_detail.h
浏览文件 @
634de590
...
@@ -21,8 +21,7 @@ SmallVector<TensorPtr>
...
@@ -21,8 +21,7 @@ SmallVector<TensorPtr>
apply_on_physical_tensor
(
const
OpDef
&
def
,
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
);
const
SmallVector
<
TensorPtr
>&
inputs
);
SmallVector
<
LogicalTensorDesc
>
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
BackwardGraphResult
BackwardGraphResult
...
...
imperative/src/include/megbrain/imperative/op_def.h
浏览文件 @
634de590
...
@@ -44,7 +44,7 @@ public:
...
@@ -44,7 +44,7 @@ public:
const
OpDef
&
def
,
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
);
const
VarNodeArray
&
inputs
);
static
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
static
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
...
...
imperative/src/include/megbrain/imperative/ops/backward_graph.h
浏览文件 @
634de590
...
@@ -38,8 +38,8 @@ public:
...
@@ -38,8 +38,8 @@ public:
SmallVector
<
TensorPtr
>
SmallVector
<
TensorPtr
>
apply
(
const
SmallVector
<
TensorPtr
>&
inputs
)
const
;
apply
(
const
SmallVector
<
TensorPtr
>&
inputs
)
const
;
SmallVector
<
LogicalTensorDesc
>
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_attrs
(
infer_attrs
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
const
;
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
const
;
template
<
typename
T
,
typename
F
,
typename
C
>
template
<
typename
T
,
typename
F
,
typename
C
>
SmallVector
<
T
>
interpret
(
F
&&
f
,
C
&&
c
,
const
SmallVector
<
T
>&
inputs
)
const
{
SmallVector
<
T
>
interpret
(
F
&&
f
,
C
&&
c
,
const
SmallVector
<
T
>&
inputs
)
const
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录