Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0a6f4a88
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0a6f4a88
编写于
3月 15, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/dtr): fix dtr problem
GitOrigin-RevId: 2a703f9ee4ebf8667ac889a73f67c688dfebd9bc
上级
529b394f
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
67 addition
and
24 deletion
+67
-24
imperative/src/impl/blob_manager_impl.cpp
imperative/src/impl/blob_manager_impl.cpp
+18
-0
imperative/src/impl/blob_manager_impl.h
imperative/src/impl/blob_manager_impl.h
+4
-0
imperative/src/impl/interpreter/commands.h
imperative/src/impl/interpreter/commands.h
+1
-0
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+33
-20
imperative/src/impl/interpreter/interpreter_impl.h
imperative/src/impl/interpreter/interpreter_impl.h
+1
-1
imperative/src/impl/interpreter/tensor_info.h
imperative/src/impl/interpreter/tensor_info.h
+4
-1
imperative/src/impl/physical_tensor.cpp
imperative/src/impl/physical_tensor.cpp
+1
-1
imperative/src/include/megbrain/imperative/blob_manager.h
imperative/src/include/megbrain/imperative/blob_manager.h
+4
-0
imperative/src/include/megbrain/imperative/physical_tensor.h
imperative/src/include/megbrain/imperative/physical_tensor.h
+1
-1
未找到文件。
imperative/src/impl/blob_manager_impl.cpp
浏览文件 @
0a6f4a88
...
@@ -41,6 +41,10 @@ void BlobManagerImpl::unregister_blob(Blob* blob) {
...
@@ -41,6 +41,10 @@ void BlobManagerImpl::unregister_blob(Blob* blob) {
}
}
void
BlobManagerImpl
::
alloc_with_defrag
(
Blob
*
blob
,
size_t
size
)
{
void
BlobManagerImpl
::
alloc_with_defrag
(
Blob
*
blob
,
size_t
size
)
{
if
(
custom_allocator
)
{
blob
->
m_storage
=
custom_allocator
(
blob
->
m_comp_node
,
size
);
return
;
}
// try alloc
// try alloc
MGB_TRY
{
alloc_direct
(
blob
,
size
);
}
MGB_TRY
{
alloc_direct
(
blob
,
size
);
}
// if fail, try defrag, alloc again
// if fail, try defrag, alloc again
...
@@ -61,6 +65,13 @@ void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) {
...
@@ -61,6 +65,13 @@ void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) {
DeviceTensorND
BlobManagerImpl
::
alloc_workspace_with_defrag
(
DeviceTensorND
BlobManagerImpl
::
alloc_workspace_with_defrag
(
CompNode
cn
,
TensorLayout
&
layout
)
{
CompNode
cn
,
TensorLayout
&
layout
)
{
DeviceTensorND
dev_tensor
;
DeviceTensorND
dev_tensor
;
if
(
custom_allocator
)
{
DeviceTensorStorage
storage
(
cn
);
size_t
sz
=
layout
.
dtype
.
size
(
layout
.
total_nr_elems
());
storage
.
reset
(
cn
,
sz
,
custom_allocator
(
cn
,
sz
));
dev_tensor
.
reset
(
storage
,
layout
);
return
dev_tensor
;
}
MGB_TRY
{
return
alloc_workspace
(
cn
,
layout
);
}
MGB_TRY
{
return
alloc_workspace
(
cn
,
layout
);
}
MGB_CATCH
(
MemAllocError
&
,
{
MGB_CATCH
(
MemAllocError
&
,
{
mgb_log_warn
(
"memory allocation failed for workspace; try defragmenting"
);
mgb_log_warn
(
"memory allocation failed for workspace; try defragmenting"
);
...
@@ -78,6 +89,10 @@ DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout
...
@@ -78,6 +89,10 @@ DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout
return
dev_tensor
;
return
dev_tensor
;
}
}
void
BlobManagerImpl
::
set_allocator
(
allocator_t
allocator
)
{
custom_allocator
=
allocator
;
}
void
BlobManagerImpl
::
defrag
(
const
CompNode
&
cn
)
{
void
BlobManagerImpl
::
defrag
(
const
CompNode
&
cn
)
{
BlobSetWithMux
*
blobs_set_ptr
;
BlobSetWithMux
*
blobs_set_ptr
;
{
{
...
@@ -159,6 +174,9 @@ struct BlobManagerStub : BlobManager {
...
@@ -159,6 +174,9 @@ struct BlobManagerStub : BlobManager {
void
defrag
(
const
CompNode
&
cn
)
{
void
defrag
(
const
CompNode
&
cn
)
{
mgb_assert
(
0
,
"prohibited after global variable destruction"
);
mgb_assert
(
0
,
"prohibited after global variable destruction"
);
};
};
virtual
void
set_allocator
(
allocator_t
allocator
)
{
mgb_assert
(
0
,
"prohibited after global variable destruction"
);
};
};
};
BlobManager
*
BlobManager
::
inst
()
{
BlobManager
*
BlobManager
::
inst
()
{
...
...
imperative/src/impl/blob_manager_impl.h
浏览文件 @
0a6f4a88
...
@@ -45,6 +45,8 @@ class BlobManagerImpl final : public BlobManager {
...
@@ -45,6 +45,8 @@ class BlobManagerImpl final : public BlobManager {
DeviceTensorND
alloc_workspace
(
CompNode
cn
,
TensorLayout
layout
);
DeviceTensorND
alloc_workspace
(
CompNode
cn
,
TensorLayout
layout
);
BlobManager
::
allocator_t
custom_allocator
;
public:
public:
static
BlobManager
*
inst
();
static
BlobManager
*
inst
();
...
@@ -56,6 +58,8 @@ public:
...
@@ -56,6 +58,8 @@ public:
void
register_blob
(
Blob
*
blob
)
override
;
void
register_blob
(
Blob
*
blob
)
override
;
void
unregister_blob
(
Blob
*
blob
)
override
;
void
unregister_blob
(
Blob
*
blob
)
override
;
void
set_allocator
(
allocator_t
allocator
)
override
;
};
};
}
// namespace imperative
}
// namespace imperative
...
...
imperative/src/impl/interpreter/commands.h
浏览文件 @
0a6f4a88
...
@@ -49,6 +49,7 @@ struct ApplyOp {
...
@@ -49,6 +49,7 @@ struct ApplyOp {
std
::
shared_ptr
<
OpDef
>
op
;
std
::
shared_ptr
<
OpDef
>
op
;
SmallVector
<
TensorInfo
*>
inputs
;
SmallVector
<
TensorInfo
*>
inputs
;
SmallVector
<
TensorInfo
*>
outputs
;
SmallVector
<
TensorInfo
*>
outputs
;
SmallVector
<
LogicalTensorDesc
>
outputs_descs
;
bool
validated
=
false
;
bool
validated
=
false
;
template
<
typename
TFunctor
>
template
<
typename
TFunctor
>
...
...
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
0a6f4a88
...
@@ -114,11 +114,13 @@ ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
...
@@ -114,11 +114,13 @@ ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
void
ChannelImpl
::
WorkQueue
::
on_async_queue_worker_thread_start
()
{
void
ChannelImpl
::
WorkQueue
::
on_async_queue_worker_thread_start
()
{
sys
::
set_thread_name
(
"worker"
);
sys
::
set_thread_name
(
"worker"
);
m_owner
->
m_worker_state
.
tid
=
std
::
this_thread
::
get_id
();
m_owner
->
m_worker_state
.
tid
=
std
::
this_thread
::
get_id
();
OpDef
::
set_allocator
(
[
&
](
CompNode
device
,
size_t
size
)
{
auto
custom_allocator
=
[
&
](
CompNode
device
,
size_t
size
)
{
auto
blob
=
Blob
::
make
(
device
,
size
);
auto
blob
=
Blob
::
make
(
device
,
size
);
m_owner
->
alloc_tensor_with_evict
(
blob
.
get
());
m_owner
->
alloc_tensor_with_evict
(
blob
.
get
());
return
blob
->
storage
();
return
blob
->
storage
();
});
};
OpDef
::
set_allocator
(
custom_allocator
);
BlobManager
::
inst
()
->
set_allocator
(
custom_allocator
);
}
}
// Do not use m_xxx_state directly
// Do not use m_xxx_state directly
...
@@ -353,7 +355,7 @@ void ChannelImpl::dispatch_kernel(
...
@@ -353,7 +355,7 @@ void ChannelImpl::dispatch_kernel(
for
(
int
i
=
0
;
i
<
output_descs
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
output_descs
.
size
();
++
i
)
{
auto
&&
desc
=
output_descs
[
i
];
auto
&&
desc
=
output_descs
[
i
];
auto
info
=
alloc
();
auto
info
=
alloc
();
init
(
info
,
std
::
move
(
desc
)
);
init
(
info
,
desc
);
// make sure desc's value is consistent with h_value
// make sure desc's value is consistent with h_value
if
(
!
info
->
desc
.
value
.
empty
())
{
if
(
!
info
->
desc
.
value
.
empty
())
{
info
->
h_value
=
HostTensorND
::
make_proxy
(
desc
.
value
)
info
->
h_value
=
HostTensorND
::
make_proxy
(
desc
.
value
)
...
@@ -362,9 +364,9 @@ void ChannelImpl::dispatch_kernel(
...
@@ -362,9 +364,9 @@ void ChannelImpl::dispatch_kernel(
output_infos
.
push_back
(
info
);
output_infos
.
push_back
(
info
);
outputs
->
push_back
(
reinterpret_cast
<
Handle
>
(
info
));
outputs
->
push_back
(
reinterpret_cast
<
Handle
>
(
info
));
}
}
ApplyOp
cmd
{
ApplyOp
cmd
{
Profiler
::
next_id
(),
std
::
move
(
op
),
Profiler
::
next_id
(),
std
::
move
(
op
),
std
::
move
(
in
put_infos
),
std
::
move
(
input_infos
),
std
::
move
(
out
put_infos
),
std
::
move
(
output_info
s
),
validated
};
std
::
move
(
output_desc
s
),
validated
};
if
(
Profiler
::
is_profiling
())
{
if
(
Profiler
::
is_profiling
())
{
auto
op_info_getter
=
[
op
=
cmd
.
op
]
{
auto
op_info_getter
=
[
op
=
cmd
.
op
]
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_info
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_info
;
...
@@ -594,7 +596,7 @@ TensorInfo* ChannelImpl::alloc() {
...
@@ -594,7 +596,7 @@ TensorInfo* ChannelImpl::alloc() {
return
info
;
return
info
;
}
}
void
ChannelImpl
::
init
(
TensorInfo
*
info
,
LogicalTensorDesc
&&
desc
)
{
void
ChannelImpl
::
init
(
TensorInfo
*
info
,
LogicalTensorDesc
desc
)
{
m_valid_handle
.
insert
(
reinterpret_cast
<
Handle
>
(
info
));
m_valid_handle
.
insert
(
reinterpret_cast
<
Handle
>
(
info
));
MGB_RECORD_EVENT
(
TensorDeclareEvent
,
info
->
id
,
info
->
name
);
MGB_RECORD_EVENT
(
TensorDeclareEvent
,
info
->
id
,
info
->
name
);
info
->
status
=
TensorInfo
::
Allocated
;
info
->
status
=
TensorInfo
::
Allocated
;
...
@@ -692,6 +694,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
...
@@ -692,6 +694,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
"shape infer error, %s vs %s"
,
dest
->
desc
.
layout
.
to_string
().
c_str
(),
"shape infer error, %s vs %s"
,
dest
->
desc
.
layout
.
to_string
().
c_str
(),
ptr
->
layout
().
to_string
().
c_str
());
ptr
->
layout
().
to_string
().
c_str
());
}
}
// in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled
if
(
state
.
options
.
enable_dtr_auto_drop
)
{
ptr
->
to_contiguous_inplace
();
}
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
memory
=
ptr
->
blob
()
->
size
();
dest
->
memory
=
ptr
->
blob
()
->
size
();
...
@@ -719,8 +726,9 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
...
@@ -719,8 +726,9 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if
(
dest
->
evict_type
==
EvictType
::
DROP
)
{
if
(
dest
->
evict_type
==
EvictType
::
DROP
)
{
auto
&&
path
=
dest
->
producer
;
auto
&&
path
=
dest
->
producer
;
m_apply_stack
.
push
(
m_apply_stack
.
push
(
{
ApplyOp
{
path
->
id
,
path
->
op
,
path
->
inputs
,
path
->
outputs
},
0
,
dest
,
{
ApplyOp
{
path
->
id
,
path
->
op
,
path
->
inputs
,
path
->
outputs
,
"dtr"
});
path
->
outputs_descs
},
0
,
dest
,
"dtr"
});
if
(
!
m_applying
)
if
(
!
m_applying
)
flush_apply_stack
();
flush_apply_stack
();
}
}
...
@@ -812,8 +820,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
...
@@ -812,8 +820,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
}
}
// Apply op
// Apply op
SmallVector
<
LogicalTensorDesc
>
output_descs
;
SmallVector
<
LogicalTensorDesc
>
output_descs
;
for
(
auto
i
:
cmd
.
outputs
)
{
for
(
auto
i
:
cmd
.
outputs
_descs
)
{
output_descs
.
push_back
(
i
->
desc
);
output_descs
.
push_back
(
i
);
}
}
// Here std::move is REQUIRED for removing duplicated references.
// Here std::move is REQUIRED for removing duplicated references.
auto
outputs
=
apply_on_physical_tensor
(
auto
outputs
=
apply_on_physical_tensor
(
...
@@ -1031,6 +1039,7 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
...
@@ -1031,6 +1039,7 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
}
}
void
ChannelImpl
::
alloc_tensor_with_evict
(
Blob
*
x
)
{
void
ChannelImpl
::
alloc_tensor_with_evict
(
Blob
*
x
)
{
bool
in_worker
=
(
get_worker_tid
()
==
std
::
this_thread
::
get_id
());
auto
reserve_size
=
[
&
](
size_t
size
)
{
auto
reserve_size
=
[
&
](
size_t
size
)
{
if
(
!
m_dtr
.
comp_node
.
valid
())
{
if
(
!
m_dtr
.
comp_node
.
valid
())
{
return
false
;
return
false
;
...
@@ -1043,10 +1052,13 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
...
@@ -1043,10 +1052,13 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
return
true
;
return
true
;
};
};
auto
pre_level
=
set_log_level
(
LogLevel
::
NO_LOG
);
auto
pre_level
=
set_log_level
(
LogLevel
::
NO_LOG
);
if
(
in_worker
)
{
reserve_size
(
x
->
size
());
reserve_size
(
x
->
size
());
}
MGB_TRY
{
BlobManager
::
inst
()
->
alloc_direct
(
x
,
x
->
size
());
}
MGB_TRY
{
BlobManager
::
inst
()
->
alloc_direct
(
x
,
x
->
size
());
}
MGB_CATCH
(
MemAllocError
&
,
{
MGB_CATCH
(
MemAllocError
&
,
{
bool
suc
=
false
;
bool
suc
=
false
;
if
(
in_worker
)
{
while
(
!
suc
)
{
while
(
!
suc
)
{
if
(
!
auto_evict
(
1
))
{
if
(
!
auto_evict
(
1
))
{
break
;
break
;
...
@@ -1055,6 +1067,7 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
...
@@ -1055,6 +1067,7 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
MGB_CATCH
(
MemAllocError
&
,
{
continue
;
});
MGB_CATCH
(
MemAllocError
&
,
{
continue
;
});
suc
=
true
;
suc
=
true
;
}
}
}
if
(
!
suc
)
{
if
(
!
suc
)
{
set_log_level
(
pre_level
);
set_log_level
(
pre_level
);
mgb_log_warn
(
mgb_log_warn
(
...
@@ -1143,10 +1156,10 @@ void ChannelImpl::process_one_task(Command& icmd) {
...
@@ -1143,10 +1156,10 @@ void ChannelImpl::process_one_task(Command& icmd) {
if
(
!
inplace
&&
!
cross_cn
&&
!
m_dtr
.
is_bad_op
(
get_name
(
*
cmd
.
op
)))
{
if
(
!
inplace
&&
!
cross_cn
&&
!
m_dtr
.
is_bad_op
(
get_name
(
*
cmd
.
op
)))
{
TensorInfo
::
ComputePath
::
make
(
TensorInfo
::
ComputePath
::
make
(
cmd
.
id
,
cmd
.
op
,
cmd
.
inputs
,
cmd
.
outputs
);
cmd
.
id
,
cmd
.
op
,
cmd
.
inputs
,
cmd
.
outputs
,
cmd
.
outputs_descs
);
size_t
detach_cnt
=
0
;
size_t
detach_cnt
=
0
;
if
(
!
strcmp
(
get_name
(
*
cmd
.
op
),
"BatchNorm"
)
&&
if
(
!
strcmp
(
get_name
(
*
cmd
.
op
),
"BatchNorm"
)
&&
cmd
.
outputs
.
size
()
==
5
)
{
cmd
.
outputs
.
size
()
==
6
)
{
cmd
.
outputs
[
0
]
->
detach_producer
();
// detach running_mean
cmd
.
outputs
[
0
]
->
detach_producer
();
// detach running_mean
cmd
.
outputs
[
1
]
->
detach_producer
();
// detach running_var
cmd
.
outputs
[
1
]
->
detach_producer
();
// detach running_var
for
(
auto
input
:
cmd
.
inputs
)
{
for
(
auto
input
:
cmd
.
inputs
)
{
...
...
imperative/src/impl/interpreter/interpreter_impl.h
浏览文件 @
0a6f4a88
...
@@ -77,7 +77,7 @@ private:
...
@@ -77,7 +77,7 @@ private:
struct
State
;
struct
State
;
TensorInfo
*
alloc
();
TensorInfo
*
alloc
();
void
init
(
TensorInfo
*
,
LogicalTensorDesc
&&
desc
);
void
init
(
TensorInfo
*
,
LogicalTensorDesc
desc
);
void
free
(
TensorInfo
*
);
void
free
(
TensorInfo
*
);
void
real_free
(
TensorInfo
*
);
void
real_free
(
TensorInfo
*
);
void
recursive_free
(
TensorInfo
*
);
void
recursive_free
(
TensorInfo
*
);
...
...
imperative/src/impl/interpreter/tensor_info.h
浏览文件 @
0a6f4a88
...
@@ -91,6 +91,7 @@ struct TensorInfo {
...
@@ -91,6 +91,7 @@ struct TensorInfo {
SmallVector
<
TensorInfo
*>
inputs
;
SmallVector
<
TensorInfo
*>
inputs
;
SmallVector
<
TensorInfo
*>
unique_inputs
;
SmallVector
<
TensorInfo
*>
unique_inputs
;
SmallVector
<
TensorInfo
*>
outputs
;
SmallVector
<
TensorInfo
*>
outputs
;
SmallVector
<
LogicalTensorDesc
>
outputs_descs
;
size_t
ref_cnt
()
{
size_t
ref_cnt
()
{
return
outputs
.
size
()
-
std
::
count
(
outputs
.
begin
(),
outputs
.
end
(),
nullptr
);
return
outputs
.
size
()
-
std
::
count
(
outputs
.
begin
(),
outputs
.
end
(),
nullptr
);
...
@@ -98,12 +99,14 @@ struct TensorInfo {
...
@@ -98,12 +99,14 @@ struct TensorInfo {
static
ComputePath
*
make
(
static
ComputePath
*
make
(
uint64_t
id
,
std
::
shared_ptr
<
OpDef
>
op
,
SmallVector
<
TensorInfo
*>
inputs
,
uint64_t
id
,
std
::
shared_ptr
<
OpDef
>
op
,
SmallVector
<
TensorInfo
*>
inputs
,
SmallVector
<
TensorInfo
*>
outputs
)
{
SmallVector
<
TensorInfo
*>
outputs
,
SmallVector
<
LogicalTensorDesc
>
outputs_descs
)
{
auto
*
path
=
new
TensorInfo
::
ComputePath
();
auto
*
path
=
new
TensorInfo
::
ComputePath
();
path
->
id
=
id
;
path
->
id
=
id
;
path
->
op
=
op
;
path
->
op
=
op
;
path
->
inputs
=
inputs
;
path
->
inputs
=
inputs
;
path
->
outputs
=
outputs
;
path
->
outputs
=
outputs
;
path
->
outputs_descs
=
outputs_descs
;
// dedup
// dedup
SmallVector
<
TensorInfo
*>
unique_inputs
=
inputs
;
SmallVector
<
TensorInfo
*>
unique_inputs
=
inputs
;
std
::
sort
(
unique_inputs
.
begin
(),
unique_inputs
.
end
());
std
::
sort
(
unique_inputs
.
begin
(),
unique_inputs
.
end
());
...
...
imperative/src/impl/physical_tensor.cpp
浏览文件 @
0a6f4a88
...
@@ -87,7 +87,7 @@ Blob::~Blob() {
...
@@ -87,7 +87,7 @@ Blob::~Blob() {
}
}
const
Blob
::
RawStorage
&
Blob
::
storage
()
{
const
Blob
::
RawStorage
&
Blob
::
storage
()
{
if
(
!
m_storage
)
{
if
(
!
m_storage
&&
m_size
)
{
BlobManager
::
inst
()
->
alloc_with_defrag
(
this
,
m_size
);
BlobManager
::
inst
()
->
alloc_with_defrag
(
this
,
m_size
);
}
}
return
m_storage
;
return
m_storage
;
...
...
imperative/src/include/megbrain/imperative/blob_manager.h
浏览文件 @
0a6f4a88
...
@@ -18,6 +18,8 @@ namespace imperative {
...
@@ -18,6 +18,8 @@ namespace imperative {
class
BlobManager
:
public
NonCopyableObj
{
class
BlobManager
:
public
NonCopyableObj
{
public:
public:
using
allocator_t
=
std
::
function
<
DeviceTensorStorage
::
RawStorage
(
CompNode
,
size_t
)
>
;
virtual
~
BlobManager
()
=
default
;
virtual
~
BlobManager
()
=
default
;
static
BlobManager
*
inst
();
static
BlobManager
*
inst
();
...
@@ -26,6 +28,8 @@ public:
...
@@ -26,6 +28,8 @@ public:
virtual
void
alloc_with_defrag
(
Blob
*
blob
,
size_t
size
)
=
0
;
virtual
void
alloc_with_defrag
(
Blob
*
blob
,
size_t
size
)
=
0
;
virtual
void
set_allocator
(
allocator_t
allocator
)
=
0
;
virtual
DeviceTensorND
alloc_workspace_with_defrag
(
virtual
DeviceTensorND
alloc_workspace_with_defrag
(
CompNode
cn
,
TensorLayout
&
layout
)
=
0
;
CompNode
cn
,
TensorLayout
&
layout
)
=
0
;
...
...
imperative/src/include/megbrain/imperative/physical_tensor.h
浏览文件 @
0a6f4a88
...
@@ -119,7 +119,7 @@ public:
...
@@ -119,7 +119,7 @@ public:
return
make_scalar
(
value
,
m_blob
->
comp_node
());
return
make_scalar
(
value
,
m_blob
->
comp_node
());
}
}
BlobPtr
blob
()
{
return
m_blob
;
}
BlobPtr
&
blob
()
{
return
m_blob
;
}
void
fetch_value
();
void
fetch_value
();
bool
value_fetched
();
bool
value_fetched
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录