Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fea46ea9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fea46ea9
编写于
5月 11, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative): add opr cache for apply_on_physical_tensor
GitOrigin-RevId: fc5d5fb34d2379905e1130d9b0572ba596fb9fe4
上级
ea4e6ab9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
298 addition
and
34 deletion
+298
-34
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+3
-2
imperative/src/impl/proxy_graph/mini_graph.h
imperative/src/impl/proxy_graph/mini_graph.h
+276
-26
imperative/src/impl/proxy_graph/proxy_graph.cpp
imperative/src/impl/proxy_graph/proxy_graph.cpp
+8
-0
imperative/src/impl/proxy_graph_detail.cpp
imperative/src/impl/proxy_graph_detail.cpp
+10
-6
src/opr/impl/tensor_manip.cpp
src/opr/impl/tensor_manip.cpp
+1
-0
未找到文件。
imperative/python/src/tensor.cpp
浏览文件 @
fea46ea9
...
...
@@ -535,8 +535,9 @@ CompNode _get_device(PyObject* const* args, size_t nargs) {
->
m_node
->
comp_node
();
if
(
cn1
!=
cn
)
{
throw
py
::
value_error
(
ssprintf
(
"ambiguous device: %s vs %s"
,
cn
.
to_string
().
c_str
(),
cn1
.
to_string
().
c_str
()));
"ambiguous device: %s (from %s) vs %s (from %s)"
,
cn
.
to_string
().
c_str
(),
cn
.
to_string_logical
().
c_str
(),
cn1
.
to_string
().
c_str
(),
cn1
.
to_string_logical
().
c_str
()));
}
}
}
...
...
imperative/src/impl/proxy_graph/mini_graph.h
浏览文件 @
fea46ea9
...
...
@@ -11,8 +11,9 @@
#include "megbrain/graph/operator_node.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/
physical_tensor
.h"
#include "megbrain/imperative/
ops/autogen
.h"
#include "../blob_manager_impl.h"
#include "./common.h"
#include "./proxy_graph_base.h"
...
...
@@ -80,6 +81,20 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>;
template
<
typename
T
>
TensorAdaptor
(
T
*
)
->
TensorAdaptor
<
T
,
void
>
;
SmallVector
<
Tensor
*>
to_raw_ptr_array
(
const
SmallVector
<
TensorPtr
>&
inputs
,
bool
ensure_storage
=
true
)
{
SmallVector
<
Tensor
*>
ret
;
for
(
auto
&&
i
:
inputs
)
{
mgb_assert
(
i
);
ret
.
push_back
(
i
.
get
());
if
(
ensure_storage
)
{
// apply lazy allocation
i
->
blob
()
->
storage
();
}
}
return
ret
;
}
// single opr graph, for static inference and execution
// contains static inference descs
class
ProxyGraph
::
MiniGraph
{
...
...
@@ -146,6 +161,9 @@ protected:
virtual
const
DeviceTensorND
*
infer_value_fallible
(
VarNode
*
)
{
mgb_assert
(
0
);
}
};
size_t
buf_size
;
SmallVector
<
size_t
>
hash_buf
;
OperatorNodeBase
*
m_opr
=
nullptr
;
SmallVector
<
std
::
unique_ptr
<
OperatorNodeBase
>>
opr_ref_keeper
;
...
...
@@ -194,6 +212,7 @@ protected:
return
nullptr
;
}
}
return
&
storage
.
value
();
}
else
{
auto
&
value
=
tensor
.
value
();
return
value
.
shape_valid
()
?
&
value
:
nullptr
;
...
...
@@ -203,8 +222,10 @@ protected:
public:
template
<
typename
I
,
typename
G
>
MiniGraph
(
G
&
graph
,
const
OpDef
&
opdef
,
const
I
&
inputs
)
:
input_value_storage
(
inputs
.
size
())
{
MiniGraph
(
G
&
graph
,
const
OpDef
&
opdef
,
const
I
&
inputs
,
const
size_t
*
hash_buf_
,
const
size_t
buf_size_
)
:
buf_size
(
buf_size_
),
input_value_storage
(
inputs
.
size
())
{
mgb_assert
(
!
m_opr
);
auto
_
=
graph
.
scoped_attach
(
this
);
cg
::
VarNodeArray
vinputs
(
inputs
.
size
());
...
...
@@ -222,7 +243,8 @@ public:
}
m_opr
->
init_output_static_infer_desc
();
// fix permuted input
// fix permuted input: the order of m_opr->input() and vinputs may be
// different, input_remap keeps the index map of m_opr->input() and vinputs
input_remap
.
reserve
(
m_opr
->
input
().
size
());
for
(
auto
*
v
:
m_opr
->
input
())
{
auto
[
found
,
i
]
=
find_index
(
vinputs
,
v
);
...
...
@@ -248,6 +270,23 @@ public:
mgb_assert
(
found
);
output_remap
.
push_back
(
i
);
}
hash_buf
.
resize
(
buf_size
);
for
(
size_t
i
=
0
;
i
<
buf_size
;
++
i
)
{
hash_buf
[
i
]
=
hash_buf_
[
i
];
}
}
bool
is_same_buf
(
const
size_t
hash_buf_
[],
const
size_t
buf_size_
)
{
if
(
buf_size
!=
buf_size_
)
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
buf_size
;
i
++
)
{
if
(
hash_buf
[
i
]
!=
hash_buf_
[
i
])
{
return
false
;
}
}
return
true
;
}
// methods for containing graph
...
...
@@ -264,6 +303,87 @@ public:
return
m_opr
;
}
void
init_input_tensor
(
const
SmallVector
<
Tensor
*>&
inputs
)
{
auto
&&
opr_inputs
=
m_opr
->
input
();
mgb_assert
(
opr_inputs
.
size
()
==
inputs
.
size
());
size_t
idx
=
0
;
for
(
auto
&&
input
:
opr_inputs
)
{
mgb_assert
(
input
->
owner_opr
()
->
same_type
<
InputPlaceholder
>
());
input
->
m_dev_tensor
.
storage
({});
auto
&&
dev_tensor
=
inputs
[
input_remap
[
idx
]]
->
dev_tensor
();
auto
&&
layout
=
dev_tensor
.
layout
();
input
->
shape
(
dev_tensor
.
shape
());
auto
&&
chk
=
input
->
m_mem_plan
.
reset_from_owner_var
().
chunk
();
input
->
m_dev_tensor
.
reset
(
dev_tensor
.
storage
(),
layout
);
input
->
m_mem_plan
.
layout
(
layout
);
chk
.
mem_alloc_status
.
set_from_owner_var
();
mgb_assert
(
input
->
comp_node
()
==
dev_tensor
.
comp_node
());
mgb_assert
(
input
->
shape
().
eq_shape
(
layout
));
mgb_assert
(
input
->
dtype
()
==
layout
.
dtype
);
idx
++
;
}
}
void
init_output_tensor
(
const
SmallVector
<
Tensor
*>&
outputs
)
{
size_t
idx
=
0
;
mgb_assert
(
m_opr
->
usable_output
().
size
()
==
outputs
.
size
());
for
(
auto
&&
var
:
m_opr
->
output
())
{
auto
&&
chk
=
var
->
m_mem_plan
.
reset_from_owner_var
().
chunk
();
if
(
var
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
// alloc workspace
TensorLayout
layout
{
var
->
shape
(),
var
->
dtype
(),
var
->
format
()};
var
->
m_dev_tensor
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
var
->
comp_node
(),
layout
);
}
else
{
mgb_assert
(
idx
<
outputs
.
size
());
auto
&&
tensor
=
outputs
[
idx
];
auto
&&
layout
=
tensor
->
layout
();
mgb_assert
(
var
->
comp_node
()
==
tensor
->
comp_node
());
mgb_assert
(
var
->
shape
().
eq_shape
(
layout
));
mgb_assert
(
var
->
dtype
()
==
layout
.
dtype
);
if
(
!
tensor
->
layout
().
is_empty
())
{
var
->
assign_dev_tensor_from_tensor
(
tensor
->
dev_tensor
());
}
else
{
var
->
m_dev_tensor
.
storage
({
var
->
comp_node
()});
}
++
idx
;
}
chk
.
mem_alloc_status
.
set_from_owner_var
();
}
mgb_assert
(
idx
==
outputs
.
size
());
// Memory forwarding was bypassed in megbrain with graph option
// imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly
// to initialize some opr(e.g. Subtensor)'s internal state
// TODO: implement memory forwarding
m_opr
->
mem_plan_fwd_in2out_readonly
();
{
// some opr (e.g. Reduce) rely on on_mem_status_changed to set
// input/output tensor corretly, since we bypass var_node_mem_mgr
// on_mem_status_changed should be called here
auto
&&
cb
=
m_opr
->
get_opr_event_callback
().
on_mem_status_changed
;
if
(
cb
.
valid
())
{
cb
.
val
()();
}
}
}
void
execute
(
const
SmallVector
<
Tensor
*>&
inputs
,
const
SmallVector
<
Tensor
*>&
outputs
,
cg
::
GraphExecutable
::
ExecEnv
&
env
)
{
init_input_tensor
(
inputs
);
init_output_tensor
(
outputs
);
m_opr
->
execute
(
env
);
for
(
auto
&&
i
:
m_opr
->
input
())
{
i
->
m_dev_tensor
.
storage
({});
}
for
(
auto
&&
i
:
m_opr
->
output
())
{
i
->
m_dev_tensor
.
storage
({});
}
}
void
register_shape_infer
(
VarNode
*
varnode
,
const
cg
::
static_infer
::
ShapeInferDesc
&
desc
)
{
auto
[
found
,
i
]
=
find_index
(
m_opr
->
output
(),
varnode
);
...
...
@@ -278,15 +398,22 @@ public:
output_data
[
i
].
value_infer
.
initialize
(
m_opr
,
desc
.
deps
,
desc
.
infer_func
);
}
const
TensorShape
&
infer_shape
(
VarNode
*
var
)
{
return
m_sess
->
infer_shape
(
var
);
}
const
TensorShape
&
infer_shape
(
VarNode
*
var
)
{
mgb_assert
(
m_sess
);
return
m_sess
->
infer_shape
(
var
);
}
const
DeviceTensorND
&
infer_value
(
VarNode
*
var
)
{
return
m_sess
->
infer_value
(
var
);
}
const
DeviceTensorND
&
infer_value
(
VarNode
*
var
)
{
mgb_assert
(
m_sess
);
return
m_sess
->
infer_value
(
var
);
}
OperatorNodeBase
*
opr
()
{
return
m_opr
;
}
// inference routine template for type of input
template
<
typename
I
>
class
InferSession
:
protected
InferSessionBase
{
public:
MiniGraph
&
owner
;
SmallVector
<
OutputData
>&
output_data
;
InputAdaptor
<
I
>
inputs
;
...
...
@@ -355,7 +482,7 @@ public:
auto
[
found
,
i
]
=
find_index
(
owner
.
m_opr
->
input
(),
var
);
mgb_assert
(
found
);
i
=
owner
.
input_remap
[
i
];
auto
*
value
=
inputs
.
value
(
i
,
fals
e
);
auto
*
value
=
inputs
.
value
(
i
,
tru
e
);
mgb_assert
(
value
);
return
*
value
;
}
...
...
@@ -379,12 +506,18 @@ public:
const
TensorShape
*
infer_shape
(
size_t
i
,
bool
sync
)
{
i
=
owner
.
output_remap
[
i
];
return
infer
(
output_data
[
i
].
shape_infer
,
sync
);
auto
*
p
=
infer
(
output_data
[
i
].
shape_infer
,
sync
);
if
(
sync
)
mgb_assert
(
p
,
"failed to infer shape"
);
return
p
;
}
const
DeviceTensorND
*
infer_value
(
size_t
i
,
bool
sync
)
{
i
=
owner
.
output_remap
[
i
];
return
infer
(
output_data
[
i
].
shape_infer
,
sync
);
auto
*
p
=
infer
(
output_data
[
i
].
value_infer
,
sync
);
if
(
sync
)
mgb_assert
(
p
,
"failed to infer value"
);
return
p
;
}
};
...
...
@@ -499,10 +632,12 @@ class ProxyGraphTypeI : public ProxyGraphBase {
public:
void
register_shape_infer
(
VarNode
*
var
,
const
cg
::
static_infer
::
ShapeInferDesc
&
desc
)
override
{
mgb_assert
(
target
);
target
->
register_shape_infer
(
var
,
desc
);
};
void
register_value_infer
(
VarNode
*
var
,
const
cg
::
static_infer
::
ValueInferDesc
&
desc
)
override
{
mgb_assert
(
target
);
target
->
register_value_infer
(
var
,
desc
);
};
cg
::
static_infer
::
InferType
get_infer_type
(
VarNode
*
)
override
{
...
...
@@ -511,17 +646,22 @@ class ProxyGraphTypeI : public ProxyGraphBase {
}
// some poorly written inference func would call infer_{shape,value}
const
TensorShape
&
infer_shape
(
VarNode
*
var
)
override
{
mgb_assert
(
target
);
return
target
->
infer_shape
(
var
);
}
const
DeviceTensorND
&
infer_value
(
VarNode
*
var
)
override
{
mgb_assert
(
target
);
return
target
->
infer_value
(
var
);
}
};
ProxyGraph
::
MiniGraph
*
target
=
nullptr
;
StaticInferManager
m_static_infer_manager
;
std
::
unordered_map
<
size_t
,
ProxyGraph
::
MiniGraph
>
m_mini_graph_cache
;
std
::
unordered_multimap
<
size_t
,
ProxyGraph
::
MiniGraph
>
m_mini_graph_cache
;
std
::
mutex
m_mini_graph_cache_mtx
;
size_t
opr_count
=
0
;
ExecEnvBase
m_env
;
CompNode
::
UnorderedSet
m_used_comp_node
;
static
thread_local
std
::
unique_ptr
<
ProxyGraphTypeI
>
sm_instance
;
...
...
@@ -531,8 +671,12 @@ class ProxyGraphTypeI : public ProxyGraphBase {
size_t
next_node_id
()
override
{
return
opr_count
;
}
void
add_used_comp_node
(
CompNode
cn
)
{
m_used_comp_node
.
insert
(
cn
);
}
std
::
shared_ptr
<
void
>
on_comp_node_finalize
()
override
{
sm_instance
.
reset
();
assert
(
!
target
);
MGB_LOCK_GUARD
(
m_mini_graph_cache_mtx
);
m_mini_graph_cache
.
clear
();
return
{};
}
...
...
@@ -575,38 +719,62 @@ class ProxyGraphTypeI : public ProxyGraphBase {
}
public:
~
ProxyGraphTypeI
()
{
if
(
is_finalized
())
{
return
;
}
for
(
auto
&&
i
:
m_used_comp_node
)
{
if
(
i
.
device_type
()
==
CompNode
::
DeviceType
::
CUDA
)
continue
;
i
.
sync
();
}
}
OperatorNodeBase
*
insert_opr
(
std
::
unique_ptr
<
OperatorNodeBase
>
opr_uniqp
)
override
{
mgb_assert
(
target
);
return
target
->
insert_opr
(
std
::
move
(
opr_uniqp
));
}
static
ProxyGraphTypeI
&
inst
()
{
if
(
!
sm_instance
)
{
if
(
!
sm_instance
||
sm_instance
->
is_finalized
()
)
{
sm_instance
.
reset
(
new
ProxyGraphTypeI
);
}
return
*
sm_instance
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
template
<
typename
T
>
ProxyGraph
::
MiniGraph
&
get_cached_minigraph
(
const
OpDef
&
def
,
const
T
&
inputs
)
{
mgb_assert
(
!
is_finalized
());
size_t
buf_size
=
2
*
inputs
.
size
()
+
1
;
size_t
buf
[
buf_size
];
size_t
pos
=
0
;
buf
[
pos
++
]
=
def
.
hash
();
for
(
auto
&&
desc
:
inputs
)
{
buf
[
pos
++
]
=
mgb
::
hash
(
desc
.
layout
.
dtype
.
handle
());
buf
[
pos
++
]
=
mgb
::
hash
(
desc
.
comp_node
);
for
(
auto
&&
inp
:
inputs
)
{
auto
tensor
=
TensorAdaptor
(
inp
);
buf
[
pos
++
]
=
mgb
::
hash
(
tensor
.
dtype
().
handle
());
buf
[
pos
++
]
=
mgb
::
hash
(
tensor
.
comp_node
());
}
mgb_assert
(
pos
==
buf_size
);
auto
key
=
XXHash
{}.
update
(
buf
,
buf_size
*
sizeof
(
size_t
)).
digest
();
auto
it
=
m_mini_graph_cache
.
find
(
key
);
if
(
it
==
m_mini_graph_cache
.
end
())
{
auto
&&
result
=
m_mini_graph_cache
.
emplace
(
std
::
piecewise_construct
,
std
::
make_tuple
(
key
),
std
::
forward_as_tuple
(
*
this
,
def
,
inputs
));
mgb_assert
(
result
.
second
);
it
=
result
.
first
;
}
auto
&
minigraph
=
it
->
second
;
auto
its
=
m_mini_graph_cache
.
equal_range
(
key
);
auto
it
=
its
.
first
;
for
(;
it
!=
its
.
second
;
++
it
)
{
if
(
it
->
second
.
is_same_buf
(
buf
,
buf_size
))
{
return
it
->
second
;
}
mgb_log_warn
(
"hash collision occurs in minigraph cache with key: %lu"
,
key
);
}
auto
&&
result
=
m_mini_graph_cache
.
emplace
(
std
::
piecewise_construct
,
std
::
make_tuple
(
key
),
std
::
forward_as_tuple
(
*
this
,
def
,
inputs
,
static_cast
<
size_t
*>
(
buf
),
buf_size
));
mgb_assert
(
result
->
first
);
return
result
->
second
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&
minigraph
=
get_cached_minigraph
(
def
,
inputs
);
auto
_
=
scoped_attach
(
&
minigraph
);
auto
sess
=
minigraph
.
infer_session
(
inputs
);
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
ret
;
...
...
@@ -627,6 +795,88 @@ public:
}
return
ret
;
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs
(
const
OpDef
&
def
,
const
SmallVector
<
Tensor
*>&
inputs
)
{
SmallVector
<
LogicalTensorDesc
>
descs
;
auto
&
minigraph
=
get_cached_minigraph
(
def
,
inputs
);
auto
_
=
scoped_attach
(
&
minigraph
);
auto
sess
=
minigraph
.
infer_session
(
inputs
);
// some output var in minigraph.opr()->output() may not appears in
// minigraph.opr()->usable_output() bug execution may use the attrs for those
// output var, so we infer attrs for all outputs, but only return
// LogicalTensorDesc for minigraph.opr()->usable_output()
for
(
size_t
i
=
0
;
i
<
minigraph
.
opr
()
->
output
().
size
();
++
i
)
{
auto
*
shape
=
sess
.
infer
(
sess
.
output_data
[
i
].
shape_infer
,
true
);
mgb_assert
(
shape
);
minigraph
.
opr
()
->
output
()[
i
]
->
shape
(
*
shape
);
}
descs
.
reserve
(
minigraph
.
output_size
());
for
(
size_t
i
=
0
;
i
<
minigraph
.
output_size
();
++
i
)
{
auto
*
ovar
=
minigraph
.
output_var
(
i
);
descs
.
emplace_back
();
auto
&
desc
=
descs
.
back
();
desc
.
layout
.
dtype
=
ovar
->
dtype
();
desc
.
comp_node
=
ovar
->
comp_node
();
mgb_assert
(
ovar
->
dtype
().
valid
()
&&
ovar
->
comp_node
().
valid
());
mgb_assert
(
ovar
->
shape
().
ndim
||
ovar
->
contain_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
));
desc
.
layout
.
init_contiguous_stride
(
ovar
->
shape
());
}
return
descs
;
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
return
infer_output_attrs
(
def
,
to_raw_ptr_array
(
inputs
));
}
void
exec
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
const
SmallVector
<
TensorPtr
>&
outputs
)
{
auto
raw_inputs
=
to_raw_ptr_array
(
inputs
),
raw_outputs
=
to_raw_ptr_array
(
outputs
);
CompNode
::
UnorderedSet
used_cns
;
for
(
auto
&&
out
:
raw_outputs
)
{
auto
cn
=
out
->
comp_node
();
add_used_comp_node
(
cn
);
if
(
used_cns
.
insert
(
cn
).
second
)
{
for
(
auto
&&
in
:
inputs
)
{
if
(
in
->
comp_node
()
!=
cn
)
{
auto
&&
e
=
in
->
get_or_create_event
();
e
->
device_wait_by
(
cn
);
}
}
}
}
auto
&
minigraph
=
get_cached_minigraph
(
def
,
raw_inputs
);
auto
_
=
scoped_attach
(
&
minigraph
);
// some opr (e.g. Subtensor) may invoke infer_value during execution,
// so we need create inference session here
auto
sess
=
minigraph
.
infer_session
(
raw_inputs
);
minigraph
.
execute
(
raw_inputs
,
raw_outputs
,
m_env
);
for
(
auto
&&
cn
:
used_cns
)
{
for
(
auto
&&
in
:
inputs
)
{
if
(
in
->
comp_node
()
!=
cn
)
{
in
->
add_release_callback
(
cn
);
}
}
}
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
)
{
auto
&&
raw_inputs
=
to_raw_ptr_array
(
inputs
);
auto
output_descs
=
infer_output_attrs
(
def
,
raw_inputs
);
SmallVector
<
TensorPtr
>
outputs
(
output_descs
.
size
(),
{});
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
outputs
[
i
]
=
Tensor
::
make
(
output_descs
[
i
].
layout
,
output_descs
[
i
].
comp_node
);
}
exec
(
def
,
inputs
,
outputs
);
return
outputs
;
}
};
}
// namespace mgb::imperative::proxy_graph
imperative/src/impl/proxy_graph/proxy_graph.cpp
浏览文件 @
fea46ea9
...
...
@@ -23,6 +23,7 @@ thread_local std::unique_ptr<ProxyGraphTypeI> ProxyGraphTypeI::sm_instance = {};
}
// namespace mgb::imperative::proxy_graph
namespace
mgb
::
imperative
::
proxy_graph_detail
{
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
ret
=
proxy_graph
::
ProxyGraphTypeI
::
inst
().
infer_output_attrs_fallible
(
...
...
@@ -42,4 +43,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
ret
;
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
)
{
auto
ret
=
proxy_graph
::
ProxyGraphTypeI
::
inst
().
apply_on_physical_tensor
(
def
,
inputs
);
return
ret
;
}
}
// namespace mgb::imperative::proxy_graph_detail
imperative/src/impl/proxy_graph_detail.cpp
浏览文件 @
fea46ea9
...
...
@@ -17,6 +17,9 @@ namespace mgb {
namespace
imperative
{
namespace
proxy_graph_detail
{
// those functions are reimplemented with opr cache
// in ./proxy_graph/mini_graph.h
#if 0
namespace {
SmallVector<Tensor*> to_raw_ptr_array(
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) {
...
...
@@ -83,12 +86,13 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs;
}
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const
// OpDef& def,
// const SmallVector<LogicalTensorDesc>& inputs) {
// auto&& graph = ProxyGraph::get_default_graph();
// return graph->infer_output_attrs_fallible(def, inputs);
// }
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs_fallible(def, inputs);
}
#endif
EncodedSubgraph
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
...
...
src/opr/impl/tensor_manip.cpp
浏览文件 @
fea46ea9
...
...
@@ -1009,6 +1009,7 @@ void Split::init_output_static_infer_desc() {
bool
Split
::
infer_shape
(
size_t
out_idx
,
TensorShape
&
dest
,
const
cg
::
static_infer
::
InpVal
&
inp
)
{
mgb_assert
(
inp
.
run_id
>
0
,
"run id should be a positive number"
);
if
(
inp
.
run_id
!=
m_output_shape_version
)
{
std
::
vector
<
size_t
>
partition
;
auto
ishp
=
inp
.
val
.
at
(
0
).
shape
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录