diff --git a/imperative/src/impl/blob_manager_impl.cpp b/imperative/src/impl/blob_manager_impl.cpp index 8682c81fe0f847710a39696cbdcbd01a84ad69ae..697e716a1c442f0a89f54df1203bc6cb008b2acc 100644 --- a/imperative/src/impl/blob_manager_impl.cpp +++ b/imperative/src/impl/blob_manager_impl.cpp @@ -67,6 +67,29 @@ void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) { blob->m_storage = storage.raw_storage(); } +DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) { + DeviceTensorND dev_tensor; + if (!m_enable) { + dev_tensor = alloc_workspace(cn, layout); + } else { + MGB_TRY{ dev_tensor = alloc_workspace(cn, layout); } + MGB_CATCH(MemAllocError&, { + mgb_log_warn("memory allocation failed for workspace; try defragmenting"); + defrag(cn); + dev_tensor = alloc_workspace(cn, layout); + }); + } + return dev_tensor; +}; + +DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout) { + DeviceTensorStorage storage(cn); + storage.ensure_size(layout.dtype.size(layout.total_nr_elems())); + DeviceTensorND dev_tensor; + dev_tensor.reset(storage, layout); + return dev_tensor; +} + void BlobManagerImpl::defrag(const CompNode& cn) { BlobSetWithMux* blobs_set_ptr; { @@ -136,6 +159,9 @@ struct BlobManagerStub : BlobManager { void alloc_with_defrag(Blob* blob, size_t size) { mgb_assert(0, "prohibited after global variable destruction"); }; + DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) { + mgb_assert(0, "prohibited after global variable destruction"); + }; void register_blob(Blob* blob) { mgb_assert(0, "prohibited after global variable destruction"); }; diff --git a/imperative/src/impl/blob_manager_impl.h b/imperative/src/impl/blob_manager_impl.h index e451ae4e3d60f2d3e7b2a09584b2efe50c6a97f0..09684ee08a1c2d3d68e0a8f895fc07fd2c68af87 100644 --- a/imperative/src/impl/blob_manager_impl.h +++ b/imperative/src/impl/blob_manager_impl.h @@ -45,11 +45,15 @@ class BlobManagerImpl final: public BlobManager { void alloc_direct(Blob* blob, size_t size); + DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); + public: static BlobManager* inst(); void alloc_with_defrag(Blob* blob, size_t size) override; + DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) override; + void register_blob(Blob* blob) override; void unregister_blob(Blob* blob) override; diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index b8a50c434c9db1a8ee96652a1d5e17fbb3bb020a..1ee91e2a2fa94ac40f75b81ff2dea88284837045 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -16,6 +16,7 @@ #include "../op_trait.h" #include "../dnn_op_helper.h" +#include "../blob_manager_impl.h" namespace mgb { namespace imperative { @@ -102,11 +103,16 @@ SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); SmallVector inp_tensornds(inputs.size()); + TensorShapeArray inp_shapes(inputs.size()); for (unsigned i = 0; i < inputs.size(); ++i){ inp_tensornds[i] = inputs[i]->dev_tensor(); + inp_shapes[i] = inputs[i]->layout(); } - SmallVector oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}}; + TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); + DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(inp_tensornds[0].comp_node(), {shape, inp_tensornds[0].layout().dtype}); + SmallVector oup_tensornds = {out}; apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); return {Tensor::make(oup_tensornds[0])}; } diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index c23ecf48593fdae67b2aa86e1eb882488e37f0d1..4ce64ac99054dc78a2cf7c07cd52b4cc3bbacf76 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -555,10 +555,7 @@ void ProxyGraph::init_output_tensor(const SmallVector& outputs) { if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { // alloc workspace TensorLayout layout{var->shape(), var->dtype(), var->format()}; - DeviceTensorStorage storage; - storage.comp_node(var->comp_node()) - .ensure_size(layout.dtype.size(layout.total_nr_elems())); - var->m_dev_tensor.reset(storage, layout); + var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout); } else { mgb_assert(j < outputs.size()); auto &&tensor = outputs[j]; diff --git a/imperative/src/include/megbrain/imperative/blob_manager.h b/imperative/src/include/megbrain/imperative/blob_manager.h index 809a2d67958ba1259d498ffe5b77ade5d9da3704..258643eec260aeb96e8ca3e0048c7fdd25c27b76 100644 --- a/imperative/src/include/megbrain/imperative/blob_manager.h +++ b/imperative/src/include/megbrain/imperative/blob_manager.h @@ -24,6 +24,8 @@ public: virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; + virtual DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) = 0; + virtual void register_blob(Blob* blob) = 0; virtual void unregister_blob(Blob* blob) = 0;