diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 6ae239efa69e3c2874737344794ffea38bc6710a..6a1d7b6ed35e18538117a64dba8c9dfaf587e603 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -124,6 +124,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { auto info = alloc(); init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()}); + info->mem_desc.id = StorageIdentifier::make(++m_storage_id); info->h_value = value; m_buffer.enqueue(Put{info, value, no_cache}); if (m_async_level == 0) { @@ -141,6 +142,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { auto info = alloc(); RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); init(info, {data.layout(), data.comp_node()}); + info->mem_desc.id = StorageIdentifier::make(++m_storage_id); info->ptr = Tensor::make(data); RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); info->status = TensorInfo::Produced; @@ -487,6 +489,9 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) { RECORD_EVENT(TensorDeclareEvent, info->id, info->name); info->status = TensorInfo::Allocated; info->desc = std::move(desc); + info->mem_desc.layout = info->desc.layout; + info->mem_desc.cn = info->desc.comp_node; + info->mem_desc.offset = 0; } @@ -605,6 +610,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0); uint64_t apply_id = cmd.id; SmallVector tensor_inputs; + SmallVector input_memory_desc; if (state.options.enable_dtr_auto_drop) { m_dtr.pin(cmd.inputs); } @@ -618,8 +624,27 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { // refcnt == 1, owners: [TensorInfo::ptr] for (auto i : cmd.inputs) { mgb_assert(i->ptr, "Invalid input tensor ptr!"); + mgb_assert(i->mem_desc.id, "Invalid input tensor mem desc!"); // refcnt ++, owners: [i->ptr, tensor_inputs] tensor_inputs.push_back(i->ptr); + input_memory_desc.push_back(i->mem_desc); + } + // SmallVector outputs_mem_desc; + // SmallVector tensor_outputs, workspaces; + auto [outputs_mem_desc, tensor_outputs, workspaces] = init_output_and_workspace(*cmd.op, tensor_inputs, input_memory_desc); + if (outputs_mem_desc.size()) { + for (size_t i = 0;i < outputs_mem_desc.size();i ++) { + if (cmd.outputs[i]) { + cmd.outputs[i]->mem_desc = outputs_mem_desc[i]; + } + } + } else { + // fail to infer mem plan + for (auto && out : cmd.outputs) { + if (out) { + out->mem_desc.id = StorageIdentifier::make(); + } + } } RECORD_EVENT(OpExecuteEvent, apply_id); // Begin profiling operator @@ -662,8 +687,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { } // Apply op // Here std::move is REQUIRED for removing duplicated references. - auto tensor_outputs = OpDef::apply_on_physical_tensor( - *cmd.op, std::move(tensor_inputs)); + if (outputs_mem_desc.size()) { + OpDef::execute( + *cmd.op, std::move(tensor_inputs), tensor_outputs, std::move(workspaces)); + } else { + tensor_outputs = OpDef::apply_on_physical_tensor( + *cmd.op, std::move(tensor_inputs)); + } // After execute for (auto&& [device, kernel_id]: kernels) { RECORD_EVENT(KernelExecuteFinishEvent, apply_id, kernel_id, Timer::record_event(device)); @@ -705,7 +735,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { RECORD_EVENT(OpExecuteFinishEvent, apply_id); // End profiling operator } - + void ChannelImpl::recompute(TensorInfo::ComputePath* path) { auto& state = get_worker_state(); do_apply_op(ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}); @@ -829,6 +859,47 @@ std::unordered_set ChannelImpl::collect_valid_tensors() { return valid_tensors; } +std::tuple, SmallVector, SmallVector> ChannelImpl::init_output_and_workspace( + const OpDef& def, + SmallVector inputs, + SmallVector inputs_mem_desc) { + + auto [outputs_desc, workspaces_desc] = OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc); + if (!outputs_desc.size()) { + // failed to infer memplan + return {{}, {}, {}}; + } + // refine storage id to make it unique + for (auto&& desc : outputs_desc) { + if (desc.id->is_sys_alloc()) { + // TODO: there may be some outputs sharing the same storage id + desc.id->id = ++ m_storage_id; + } + } + auto alloc_storage = [&](SmallVector& desc) { + SmallVector tensors; + for (size_t i = 0; i < desc.size(); i ++) { + if (desc[i].id->is_sys_alloc()) { + tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn)); + } else if (desc[i].id->is_from_other()) { + for (size_t j = 0; j < inputs_mem_desc.size();j ++) { + if (inputs_mem_desc[j].id->desc == desc[i].id->desc) { + tensors.push_back(inputs[j]->sub(desc[i].offset, desc[i].layout)); + break; + } + } + } else if (desc[i].id->is_device_ptr()) { + tensors.push_back(desc[i].id->ptr); + } else { + mgb_assert(0, "not implemented"); + } + } + return tensors; + }; + + return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)}; +} + void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { using namespace ranges; using namespace ranges::views; diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 14b3b2de37fd841df31f79f8a5e53440d0c10fe4..bbe993a942287185ced0092de07c16e9cc2648c7 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -101,6 +101,11 @@ private: void regenerate(TensorInfo* dest); void recompute(TensorInfo::ComputePath* path); void do_apply_op(const ApplyOp& cmd); + + std::tuple, SmallVector, SmallVector> init_output_and_workspace( + const OpDef& def, + SmallVector inputs, + SmallVector inputs_mem_desc); void dispatch_default_cpu( std::shared_ptr op, @@ -139,6 +144,7 @@ private: uint64_t m_waitee_id = 0; std::exception_ptr m_worker_exc; std::function m_profile_dump_callback; + size_t m_storage_id = 0; bool m_closed = false; diff --git a/imperative/src/impl/interpreter/tensor_info.h b/imperative/src/impl/interpreter/tensor_info.h index 5a3067ca056e55a86423d0b959ee5e9147850101..9102d55a12734ea42d47d30ce63850816c290c79 100644 --- a/imperative/src/impl/interpreter/tensor_info.h +++ b/imperative/src/impl/interpreter/tensor_info.h @@ -58,6 +58,7 @@ struct TensorInfo { // Lock interpreter when visiting `ptr`. TensorPtr ptr; LogicalTensorDesc desc; + MemoryDesc mem_desc; double compute_time; size_t memory; diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 7603299a4d832cf96125cb9107db5d0d0676d383..ede6c6bca62397ec3d41b2fb7f5fc21863d1b3d2 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -45,6 +45,21 @@ SmallVector OpDef::apply_on_physical_tensor( return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); } +std::tuple, SmallVector> OpDef::infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems); +} + +void OpDef::execute( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace)); +} + void OpDef::apply_on_device_tensornd( const OpDef& def, const SmallVector& inputs, diff --git a/imperative/src/impl/op_trait.cpp b/imperative/src/impl/op_trait.cpp index ba84111d487e56ce67008908f80dfe1f2582bcac..640583a671ece68d07e6da9e6d053a1dc8d956c6 100644 --- a/imperative/src/impl/op_trait.cpp +++ b/imperative/src/impl/op_trait.cpp @@ -79,6 +79,13 @@ OpTraitRegistry& OpTraitRegistry::fallback() { trait->apply_on_physical_tensor = proxy_graph_detail::apply_on_physical_tensor; } + if (!trait->execute) { + trait->execute = proxy_graph_detail::execute; + } + if (!trait->infer_output_mem_desc) { + trait->infer_output_mem_desc = + proxy_graph_detail::infer_output_mem_desc; + } if (!trait->infer_output_attrs_fallible) { trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible; diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 8abd638d4a7771d40e615f58f7e89cec1eca6e5b..dff441479305635ebe15482ec6ebfde7707970f7 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -64,6 +64,10 @@ using DecideDispatchMode = detail::OpMeth< decltype(OpDef::decide_dispatch_mode)>; using ApplyOnPhysicalTensor = detail::OpMeth< decltype(OpDef::apply_on_physical_tensor)>; +using InferOutputMemDesc = detail::OpMeth< + decltype(OpDef::infer_output_mem_desc)>; +using Execute = detail::OpMeth< + decltype(OpDef::execute)>; using ApplyOnDeviceTensorND = detail::OpMeth< decltype(OpDef::apply_on_device_tensornd)>; using ApplyOnVarNode = detail::OpMeth< @@ -82,6 +86,8 @@ struct OpTrait { OpDefMaker make_from_op_node; DecideDispatchMode decide_dispatch_mode; ApplyOnPhysicalTensor apply_on_physical_tensor; + InferOutputMemDesc infer_output_mem_desc; + Execute execute; ApplyOnDeviceTensorND apply_on_device_tensornd; ApplyOnVarNode apply_on_var_node; InferOutputAttrsFallible infer_output_attrs_fallible; @@ -100,6 +106,8 @@ struct OpTrait { cb(make_from_op_node) \ cb(decide_dispatch_mode) \ cb(apply_on_physical_tensor) \ + cb(infer_output_mem_desc) \ + cb(execute) \ cb(apply_on_device_tensornd) \ cb(apply_on_var_node) \ cb(infer_output_attrs_fallible) \ diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index bdab57db95599fef09b90087adbc35a9fa1c6082..51b3ee0427ebe9b69f0e760ab4f434a44dd2c4fd 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -79,10 +79,24 @@ std::tuple, bool> infer_output_attrs_fallible( }, false}; } +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& inputs_mems) { + return {{}, {}}; +} + +void execute(const OpDef& def, const SmallVector& inputs, + const SmallVector& outputs, const SmallVector& workspace) { + mgb_assert(0); +} + OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .infer_output_mem_desc(infer_output_mem_desc) + .execute(execute) .fallback(); } // namespace diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 40c69d8c05c0b2bdd9d5ee15577d540e91003a94..ab6196587a6b2f85ead89bf8c714036b3d79ec2d 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -118,6 +118,35 @@ void apply_on_device_tensornd( opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); } +void execute( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + mgb_assert(outputs.size() == 1); + SmallVector inp_tensornds(inputs.size()); + for (size_t i = 0;i < inputs.size(); ++i) { + inp_tensornds[i] = inputs[i]->dev_tensor(); + } + SmallVector out_tensornds = {outputs[0]->dev_tensor()}; + apply_on_device_tensornd(def, inp_tensornds, &out_tensornds); +} + +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + auto&& op_def = def.cast_final_safe(); + TensorShapeArray inp_shapes(inputs_tensors.size()); + for (size_t i = 0;i < inputs_tensors.size(); ++i) { + inp_shapes[i] = inputs_tensors[i]->layout(); + } + TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); + SmallVector outputs = {{{shape, inputs_tensors[0]->dtype()}, 0, inputs_tensors[0]->comp_node(), StorageIdentifier::make(1)}}; + return {outputs, {}}; +} + + SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { @@ -224,7 +253,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( SmallVector apply_inplace_add_on_physical_tensor( const OpDef& def, const SmallVector& inputs){ - mgb_assert(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique(), + mgb_assert(inputs[0]->blob().use_count() == 2 && inputs[0]->blob()->storage().unique(), "This inplace modification may change the elements of other tensors. " "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly."); auto dest = inputs[0], delta = inputs[1], @@ -238,6 +267,23 @@ SmallVector apply_inplace_add_on_physical_tensor( return { std::make_shared(dest->blob(), dest->offset(), dest->layout()) }; } +void execute_inplace( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + apply_inplace_add_on_physical_tensor(def, inputs); +} + +std::tuple, SmallVector> infer_inplace_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + auto dest = inputs_tensors[0]; + SmallVector outputs = {{dest->layout(), 0, dest->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}; + return {outputs, {}}; +} + std::tuple, bool> infer_inplace_add_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { @@ -271,12 +317,16 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_mem_desc(infer_output_mem_desc) + .execute(execute) .fallback(); OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) .apply_on_var_node(apply_inplace_add_on_var_node) .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) + .infer_output_mem_desc(infer_inplace_output_mem_desc) + .execute(execute_inplace) .fallback(); } // anonymous namespace diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 1e97a89ddb57ff00f01b54a1dd8bd419ef31d70d..b913650e04436eb5f216640364b158e7d7374480 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -331,6 +331,7 @@ struct _RNGOprInvoker { } \ }; + #define _INST_RNG_MAKER(MGB_NR_INPUTS) \ template<> \ struct _RNGOprMaker { \ @@ -366,7 +367,7 @@ _INST_RNG_MAKER(2) template void exec(const OpDef& op, const SmallVector& inputs, - const SmallVector& outputs) { + const SmallVector& outputs, const SmallVector& workspace) { auto&& rng = op.cast_final_safe(); auto dest = outputs[0]; @@ -418,6 +419,18 @@ SmallVector infer_output_attrs( return {dest}; } +template +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + auto &&dest = infer_output_attrs(def, inputs_tensors); + SmallVector outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}}; + + return {outputs, {}}; +} + + template SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { @@ -427,10 +440,19 @@ SmallVector apply_on_physical_tensor( for (auto&& i : desc) { outputs.push_back(Tensor::make(i.layout, i.comp_node)); } - exec(def, inputs, outputs); + exec(def, inputs, outputs, {}); return outputs; } +template +void execute( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + exec(def, inputs, outputs, {}); +} + template SymbolVar apply_on_var_node( const OpDef& def, @@ -492,6 +514,8 @@ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ .apply_on_var_node(apply_on_var_node) \ .apply_on_physical_tensor(apply_on_physical_tensor) \ .infer_output_attrs_fallible(infer_output_attrs_fallible) \ + .infer_output_mem_desc(infer_output_mem_desc) \ + .execute(execute) \ .fallback(); \ } \ diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 2bc29bbe217a12bb6a5d48a2d29daf8b14d00b8e..10862e0b7b81faa7df0b2e3bc456b998f13ace6a 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -86,22 +86,22 @@ void apply_on_device_tensornd( (*outputs)[0] = DeviceTensorND::make_proxy(hv); } -SmallVector apply_on_physical_tensor( - const OpDef& def, - const SmallVector& inputs) { +HostTensorND get_var_shape_host_tensor(const OpDef& def, const SmallVector& inputs) { SmallVector input_tensornds; input_tensornds.reserve(inputs.size()); for (auto&& inp : inputs) { input_tensornds.push_back(inp->dev_tensor()); } SmallVector output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}}; - apply_on_device_tensornd(def, input_tensornds, &output_tensornds); - // restore to input comp_node - HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0]) - .proxy_to_comp_node(inputs[0]->comp_node()); - return {Tensor::make(std::move(host_tensornd))}; + return HostTensorND::make_proxy(output_tensornds[0]).proxy_to_comp_node(inputs[0]->comp_node()); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, + const SmallVector& inputs) { + return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))}; } std::tuple, bool> infer_output_attrs_fallible( @@ -142,6 +142,33 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; } +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& inputs_mems) { + HostTensorND tensor = get_var_shape_host_tensor(def, inputs); + SmallVector ret; + auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor); + if (blob) { + ret.push_back({tensor.layout(), 0, inputs[0]->comp_node(), + StorageIdentifier::make(Tensor::make(std::forward(blob), tensor.layout(), tensor))}); + } else { + ret.push_back({tensor.layout(), 0, inputs[0]->comp_node(), StorageIdentifier::make(1)}); + } + return {ret, {}}; +} + +void execute(const OpDef& def, const SmallVector& inputs, + const SmallVector& outputs, const SmallVector& workspace) { + HostTensorND tensor = get_var_shape_host_tensor(def, inputs); + SmallVector ret; + auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor); + if (!blob || blob->storage() != outputs[0]->blob()->storage()) { + outputs[0]->dev_tensor().copy_from_fixlayout(tensor); + AsyncReleaser::inst()->add(tensor); + } +} + std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); return GetVarShape::make(node->param()); @@ -154,6 +181,8 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) .apply_on_var_node(apply_on_var_node) .apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_mem_desc(infer_output_mem_desc) + .execute(execute) .fallback(); } // get_var_shape @@ -181,6 +210,31 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( return opr; } +std::tuple, SmallVector> param_pack_split_infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& inputs_mems) { + auto&& param = def.cast_final_safe(); + mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); + auto&& inp = inputs[0]; + auto&& shp = inp->layout(); + mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1"); + mgb_assert(param.shapes.size() * 2 == param.offsets.size()); + SmallVector ret; + auto&& shapes = get_shapes(param.shapes); + size_t dtype_size = inputs[0]->layout().dtype.size(); + for (size_t i = 0; i < shapes.size(); ++i) { + // memory forward + ret.push_back({{shapes[i], inputs[0]->dtype()}, param.offsets[i * 2] * dtype_size, inp->comp_node(), StorageIdentifier::make(&inputs_mems[0])}); + } + return {ret, {}}; +} + +void param_pack_split_execute(const OpDef& def, const SmallVector& inputs, + const SmallVector& outputs, const SmallVector& workspace) { + // do nothing +} + SmallVector param_pack_split_apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { @@ -203,6 +257,8 @@ SmallVector param_pack_split_apply_on_physical_tensor( OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit) .apply_on_var_node(param_pack_split_apply_on_var_node) + .infer_output_mem_desc(param_pack_split_infer_output_mem_desc) + .execute(param_pack_split_execute) .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor) .fallback(); @@ -219,6 +275,64 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( return opr; } + +std::tuple, SmallVector> param_pack_concat_infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& inputs_mems) { + def.cast_final_safe(); + mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); + auto comp_node = inputs.front()->comp_node(); + auto dtype = inputs.front()->dtype(); + size_t nr_inputs = inputs.size() - 1; + size_t nr_elems = 0; + for (size_t i = 0; i < nr_inputs; ++i) { + auto& input = inputs[i]; + mgb_assert(comp_node == input->comp_node(), "inputs for param_pack_concat must in same comp_node"); + mgb_assert(dtype == input->dtype(), "inputs for param_pack_concat must have same dtype"); + nr_elems += input->layout().total_nr_elems(); + } + auto dest_layout = TensorLayout({nr_elems}, dtype); + auto caller = DnnOprCaller(comp_node); + size_t ws_size; + { + TensorShapeArray src_shapes; + for (size_t i = 0; i < nr_inputs; ++i) { + src_shapes.push_back(inputs[i]->shape()); + } + ws_size = caller.op->get_workspace_in_bytes(src_shapes, inputs.back()->shape(), TensorShape{}); + } + + SmallVector outputs = {{dest_layout, 0, comp_node, StorageIdentifier::make(1)}}; + MemoryDesc workspace = {{{ws_size}, dtype::Byte()}, 0, comp_node, StorageIdentifier::make(2)}; + + return {outputs, {workspace}}; +} + +void param_pack_concat_execute(const OpDef& def, const SmallVector& inputs, + const SmallVector& outputs, const SmallVector& workspace) { + def.cast_final_safe(); + mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); + auto comp_node = inputs.front()->comp_node(); + size_t nr_inputs = inputs.size() - 1; + auto caller = DnnOprCaller(comp_node); + size_t srcs_size = sizeof(void*)*nr_inputs; + void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size); + std::shared_ptr srcs_ptr = {(dt_byte*)srcs_raw_ptr, [comp_node](dt_byte* ptr){ + comp_node.free_host(ptr); + }}; + TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()}; + for (size_t i = 0; i < nr_inputs; ++i) { + srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr; + } + HostTensorStorage srcs_storage; + srcs_storage.reset(comp_node, srcs_size, srcs_ptr); + megdnn::Workspace dnn_wk(workspace[0]->blob()->storage().get(), workspace[0]->blob()->size()); + caller.op->exec({srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), outputs[0]->dev_tensor().as_megdnn(), + dnn_wk); + AsyncReleaser::inst()->add(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); +} + SmallVector param_pack_concat_apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { @@ -264,6 +378,8 @@ SmallVector param_pack_concat_apply_on_physical_tensor( OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) .apply_on_var_node(param_pack_concat_apply_on_var_node) + .infer_output_mem_desc(param_pack_concat_infer_output_mem_desc) + .execute(param_pack_concat_execute) .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor) .fallback(); } // param_pack diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index ba0b5ae26e4315247a14b21c93b30c396535d7a2..5f675347ab7f4c6b827204ab3046c7bab71ebb04 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -77,149 +77,6 @@ public: bool CompNodeSyncManager::is_into_atexit = false; #endif -// Cache for small blobs -// 1. A blob has to be seen twice (within a window) to be eligible for cache -// 2. Cache eviction occurs when cache size reaches a threshold, in least frequently used order -class ConstTensorCache { -public: - struct Entry { - size_t hitcnt = 0; - std::unique_ptr data; - size_t size; - BlobPtr blob; - - Entry() = default; - Entry(const dt_byte* ptr, size_t size_, BlobPtr blob_) - : data(new dt_byte[size_]), size(size_), blob(blob_) { - memcpy(data.get(), ptr, size); - } - - // does not check input - bool match(const HostTensorND& hv) { - return 0 == memcmp(data.get(), hv.raw_ptr(), hv.layout().span().high_byte); - } - }; - - using KV = std::pair; - - bool check(const HostTensorND& hv) { - auto&& layout = hv.layout(); - auto&& span = layout.span(); - return hv.format().is_default() && !hv.empty() && - layout.is_contiguous() && span.low_byte == 0 && - span.high_byte <= max_bytes; - } - - // hash storage; does not check input - static uint64_t hash(const HostTensorND& hv) { - auto&& span = hv.layout().span(); - return XXHash{} - .update(hv.raw_ptr(), span.high_byte) - .digest(); - } - - BlobPtr lookup(const HostTensorND& hv) { - if (!check(hv)) { - return {}; - } - auto h = hash(hv); - MGB_LOCK_GUARD(mtx); - // lookup in g1 - auto it = g1.find(h); - if (it != g1.end()) { - if (!it->second.match(hv)) { - mgb_log_warn("hash collision in const tensor cache"); - return {}; - } - it->second.hitcnt += 1; - return it->second.blob; - } - // lookup in g0 - if (!g0.extract(h) && !g0b.extract(h)) { - maybe_collect_g0(); - g0.emplace(h); - return {}; - } - // add new entry to g1 - maybe_collect_g1(); - Entry entry(hv.raw_ptr(), hv.layout().span().high_byte, Tensor(hv).blob()); - it = g1.emplace_hint(it, h, std::move(entry)); - it->second.hitcnt += 1; - return it->second.blob; - } - - void clear() { - MGB_LOCK_GUARD(mtx); - g0.clear(); - g0b.clear(); - g1.clear(); - } - - std::mutex mtx; - const size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536; - -private: - void maybe_collect_g0() { - if (g0.size() > window) { - std::swap(g0, g0b); - g0.clear(); - } - } - void maybe_collect_g1() { - if (g1.size() < hwm) return; - - tmp.clear(); - for (auto&& kv : g1) { - tmp.emplace_back(kv.first, std::move(kv.second)); - } - std::nth_element(tmp.begin(), tmp.begin() + lwm, tmp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.second.hitcnt > rhs.second.hitcnt; - }); - tmp.resize(lwm); - g1.clear(); - for (auto&& kv : tmp) { - kv.second.hitcnt = 0; - g1.emplace(std::move(kv)); - } - } - - // g0: records blobs which have been seen at least once (within a window) - // g0b: backup of g0 - // g1: records the most frequently used blobs which have been seen at least - // twice. When `g1.size() == hwm`, it will be refreshed and only the top - // `lhw` frequently used blobs will be kept. - std::unordered_set g0, g0b; - std::unordered_map g1; - std::vector tmp; - -public: - ConstTensorCache() { - g0.reserve(window), g0b.reserve(window); - g1.reserve(hwm), tmp.reserve(hwm); - } -}; - -struct MultiCNConstTensorCache : CompNodeDepedentObject { - std::mutex mtx; - CompNode::UnorderedMap cn2cache; - - std::shared_ptr on_comp_node_finalize() { - MGB_LOCK_GUARD(mtx); - cn2cache.clear(); - return {}; - } - - BlobPtr lookup(const HostTensorND& hv) { - MGB_LOCK_GUARD(mtx); - return cn2cache[hv.comp_node()].lookup(hv); - } - - static MultiCNConstTensorCache& inst() { - static MultiCNConstTensorCache sl_inst; - return sl_inst; - } -}; - } // namespace void EventDeleter::operator()(CompNode::Event* event) { diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 7b5a02fdc33feb4874b057531e60712b07acc1ba..77f6f5bafb3132e5763e44c17f23ae9a36280648 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -522,9 +522,10 @@ SmallVector ProxyGraph::infer_output_attrs( void ProxyGraph::invoke_op(const OpDef& opdef, const SmallVector& inputs, - const SmallVector& outputs) { + const SmallVector& outputs, + const SmallVector& workspaces) { CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); - init_output_tensor(outputs); + init_output_tensor(outputs, workspaces); for (auto oup : m_cur_opr->output()) { m_graph->add_used_comp_node(oup->comp_node()); } @@ -544,19 +545,30 @@ void ProxyGraph::cleanup() { m_cur_opr = nullptr; } -void ProxyGraph::init_output_tensor(const SmallVector& outputs) { +void ProxyGraph::init_output_tensor(const SmallVector& outputs, const SmallVector& workspaces) { // get proxy opr auto proxy = m_cur_opr; do_shape_infer(true); size_t j = 0; + size_t k = 0; for (auto&& var : proxy->output()) { auto &&chk = var->m_mem_plan.reset_from_owner_var().chunk(); if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - // alloc workspace - TensorLayout layout{var->shape(), var->dtype(), var->format()}; - var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout); + // workspace + if (workspaces.size()) { + mgb_assert(k < workspaces.size()); + auto && layout = workspaces[k]->layout(); + mgb_assert(var->comp_node() == workspaces[k]->comp_node() && + var->shape().eq_shape(layout) && + var->dtype() == layout.dtype); + var->m_dev_tensor = workspaces[k]->dev_tensor(); + ++ k; + } else { + TensorLayout layout{var->shape(), var->dtype(), var->format()}; + var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout); + } } else { mgb_assert(j < outputs.size()); auto &&tensor = outputs[j]; @@ -570,6 +582,7 @@ void ProxyGraph::init_output_tensor(const SmallVector& outputs) { chk.mem_alloc_status.set_from_owner_var(); } mgb_assert(j == outputs.size()); + mgb_assert(k == workspaces.size()); // Memory forwarding was bypassed in megbrain with graph option // imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly @@ -623,6 +636,26 @@ std::tuple, bool> ProxyGraph::infer_output_attrs_ return {outputs, validated && !need_check}; } +std::tuple, SmallVector> ProxyGraph::infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + auto opr = get_proxy_opr(def, inputs_tensors); + CUR_OPR_GUARD(opr); + do_shape_infer(true); + SmallVector outputs; + SmallVector workspaces; + size_t cur_id = 0; + for (auto&& i : opr->output()) { + if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + workspaces.push_back({{i->shape(), i->dtype(), i->format()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)}); + } else { + outputs.push_back({{i->shape(), i->dtype()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)}); + } + } + return {outputs, workspaces}; +} + struct ProxyGraph::GradGraph { cg::VarNodeArray inputs; cg::VarNodeArray outputs; diff --git a/imperative/src/impl/proxy_graph.h b/imperative/src/impl/proxy_graph.h index 779949d60357f4dac55e5e37f4968f9d6dd1c6b5..787f5cb15c5c4230dac10946b16065c239c03892 100644 --- a/imperative/src/impl/proxy_graph.h +++ b/imperative/src/impl/proxy_graph.h @@ -37,7 +37,8 @@ public: void invoke_op( const OpDef& opdef, const SmallVector& inputs, - const SmallVector& outputs); + const SmallVector& outputs, + const SmallVector& workspace); BackwardGraphResult make_backward_graph( const OpDef& opdef, @@ -45,6 +46,11 @@ public: const SmallVector& input_requires_grad, const SmallVector& output_has_grad); + std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems); + /********************** Logical Tensor API **********************/ size_t get_opr_output_size( @@ -74,7 +80,8 @@ private: void cleanup(); void init_output_tensor( - const SmallVector& outputs); + const SmallVector& outputs, + const SmallVector& workspace); cg::OperatorNodeBase* get_proxy_opr( const OpDef& opdef, diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index d5d0365c54b58b511a419d4716d2482c10122f0a..93d75b3daf4d567cd0502c22af3d24c29b249ad2 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -43,10 +43,12 @@ infer_output_attrs(const OpDef& def, void exec(const OpDef& def, const SmallVector& inputs, - const SmallVector& outputs) { + const SmallVector& outputs, + const SmallVector& workspaces) { auto&& graph = ProxyGraph::get_default_graph(); auto raw_inputs = to_raw_ptr_array(inputs), - raw_outputs = to_raw_ptr_array(outputs); + raw_outputs = to_raw_ptr_array(outputs), + raw_workspaces = to_raw_ptr_array(workspaces); CompNode::UnorderedSet used_cns; for (auto&& out: raw_outputs) { auto cn = out->comp_node(); @@ -59,7 +61,7 @@ void exec(const OpDef& def, } } } - graph->invoke_op(def, raw_inputs, raw_outputs); + graph->invoke_op(def, raw_inputs, raw_outputs, raw_workspaces); for (auto&& cn: used_cns) { for (auto&& in: inputs) { if (in->comp_node() != cn) { @@ -77,7 +79,7 @@ apply_on_physical_tensor(const OpDef& def, for (size_t i = 0; i < outputs.size(); i++) { outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); } - exec(def, inputs, outputs); + exec(def, inputs, outputs, {}); auto async_error = ProxyGraph::get_async_error(); if (async_error) { throw *async_error; @@ -85,6 +87,26 @@ apply_on_physical_tensor(const OpDef& def, return outputs; } +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + auto&& graph = ProxyGraph::get_default_graph(); + return graph->infer_output_mem_desc(def, to_raw_ptr_array(inputs_tensors), inputs_mems); +} + +void execute(const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + exec(def, inputs, outputs, workspace); + auto async_error = ProxyGraph::get_async_error(); + if (async_error) { + throw *async_error; + } + return; +} + // std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, // const SmallVector& inputs) { // auto&& graph = ProxyGraph::get_default_graph(); diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 59bb96df6a7f3a0ee747bd4a85ed30f553c1dddf..8b8641a51b66a3f189b12d4b517ed058054883f4 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -108,6 +108,13 @@ public: static SmallVector apply_on_physical_tensor( const OpDef& def, SmallVector inputs); + + static void execute( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace); + /*! * \brief Call the corresponding dnn op to calculate results. Output @@ -126,6 +133,11 @@ public: const OpDef& def, const SmallVector& inputs); + static std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems); + static BackwardGraphResult make_backward_graph( const OpDef& def, const SmallVector& inputs, diff --git a/imperative/src/include/megbrain/imperative/physical_tensor.h b/imperative/src/include/megbrain/imperative/physical_tensor.h index 114c60052498794155eb025aa8935123039eae7a..e8d250f7e6c6707809dab98e63ae7ab0175212ab 100644 --- a/imperative/src/include/megbrain/imperative/physical_tensor.h +++ b/imperative/src/include/megbrain/imperative/physical_tensor.h @@ -150,12 +150,192 @@ private: EventPtr m_value_ready = nullptr; }; +// Cache for small blobs +// 1. A blob has to be seen twice (within a window) to be eligible for cache +// 2. Cache eviction occurs when cache size reaches a threshold, in least frequently used order +class ConstTensorCache { +public: + struct Entry { + size_t hitcnt = 0; + std::unique_ptr data; + size_t size; + BlobPtr blob; + + Entry() = default; + Entry(const dt_byte* ptr, size_t size_, BlobPtr blob_) + : data(new dt_byte[size_]), size(size_), blob(blob_) { + memcpy(data.get(), ptr, size); + } + + // does not check input + bool match(const HostTensorND& hv) { + return 0 == memcmp(data.get(), hv.raw_ptr(), hv.layout().span().high_byte); + } + }; + + using KV = std::pair; + + bool check(const HostTensorND& hv) { + auto&& layout = hv.layout(); + auto&& span = layout.span(); + return hv.format().is_default() && !hv.empty() && + layout.is_contiguous() && span.low_byte == 0 && + span.high_byte <= max_bytes; + } + + // hash storage; does not check input + static uint64_t hash(const HostTensorND& hv) { + auto&& span = hv.layout().span(); + return XXHash{} + .update(hv.raw_ptr(), span.high_byte) + .digest(); + } + + BlobPtr lookup(const HostTensorND& hv) { + if (!check(hv)) { + return {}; + } + auto h = hash(hv); + MGB_LOCK_GUARD(mtx); + // lookup in g1 + auto it = g1.find(h); + if (it != g1.end()) { + if (!it->second.match(hv)) { + mgb_log_warn("hash collision in const tensor cache"); + return {}; + } + it->second.hitcnt += 1; + return it->second.blob; + } + // lookup in g0 + if (!g0.extract(h) && !g0b.extract(h)) { + maybe_collect_g0(); + g0.emplace(h); + return {}; + } + // add new entry to g1 + maybe_collect_g1(); + Entry entry(hv.raw_ptr(), hv.layout().span().high_byte, Tensor(hv).blob()); + it = g1.emplace_hint(it, h, std::move(entry)); + it->second.hitcnt += 1; + return it->second.blob; + } + + void clear() { + MGB_LOCK_GUARD(mtx); + g0.clear(); + g0b.clear(); + g1.clear(); + } + + std::mutex mtx; + const size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536; + +private: + void maybe_collect_g0() { + if (g0.size() > window) { + std::swap(g0, g0b); + g0.clear(); + } + } + void maybe_collect_g1() { + if (g1.size() < hwm) return; + + tmp.clear(); + for (auto&& kv : g1) { + tmp.emplace_back(kv.first, std::move(kv.second)); + } + std::nth_element(tmp.begin(), tmp.begin() + lwm, tmp.end(), [](const KV& lhs, const KV& rhs) { + return lhs.second.hitcnt > rhs.second.hitcnt; + }); + tmp.resize(lwm); + g1.clear(); + for (auto&& kv : tmp) { + kv.second.hitcnt = 0; + g1.emplace(std::move(kv)); + } + } + + // g0: records blobs which have been seen at least once (within a window) + // g0b: backup of g0 + // g1: records the most frequently used blobs which have been seen at least + // twice. When `g1.size() == hwm`, it will be refreshed and only the top + // `lhw` frequently used blobs will be kept. + std::unordered_set g0, g0b; + std::unordered_map g1; + std::vector tmp; + +public: + ConstTensorCache() { + g0.reserve(window), g0b.reserve(window); + g1.reserve(hwm), tmp.reserve(hwm); + } +}; + +struct MultiCNConstTensorCache : CompNodeDepedentObject { + std::mutex mtx; + CompNode::UnorderedMap cn2cache; + + std::shared_ptr on_comp_node_finalize() { + MGB_LOCK_GUARD(mtx); + cn2cache.clear(); + return {}; + } + + BlobPtr lookup(const HostTensorND& hv) { + MGB_LOCK_GUARD(mtx); + return cn2cache[hv.comp_node()].lookup(hv); + } + + static MultiCNConstTensorCache& inst() { + static MultiCNConstTensorCache sl_inst; + return sl_inst; + } +}; + struct LogicalTensorDesc { TensorLayout layout; CompNode comp_node; DeviceTensorND value; // cpu:default }; +struct StorageIdentifier; +struct MemoryDesc { + TensorLayout layout; + size_t offset; + CompNode cn; + std::shared_ptr id; +}; + +struct StorageIdentifier { + enum { INVALID, SYS_ALLOC, FROM_OTHER, DEVICE_PTR } tag; + union { + size_t id; + MemoryDesc* desc; + }; + TensorPtr ptr; + StorageIdentifier() = default; + StorageIdentifier(size_t id): tag(SYS_ALLOC), id(id) {} + StorageIdentifier(const MemoryDesc* desc): tag(FROM_OTHER), desc(desc->id->desc) {} + StorageIdentifier(TensorPtr dev_ptr): tag(DEVICE_PTR), ptr(dev_ptr) {} + + template + static std::shared_ptr make(Args&& ...args) { + return std::make_shared(std::forward(args)...); + } + bool is_sys_alloc() { + return tag == SYS_ALLOC; + } + bool is_from_other() { + return tag == FROM_OTHER; + } + bool is_device_ptr() { + return tag == DEVICE_PTR; + } + bool is_invalid() { + return tag == INVALID; + } +}; } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/proxy_graph_detail.h b/imperative/src/include/megbrain/imperative/proxy_graph_detail.h index 2e2544ef1daf7847a35968123ab5661a6a64bd8c..5bdaab730dd37ed537492daefced4aa0e3c8378f 100644 --- a/imperative/src/include/megbrain/imperative/proxy_graph_detail.h +++ b/imperative/src/include/megbrain/imperative/proxy_graph_detail.h @@ -21,9 +21,19 @@ SmallVector apply_on_physical_tensor(const OpDef& def, SmallVector inputs); +void execute(const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace); + std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, const SmallVector& inputs); +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems); + void exec(const OpDef& def, const SmallVector& inputs, const SmallVector& outputs);