Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a5af35c1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a5af35c1
编写于
1月 05, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): remove command buffer
GitOrigin-RevId: 83c8cb6d3bed9b44b0424965fc7c4938b0ae5841
上级
bdb853ee
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
38 addition
and
229 deletion
+38
-229
imperative/python/megengine/dtr/dtr.py
imperative/python/megengine/dtr/dtr.py
+0
-1
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+0
-4
imperative/python/test/integration/test_converge_with_drop.py
...rative/python/test/integration/test_converge_with_drop.py
+0
-3
imperative/python/test/unit/random/test_rng.py
imperative/python/test/unit/random/test_rng.py
+4
-0
imperative/src/impl/interpreter/commands.h
imperative/src/impl/interpreter/commands.h
+0
-2
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+31
-148
imperative/src/impl/interpreter/interpreter_impl.h
imperative/src/impl/interpreter/interpreter_impl.h
+0
-45
imperative/src/impl/interpreter/option_manager.h
imperative/src/impl/interpreter/option_manager.h
+0
-3
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+3
-20
imperative/src/impl/proxy_graph.h
imperative/src/impl/proxy_graph.h
+0
-3
未找到文件。
imperative/python/megengine/dtr/dtr.py
浏览文件 @
a5af35c1
...
...
@@ -120,7 +120,6 @@ def enable():
r
"""Enable to record computing path of tensors and to perform DTR policy."""
_set_option
(
"enable_dtr_auto_drop"
,
1
)
_set_option
(
"enable_drop"
,
1
)
_set_option
(
"buffer_length"
,
0
)
_set_option
(
"record_computing_path"
,
1
)
...
...
imperative/python/src/tensor.cpp
浏览文件 @
a5af35c1
...
...
@@ -702,10 +702,6 @@ void init_tensor(py::module m) {
});
m
.
def
(
"get_option"
,
[
channel
](
std
::
string
name
)
{
return
channel
->
get_option
(
name
);
});
m
.
def
(
"set_buffer_length"
,
[
channel
](
int
length
)
{
mgb_assert
(
length
>=
0
and
length
<
100
,
"buffer_length should be in [0, 100)"
);
channel
->
set_option
(
"buffer_length"
,
length
);
});
m
.
def
(
"push_scope"
,
[
channel
](
std
::
string
name
)
{
Transformation
::
push_scope
(
name
);
channel
->
push_scope
(
name
);
...
...
imperative/python/test/integration/test_converge_with_drop.py
浏览文件 @
a5af35c1
...
...
@@ -76,8 +76,6 @@ class XORNet(Module):
def
test_training_converge_with_drop
():
set_option
(
"enable_drop"
,
1
)
old_buffer_length
=
get_option
(
"buffer_length"
)
set_option
(
"buffer_length"
,
0
)
net
=
XORNet
()
opt
=
SGD
(
net
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
gm
=
ad
.
GradManager
().
attach
(
net
.
parameters
())
...
...
@@ -119,4 +117,3 @@ def test_training_converge_with_drop():
)
set_option
(
"enable_drop"
,
0
)
set_option
(
"buffer_length"
,
old_buffer_length
)
imperative/python/test/unit/random/test_rng.py
浏览文件 @
a5af35c1
...
...
@@ -9,6 +9,7 @@
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.functional
as
F
from
megengine
import
Tensor
,
jit
,
random
from
megengine.core._imperative_rt
import
CompNode
...
...
@@ -209,9 +210,12 @@ def test_permutation_op():
assert
str
(
output
.
device
)
==
str
(
cn
)
assert
output
.
dtype
==
dtype
# FIXME: remove this sync
mge
.
core
.
set_option
(
"async_level"
,
0
)
test_permutation_op_dtype
(
np
.
float32
)
test_permutation_op_dtype
(
np
.
int32
)
test_permutation_op_dtype
(
np
.
int16
)
mge
.
core
.
set_option
(
"async_level"
,
2
)
@
pytest
.
mark
.
skipif
(
...
...
imperative/src/impl/interpreter/commands.h
浏览文件 @
a5af35c1
...
...
@@ -49,14 +49,12 @@ struct ApplyOp {
std
::
shared_ptr
<
OpDef
>
op
;
SmallVector
<
TensorInfo
*>
inputs
;
SmallVector
<
TensorInfo
*>
outputs
;
SmallVector
<
TensorInfo
*>
dels
;
template
<
typename
TFunctor
>
void
get_props
(
TFunctor
&&
functor
)
const
{
functor
(
"op"
,
op
);
functor
(
"inputs"
,
inputs
);
functor
(
"outputs"
,
outputs
);
functor
(
"dels"
,
dels
);
}
const
char
*
get_name
()
const
{
return
"ApplyOp"
;
}
...
...
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
a5af35c1
...
...
@@ -156,7 +156,9 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info
->
desc
.
value
=
value
.
proxy_to_default_cpu
();
}
info
->
mem_desc
.
id
=
StorageIdentifier
::
make
(
++
m_storage_id
);
m_buffer
.
enqueue
(
Put
{
info
,
value
,
no_cache
});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
Put
{
info
,
value
,
no_cache
},
get_channel_state
().
stack_manager
.
dump
()});
if
(
m_async_level
==
0
)
{
sync_impl
();
info
->
desc
.
comp_node
.
sync
();
...
...
@@ -200,7 +202,8 @@ void ChannelImpl::del_impl(Handle handle) {
mgb_assert
(
m_valid_handle
.
count
(
handle
),
"invalid handle: %p"
,
handle
);
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
m_valid_handle
.
erase
(
handle
);
m_buffer
.
enqueue
(
Del
{
info
});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
Del
{
info
},
get_channel_state
().
stack_manager
.
dump
()});
}
void
ChannelImpl
::
drop
(
Handle
handle
)
{
...
...
@@ -212,7 +215,9 @@ void ChannelImpl::drop(Handle handle) {
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
m_buffer
.
enqueue
(
Drop
{
info
});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
Drop
{
info
},
get_channel_state
().
stack_manager
.
dump
()});
}
}
...
...
@@ -333,7 +338,9 @@ void ChannelImpl::dispatch_kernel(
MGB_RECORD_EVENT
(
OpDispatchEvent
,
cmd
.
id
,
name
,
op_info_getter
,
tinfo_to_tid
(
cmd
.
inputs
),
tinfo_to_tid
(
cmd
.
outputs
),
state
.
stack_manager
.
dump
());
m_buffer
.
enqueue
(
std
::
move
(
cmd
));
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
std
::
move
(
cmd
),
get_channel_state
().
stack_manager
.
dump
()});
if
(
!
validated
&&
options
.
async_level
==
1
)
{
sync_impl
();
}
else
if
(
options
.
async_level
==
0
)
{
...
...
@@ -466,7 +473,6 @@ void ChannelImpl::sync() {
}
void
ChannelImpl
::
sync_impl
()
{
m_buffer
.
flush
();
m_worker
.
wait_all_task_finish
();
MGB_LOCK_GUARD
(
m_mutex
);
check_worker_exc_unsafe
();
...
...
@@ -499,7 +505,9 @@ void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert
(
check_available
(),
"Channel already closed"
);
auto
&
state
=
get_channel_state
();
state
.
options
.
set_option
(
name
,
value
);
m_buffer
.
enqueue
(
SetOption
{
name
,
value
});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
SetOption
{
name
,
value
},
get_channel_state
().
stack_manager
.
dump
()});
}
void
ChannelImpl
::
clear_candidates
()
{
...
...
@@ -604,7 +612,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool
.
free
(
ptr
);
}
ChannelImpl
::
ChannelImpl
()
:
m_worker
(
this
)
,
m_buffer
(
this
)
{}
ChannelImpl
::
ChannelImpl
()
:
m_worker
(
this
)
{}
ChannelImpl
::~
ChannelImpl
()
{
close
();
...
...
@@ -645,7 +653,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if
(
dest
->
evict_type
==
EvictType
::
DROP
)
{
auto
&&
path
=
dest
->
producer
;
m_apply_stack
.
push
(
{
ApplyOp
{
path
->
id
,
path
->
op
,
path
->
inputs
,
path
->
outputs
,
{}
},
0
,
dest
,
{
ApplyOp
{
path
->
id
,
path
->
op
,
path
->
inputs
,
path
->
outputs
},
0
,
dest
,
"dtr"
});
if
(
!
m_applying
)
flush_apply_stack
();
...
...
@@ -748,19 +756,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
MGB_RECORD_EVENT
(
TensorUsageEvent
,
input_id
);
MGB_RECORD_EVENT
(
OpInputFinishEvent
,
input_id
);
}
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by
// tensor_inputs after Del. Note for exprs like 'y = x op x', inplace is unsupported
// yet but Del would be also fused.
for
(
auto
*
del
:
cmd
.
dels
)
{
// refcnt --, owners: [tensor_inputs]
// if it's decreased to 1, would be detected at @see:
// proxy_graph_detail::apply_on_physical_tensor
uint64_t
del_id
=
del
->
id
;
MGB_RECORD_EVENT
(
TensorCommandEvent
,
del_id
,
TensorCommandKind
::
Del
);
free
(
del
);
MGB_RECORD_EVENT
(
TensorCommandFinishEvent
,
del_id
,
TensorCommandKind
::
Del
);
}
// Before wait
// TODO: split operator wait and execute so that OpWait could be corrected recorded.
// Before execute
...
...
@@ -931,7 +926,6 @@ bool ChannelImpl::check_available() {
}
TensorPtr
ChannelImpl
::
wait_tensor
(
TensorInfo
*
info
,
TensorProp
prop
)
{
m_buffer
.
flush
();
std
::
unique_lock
<
decltype
(
m_mutex
)
>
lock
(
m_mutex
);
mgb_assert
(
!
m_waitee
,
"duplicate waitee"
);
m_waitee
=
info
;
...
...
@@ -943,8 +937,9 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if
(
require_host
&&
!
host_available
())
{
// avoid dead lock
lock
.
unlock
();
m_buffer
.
enqueue
(
GetValue
{
info
});
m_buffer
.
flush
();
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
GetValue
{
info
},
get_channel_state
().
stack_manager
.
dump
()});
lock
.
lock
();
wait_host
=
true
;
}
...
...
@@ -1266,141 +1261,25 @@ void ChannelImpl::check_worker_exc_unsafe() {
}
}
void
ChannelImpl
::
CommandBuffer
::
enqueue
(
CommandData
cmd
)
{
auto
&
state
=
m_owner
->
get_channel_state
();
if
(
std
::
get_if
<
Del
>
(
&
cmd
)
&&
fuse_del
(
std
::
get
<
Del
>
(
cmd
)))
{
return
;
}
m_commands
.
push_back
(
{
Profiler
::
next_id
(),
std
::
move
(
cmd
),
state
.
stack_manager
.
dump
()});
auto
flush_pos
=
flush_pos_for
(
m_commands
.
back
());
flush
(
flush_pos
);
}
void
ChannelImpl
::
CommandBuffer
::
flush
()
{
flush
(
m_commands
.
end
());
}
void
ChannelImpl
::
CommandBuffer
::
flush
(
Handle
pos
)
{
for
(
auto
iter
=
m_commands
.
begin
();
iter
!=
pos
;
++
iter
)
{
if
(
Profiler
::
is_profiling
())
{
mgb_log_debug
(
"%s Flushed"
,
to_string
(
*
iter
).
c_str
());
}
m_owner
->
m_worker
.
add_task
(
std
::
move
(
*
iter
));
}
m_commands
.
erase
(
m_commands
.
begin
(),
pos
);
}
auto
ChannelImpl
::
CommandBuffer
::
flush_pos_for
(
const
Command
&
cmd
)
->
Handle
{
auto
&
state
=
m_owner
->
get_channel_state
();
return
std
::
visit
(
[
this
,
&
state
](
const
auto
&
cmd
)
{
using
T
=
std
::
decay_t
<
decltype
(
cmd
)
>
;
if
constexpr
(
std
::
is_same_v
<
T
,
ApplyOp
>
)
{
auto
*
op_type
=
cmd
.
op
->
dyn_typeinfo
();
if
(
op_type
==
RemoteRecv
::
typeinfo
()
||
op_type
==
RemoteSend
::
typeinfo
()
||
op_type
==
CollectiveComm
::
typeinfo
()
||
op_type
==
opr
::
InputCallback
::
typeinfo
()
||
op_type
==
opr
::
OutputCallback
::
typeinfo
())
{
return
m_commands
.
end
();
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
GetValue
>
)
{
return
m_commands
.
end
();
}
size_t
buffer_length
=
state
.
options
.
buffer_length
;
if
(
m_commands
.
size
()
>
buffer_length
)
{
return
m_commands
.
begin
()
+
(
m_commands
.
size
()
-
buffer_length
);
}
return
m_commands
.
begin
();
},
cmd
.
data
);
}
/**
* 1. Find ApplyOp(dest) in buffered commands
* 2. Check if there are other usages between ApplyOp and Del, return false if not
* 3. Fuse Del into ApplyOp, return true
*/
bool
ChannelImpl
::
CommandBuffer
::
fuse_del
(
const
Del
&
cmd
)
{
auto
*
dest
=
cmd
.
dest
;
// TODO: eliminate Puts
auto
begin
=
m_commands
.
begin
(),
end
=
m_commands
.
end
();
auto
apply_iter
=
std
::
find_if
(
begin
,
end
,
[
dest
](
const
Command
&
cmd
)
{
if
(
auto
*
apply
=
std
::
get_if
<
ApplyOp
>
(
&
cmd
.
data
))
{
return
std
::
count
(
apply
->
inputs
.
begin
(),
apply
->
inputs
.
end
(),
dest
)
>
0
;
}
return
false
;
});
if
(
apply_iter
==
end
||
find_last_usage
(
dest
,
{
apply_iter
+
1
,
end
})
!=
end
)
{
return
false
;
}
std
::
get
<
ApplyOp
>
(
apply_iter
->
data
).
dels
.
push_back
(
dest
);
return
true
;
}
auto
ChannelImpl
::
CommandBuffer
::
find_last_usage
(
TensorInfo
*
dest
,
Range
range
)
->
Handle
{
auto
found
=
range
[
1
];
for
(
auto
iter
=
range
[
0
];
iter
!=
range
[
1
];
++
iter
)
{
std
::
visit
(
[
&
](
const
auto
&
cmd
)
{
using
T
=
std
::
decay_t
<
decltype
(
cmd
)
>
;
if
constexpr
(
std
::
is_same_v
<
T
,
ApplyOp
>
)
{
if
(
std
::
count
(
cmd
.
inputs
.
begin
(),
cmd
.
inputs
.
end
(),
dest
)
>
0
)
{
found
=
iter
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
GetValue
>
)
{
if
(
cmd
.
dest
==
dest
)
{
found
=
iter
;
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
Drop
>
)
{
// TODO: ignore swap-like commands, just remove them from buffer
if
(
cmd
.
dest
==
dest
)
{
found
=
iter
;
}
}
},
iter
->
data
);
};
return
found
;
}
auto
ChannelImpl
::
CommandBuffer
::
find_produce
(
TensorInfo
*
dest
,
Range
range
)
->
Handle
{
return
std
::
find_if
(
range
[
0
],
range
[
1
],
[
dest
](
auto
&
cmd
)
{
return
std
::
visit
(
[
dest
](
const
auto
&
cmd
)
{
using
T
=
std
::
decay_t
<
decltype
(
cmd
)
>
;
if
constexpr
(
std
::
is_same_v
<
T
,
ApplyOp
>
)
{
return
std
::
count
(
cmd
.
outputs
.
begin
(),
cmd
.
outputs
.
end
(),
dest
)
>
0
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
Put
>
)
{
return
cmd
.
dest
==
dest
;
}
return
false
;
},
cmd
.
data
);
});
}
void
ChannelImpl
::
start_profile
()
{
MGB_LOCK_GUARD
(
m_spin
);
mgb_assert
(
check_available
(),
"Channel already closed"
);
auto
capture_tensors
=
collect_valid_tensors
();
if
(
capture_tensors
.
size
()
>
0
)
{
m_buffer
.
enqueue
(
StartProfile
{
std
::
move
(
capture_tensors
)});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
StartProfile
{
std
::
move
(
capture_tensors
)},
get_channel_state
().
stack_manager
.
dump
()});
}
}
void
ChannelImpl
::
stop_profile
()
{
MGB_LOCK_GUARD
(
m_spin
);
mgb_assert
(
check_available
(),
"Channel already closed"
);
m_buffer
.
flush
();
auto
escape_tensors
=
collect_valid_tensors
();
if
(
escape_tensors
.
size
()
>
0
)
{
m_buffer
.
enqueue
(
StopProfile
{
std
::
move
(
escape_tensors
)});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
StopProfile
{
std
::
move
(
escape_tensors
)},
get_channel_state
().
stack_manager
.
dump
()});
}
}
...
...
@@ -1410,7 +1289,9 @@ void ChannelImpl::push_scope(std::string name) {
auto
&
state
=
get_channel_state
();
state
.
stack_manager
.
enter
(
name
);
MGB_RECORD_EVENT
(
ScopeEvent
,
name
);
m_buffer
.
enqueue
(
PushScope
{
name
});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
PushScope
{
name
},
get_channel_state
().
stack_manager
.
dump
()});
}
void
ChannelImpl
::
pop_scope
(
std
::
string
name
)
{
...
...
@@ -1419,7 +1300,9 @@ void ChannelImpl::pop_scope(std::string name) {
auto
&
state
=
get_channel_state
();
state
.
stack_manager
.
exit
(
name
);
MGB_RECORD_EVENT
(
ScopeFinishEvent
,
name
);
m_buffer
.
enqueue
(
PopScope
{
name
});
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
PopScope
{
name
},
get_channel_state
().
stack_manager
.
dump
()});
}
void
ChannelImpl
::
assert_in_channel
()
{
...
...
imperative/src/impl/interpreter/interpreter_impl.h
浏览文件 @
a5af35c1
...
...
@@ -126,11 +126,6 @@ private:
void
assert_in_worker
();
std
::
thread
::
id
get_worker_tid
();
// template <typename TCommand>
// void enqueue_command(TCommand&& cmd) {
// m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
// }
void
sample_on_device
(
CompNode
device
,
bool
force
);
// valid => status != Deleted
...
...
@@ -178,46 +173,6 @@ private:
ChannelImpl
*
m_owner
;
}
m_worker
;
/**
* Buf a command window for following fuse
* example:
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
* ---------------------------------------------------------------------
* Then the fused Apply may be invoked inplace. see:
* ChannelImpl::process_one_task
*/
struct
CommandBuffer
{
CommandBuffer
(
ChannelImpl
*
owner
)
:
m_owner
(
owner
)
{}
void
enqueue
(
CommandData
cmd
);
bool
empty
()
const
{
return
m_commands
.
empty
();
}
void
flush
();
private:
ChannelImpl
*
m_owner
;
std
::
deque
<
Command
>
m_commands
;
using
Handle
=
decltype
(
m_commands
)
::
iterator
;
// [begin, end)
using
Range
=
std
::
array
<
Handle
,
2
>
;
// Launch commands in range [m_commands.begin(), pos)
void
flush
(
Handle
pos
);
// Select flush position for incoming cmd
Handle
flush_pos_for
(
const
Command
&
cmd
);
// Fuse del command into suitable ApplyOp
bool
fuse_del
(
const
Del
&
cmd
);
// Returns the last handle that dest is used within range. If dest is not used,
// returns range[1]
Handle
find_last_usage
(
TensorInfo
*
dest
,
Range
range
);
// Returns the produce position of dest. If not found, returns range[1]
Handle
find_produce
(
TensorInfo
*
dest
,
Range
range
);
}
m_buffer
;
//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
...
...
imperative/src/impl/interpreter/option_manager.h
浏览文件 @
a5af35c1
...
...
@@ -40,9 +40,6 @@ public:
DEF_OPTION
(
catch_worker_execption
,
"MEGENGINE_CATCH_WORKER_EXEC"
,
1
,
"catch worker exception if enabled, close it when debugging"
);
DEF_OPTION
(
buffer_length
,
"MEGENGINE_COMMAND_BUFFER_LENGTH"
,
3
,
"set command buffer length."
);
DEF_OPTION
(
enable_host_compute
,
"MEGENGINE_HOST_COMPUTE"
,
1
,
"enable host compute, thus computation may be done in host event if it's "
...
...
imperative/src/impl/proxy_graph.cpp
浏览文件 @
a5af35c1
...
...
@@ -626,23 +626,12 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
/*********************** Logical Tensor Impl ***********************/
size_t
ProxyGraph
::
get_opr_output_size
(
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
return
get_proxy_opr
(
opdef
,
inputs
)
->
usable_output
().
size
();
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
ProxyGraph
::
infer_output_attrs_fallible
(
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
opr
=
get_proxy_opr
(
opdef
,
inputs
);
CUR_OPR_GUARD
(
opr
);
SmallVector
<
LogicalTensorDesc
>
outputs
;
bool
validated
=
do_shape_infer
(
false
);
for
(
auto
&&
i
:
opr
->
usable_output
())
{
outputs
.
push_back
({{
i
->
shape
(),
i
->
dtype
()},
i
->
comp_node
()});
}
bool
need_check
=
opr
->
same_type
<
opr
::
Reshape
>
();
return
{
outputs
,
validated
&&
!
need_check
};
// this function is just a placeholder
// it will be overrided by ProxyGraphTypeI::infer_output_attrs_fallible in minigraph
mgb_assert
(
0
);
}
std
::
tuple
<
SmallVector
<
MemoryDesc
>
,
SmallVector
<
MemoryDesc
>>
ProxyGraph
::
...
...
@@ -823,12 +812,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
return
result
;
}
cg
::
OperatorNodeBase
*
ProxyGraph
::
get_proxy_opr
(
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
mgb_assert
(
!
m_cur_opr
);
auto
vinputs
=
make_input_place_holders
(
inputs
);
return
OpDef
::
apply_on_var_node
(
opdef
,
vinputs
)[
0
]
->
owner_opr
();
}
VarNodeArray
ProxyGraph
::
make_input_place_holders
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
...
...
imperative/src/impl/proxy_graph.h
浏览文件 @
a5af35c1
...
...
@@ -85,9 +85,6 @@ private:
/********************** Logical Tensor Helper **********************/
cg
::
OperatorNodeBase
*
get_proxy_opr
(
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
cg
::
VarNodeArray
make_input_place_holders
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录