提交 a3ab9f1d 编写于 作者: M Megvii Engine Team

feat(blob_manager): make all memory allocation go through blob manager

GitOrigin-RevId: f79155e5c337ad1a883bd9d9bef8c9f97710fb4a
上级 3ea00eba
......@@ -67,6 +67,29 @@ void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) {
blob->m_storage = storage.raw_storage();
}
DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) {
DeviceTensorND dev_tensor;
if (!m_enable) {
dev_tensor = alloc_workspace(cn, layout);
} else {
MGB_TRY{ dev_tensor = alloc_workspace(cn, layout); }
MGB_CATCH(MemAllocError&, {
mgb_log_warn("memory allocation failed for workspace; try defragmenting");
defrag(cn);
dev_tensor = alloc_workspace(cn, layout);
});
}
return dev_tensor;
};
DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout) {
DeviceTensorStorage storage(cn);
storage.ensure_size(layout.dtype.size(layout.total_nr_elems()));
DeviceTensorND dev_tensor;
dev_tensor.reset(storage, layout);
return dev_tensor;
}
void BlobManagerImpl::defrag(const CompNode& cn) {
BlobSetWithMux* blobs_set_ptr;
{
......@@ -136,6 +159,9 @@ struct BlobManagerStub : BlobManager {
void alloc_with_defrag(Blob* blob, size_t size) {
mgb_assert(0, "prohibited after global variable destruction");
};
DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) {
mgb_assert(0, "prohibited after global variable destruction");
};
void register_blob(Blob* blob) {
mgb_assert(0, "prohibited after global variable destruction");
};
......
......@@ -45,11 +45,15 @@ class BlobManagerImpl final: public BlobManager {
void alloc_direct(Blob* blob, size_t size);
DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout);
public:
static BlobManager* inst();
void alloc_with_defrag(Blob* blob, size_t size) override;
DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) override;
void register_blob(Blob* blob) override;
void unregister_blob(Blob* blob) override;
......
......@@ -16,6 +16,7 @@
#include "../op_trait.h"
#include "../dnn_op_helper.h"
#include "../blob_manager_impl.h"
namespace mgb {
namespace imperative {
......@@ -102,11 +103,16 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
TensorShapeArray inp_shapes(inputs.size());
for (unsigned i = 0; i < inputs.size(); ++i){
inp_tensornds[i] = inputs[i]->dev_tensor();
inp_shapes[i] = inputs[i]->layout();
}
SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}};
TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(inp_tensornds[0].comp_node(), {shape, inp_tensornds[0].layout().dtype});
SmallVector<DeviceTensorND> oup_tensornds = {out};
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds);
return {Tensor::make(oup_tensornds[0])};
}
......
......@@ -555,10 +555,7 @@ void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) {
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
// alloc workspace
TensorLayout layout{var->shape(), var->dtype(), var->format()};
DeviceTensorStorage storage;
storage.comp_node(var->comp_node())
.ensure_size(layout.dtype.size(layout.total_nr_elems()));
var->m_dev_tensor.reset(storage, layout);
var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout);
} else {
mgb_assert(j < outputs.size());
auto &&tensor = outputs[j];
......
......@@ -24,6 +24,8 @@ public:
virtual void alloc_with_defrag(Blob* blob, size_t size) = 0;
virtual DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) = 0;
virtual void register_blob(Blob* blob) = 0;
virtual void unregister_blob(Blob* blob) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册