From a5a606792ef325e4ce1b44e6b6e43a7e68523642 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 22 Dec 2020 16:36:22 +0800 Subject: [PATCH] feat(imperative/interpreter): add more dispatch mode in apply_op GitOrigin-RevId: 2663504470e6cf83a4ce5d84131f0cbd2f39716e --- imperative/src/impl/interpreter_impl.cpp | 185 ++++++++++++------ imperative/src/impl/interpreter_impl.h | 11 ++ imperative/src/impl/op_def.cpp | 14 ++ imperative/src/impl/op_trait.cpp | 15 +- imperative/src/impl/op_trait.h | 8 + imperative/src/impl/ops/elemwise.cpp | 41 +++- imperative/src/impl/ops/tensor_manip.cpp | 63 +++++- .../src/include/megbrain/imperative/op_def.h | 26 +++ 8 files changed, 287 insertions(+), 76 deletions(-) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 24e47afd4..5a012813e 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -29,7 +29,7 @@ Interpreter& Interpreter::inst() { return inst_; } -void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { +Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { auto info = alloc(); info->desc.layout = value.layout(); info->desc.comp_node = value.comp_node(); @@ -39,7 +39,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { return info; } -void* ChannelImpl::put(const DeviceTensorND& data) { +Handle ChannelImpl::put(const DeviceTensorND& data) { auto info = alloc(); info->desc.layout = data.layout(); info->desc.comp_node = data.comp_node(); @@ -48,12 +48,12 @@ void* ChannelImpl::put(const DeviceTensorND& data) { return info; } -void ChannelImpl::del(void* handle) { +void ChannelImpl::del(Handle handle) { mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); m_buffer.enqueue(Del{reinterpret_cast(handle)}); } -void ChannelImpl::swap_in(void* handle) { +void ChannelImpl::swap_in(Handle handle) { if (m_enable_evict & SWAP) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -61,7 +61,7 @@ void ChannelImpl::swap_in(void* handle) { } } -void ChannelImpl::swap_out(void* handle) { +void ChannelImpl::swap_out(Handle handle) { if (m_enable_evict & SWAP) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -69,7 +69,7 @@ void ChannelImpl::swap_out(void* handle) { } } -void ChannelImpl::drop(void* handle) { +void ChannelImpl::drop(Handle handle) { if (m_enable_evict & DROP) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -77,45 +77,91 @@ void ChannelImpl::drop(void* handle) { } } -SmallVector ChannelImpl::apply_op( +void ChannelImpl::dispatch_default_cpu( std::shared_ptr op, - const SmallVector& inputs) { - for (auto i : inputs) { - mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), - "invalid handle: %p", i); - } - SmallVector input_infos; - input_infos.reserve(inputs.size()); - SmallVector input_descs; - input_descs.reserve(inputs.size()); + const SmallVector& input_infos, + const SmallVector& input_descs, + SmallVector* outputs) { + auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); + SmallVector input_tensornds; + input_tensornds.reserve(input_descs.size()); + CompNode output_cn; { MGB_LOCK_GUARD(m_mutex); - for (auto i : inputs) { - auto info = reinterpret_cast(i); - mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); - input_infos.push_back(info); - input_descs.push_back(info->desc); + for (auto&& info : input_infos) { + mgb_assert(info->ptr, "invalid tensor ptr!"); + if (!output_cn.valid()) { + output_cn = info->ptr->comp_node(); + } else { + mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); + } + mgb_assert(info->ptr->try_get_value(), "no valid host value"); + input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); + } + } + + outputs->reserve(output_descs.size()); + SmallVector output_tensornds; + output_tensornds.reserve(output_descs.size()); + for (auto&& desc : output_descs) { + // TODO: may conflict with condtake, which need alloc inside + mgb_assert(!desc.layout.is_empty()); + // use HostTensorND alloc_host for cuda pinned memory + output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); + } + + OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); + + SmallVector output_infos; + output_infos.reserve(output_descs.size()); + for (auto&& tensornd : output_tensornds) { + // tensornd -> host_tensornd + HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) + .proxy_to_comp_node(output_cn); + // tensornd -> desc + LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; + // tensornd -> tensor + auto info = alloc(); + info->desc = desc; + m_valid_handle.insert(info); + output_infos.push_back(info); + info->ptr = Tensor::make(host_tensornd, true); // host_only=true + info->value_fetched = true; + outputs->push_back(info); + } + + if (m_enable_evict & DROP) { + for (auto out : output_infos) { + out->path.op = op; + for (auto out_ : output_infos) { + out->path.outputs.push_back(m_st.at(out_)); + } + for (auto inp : input_infos) { + out->path.inputs.push_back(m_st.at(inp)); + inp->path.dep_outputs.push_back(m_st.at(out)); + } } } +} +void ChannelImpl::dispatch_kernel( + std::shared_ptr op, + const SmallVector& input_infos, + const SmallVector& input_descs, + SmallVector* outputs) { auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); + ApplyOp cmd{std::move(op)}; cmd.inputs = std::move(input_infos); cmd.outputs.reserve(output_descs.size()); - SmallVector outputs; - // FIXME: remove this check when op check is correct - bool validated_bkp = true; - for (size_t i = 0;i < output_descs.size();i ++) { - auto&& desc = output_descs[i]; - if (desc.layout.ndim == 0) { - validated_bkp = false; - } + outputs->reserve(output_descs.size()); + for (auto&& desc : output_descs) { auto info = alloc(); info->desc = desc; m_valid_handle.insert(info); cmd.outputs.push_back(info); - outputs.push_back(info); + outputs->push_back(info); } if (m_enable_evict & DROP) { for (auto out : cmd.outputs) { @@ -130,20 +176,55 @@ SmallVector ChannelImpl::apply_op( } } m_buffer.enqueue(std::move(cmd)); - if (!(validated && validated_bkp) && m_async_level == 1) { + if (!validated && m_async_level == 1) { sync(); } else if (m_async_level == 0) { sync(); // check device error - for (auto&& oup : outputs) { + for (auto&& oup : *outputs) { auto info = reinterpret_cast(oup); info->ptr->comp_node().sync(); } } +} + +SmallVector ChannelImpl::apply_op( + std::shared_ptr op, + const SmallVector& inputs) { + for (auto i : inputs) { + mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), + "invalid handle: %p", i); + } + SmallVector input_infos; + input_infos.reserve(inputs.size()); + SmallVector input_descs; + input_descs.reserve(inputs.size()); + { + MGB_LOCK_GUARD(m_mutex); + for (auto i : inputs) { + auto info = reinterpret_cast(i); + mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); + input_infos.push_back(info); + input_descs.push_back(info->desc); + } + } + + SmallVector outputs; + switch (OpDef::decide_dispatch_mode(*op, input_descs)) { + case DEFAULT_CPU: { + dispatch_default_cpu(op, input_infos, input_descs, &outputs); + break; + } + case KERNEL: { + dispatch_kernel(op, input_infos, input_descs, &outputs); + break; + } + } + mgb_assert(outputs.size() > 0, "Invalid dispatch mode!"); return outputs; } -HostTensorND ChannelImpl::get_value(void* handle) { +HostTensorND ChannelImpl::get_value(Handle handle) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -163,7 +244,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { return info->ptr->get_value(); } -TensorShape ChannelImpl::get_shape(void* handle) { +TensorShape ChannelImpl::get_shape(Handle handle) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -184,7 +265,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { return ret; } -DType ChannelImpl::get_dtype(void* handle) { +DType ChannelImpl::get_dtype(Handle handle) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -193,7 +274,7 @@ DType ChannelImpl::get_dtype(void* handle) { return ret; } -CompNode ChannelImpl::get_device(void* handle) { +CompNode ChannelImpl::get_device(Handle handle) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -202,7 +283,7 @@ CompNode ChannelImpl::get_device(void* handle) { return ret; } -DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { +DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -262,25 +343,15 @@ ChannelImpl::~ChannelImpl() { } void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { - if (notice) { - MGB_LOCK_GUARD(m_mutex); - dest->value_fetched = ptr->value_fetched(); - // update tensor desc for static infer - // if (dest->desc.layout.ndim) { - // mgb_assert(dest->desc.layout.eq_shape(ptr->layout())); - // } - dest->desc.layout = ptr->layout(); - dest->desc.comp_node = ptr->comp_node(); - dest->ptr = std::move(ptr); - if (m_waitee == dest) { - m_cv.notify_all(); - } - } else { - dest->value_fetched = ptr->value_fetched(); - // update tensor desc for static infer - dest->desc.layout = ptr->layout(); - dest->desc.comp_node = ptr->comp_node(); - dest->ptr = std::move(ptr); + auto lock = notice ? std::unique_lock(m_mutex) + : std::unique_lock(); + dest->value_fetched = ptr->value_fetched(); + // update tensor desc for static infer + dest->desc.layout = ptr->layout(); + dest->desc.comp_node = ptr->comp_node(); + dest->ptr = std::move(ptr); + if (notice && m_waitee == dest) { + m_cv.notify_all(); } } @@ -295,7 +366,7 @@ void ChannelImpl::do_swap_out(TensorInfo* dest) { dest->evict_type = SWAP; dest->value_fetched = false; // TODO: swap in parallel - dest->h_value.copy_from(dest->ptr->dev_tensor()).sync(); + dest->h_value = dest->ptr->get_value(); dest->ptr.reset(); } diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 4a5e3fe9e..328fe3ebb 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -198,6 +198,17 @@ private: void do_drop(TensorInfo* dest); void regenerate(TensorInfo* dest, bool must_drop); + void dispatch_default_cpu( + std::shared_ptr op, + const SmallVector& input_infos, + const SmallVector& input_descs, + SmallVector* outputs); + void dispatch_kernel( + std::shared_ptr op, + const SmallVector& input_infos, + const SmallVector& input_descs, + SmallVector* outputs); + std::mutex m_mutex; std::condition_variable m_cv; MemPool m_pool; diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 995f43c5f..2ec350d44 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -30,12 +30,26 @@ std::shared_ptr OpDef::make_from_op_node( return trait->make_from_op_node(node); } +DispatchMode OpDef::decide_dispatch_mode( + const OpDef& def, + const SmallVector& inputs) { + return def.trait()->decide_dispatch_mode(def, inputs); +} + SmallVector OpDef::apply_on_physical_tensor( const OpDef& def, SmallVector inputs) { return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); } +void OpDef::apply_on_device_tensornd( + const OpDef& def, + const SmallVector& inputs, + SmallVector* outputs) { + def.trait()->apply_on_device_tensornd(def, inputs, outputs); + return; +} + VarNodeArray OpDef::apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { diff --git a/imperative/src/impl/op_trait.cpp b/imperative/src/impl/op_trait.cpp index e2e3d19a1..6f665bd69 100644 --- a/imperative/src/impl/op_trait.cpp +++ b/imperative/src/impl/op_trait.cpp @@ -9,12 +9,16 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include #include +#include +#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/imperative/proxy_graph_detail.h" +#include "megbrain/tensor.h" #include "./op_trait.h" -#include "megbrain/imperative/proxy_graph_detail.h" namespace mgb { namespace imperative { @@ -62,6 +66,12 @@ void OpTrait::for_each_trait(thin_function visitor){ } } +DispatchMode fallback_decide_dispatch_mode( + const OpDef& def, + const SmallVector& inputs) { + return KERNEL; +} + OpTraitRegistry& OpTraitRegistry::fallback() { if (trait->apply_on_var_node) { // fallback to proxy graph impl @@ -78,6 +88,9 @@ OpTraitRegistry& OpTraitRegistry::fallback() { proxy_graph_detail::make_backward_graph; } } + if (!trait->decide_dispatch_mode) { + trait->decide_dispatch_mode = fallback_decide_dispatch_mode; + } return *this; } diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 8f1348331..c8b5c1629 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -60,8 +60,12 @@ struct ToVarNodeArray: std::true_type { using OpDefMaker = detail::OpMeth< decltype(OpDef::make_from_op_node)>; +using DecideDispatchMode = detail::OpMeth< + decltype(OpDef::decide_dispatch_mode)>; using ApplyOnPhysicalTensor = detail::OpMeth< decltype(OpDef::apply_on_physical_tensor)>; +using ApplyOnDeviceTensorND = detail::OpMeth< + decltype(OpDef::apply_on_device_tensornd)>; using ApplyOnVarNode = detail::OpMeth< decltype(OpDef::apply_on_var_node)>; using InferOutputAttrsFallible = detail::OpMeth< @@ -74,7 +78,9 @@ using IsSame = detail::OpMeth; struct OpTrait { const char* name; OpDefMaker make_from_op_node; + DecideDispatchMode decide_dispatch_mode; ApplyOnPhysicalTensor apply_on_physical_tensor; + ApplyOnDeviceTensorND apply_on_device_tensornd; ApplyOnVarNode apply_on_var_node; InferOutputAttrsFallible infer_output_attrs_fallible; GradMaker make_backward_graph; @@ -88,7 +94,9 @@ struct OpTrait { #define FOR_EACH_OP_METH(cb) \ cb(make_from_op_node) \ + cb(decide_dispatch_mode) \ cb(apply_on_physical_tensor) \ + cb(apply_on_device_tensornd) \ cb(apply_on_var_node) \ cb(infer_output_attrs_fallible) \ cb(make_backward_graph) \ diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 2a30f544b..bdd8b5b56 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -68,23 +68,46 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; } -SmallVector apply_on_physical_tensor( +DispatchMode decide_dispatch_mode( const OpDef& def, - const SmallVector& inputs) { + const SmallVector& inputs) { + bool host_computable = true; + constexpr int size_threshhold = TensorShape::MAX_NDIM; + for (auto&& inp : inputs) { + if (inp.value.empty() || inp.value.layout().ndim == 0 + || inp.value.layout().total_nr_elems() > size_threshhold) { + host_computable = false; + break; + } + } + return host_computable ? DEFAULT_CPU : KERNEL; +} + + +void apply_on_device_tensornd( + const OpDef& def, + const SmallVector& inputs, + SmallVector* outputs) { auto&& op_def = def.cast_final_safe(); auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); mgb_assert(inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually", trait.name, trait.arity, inputs.size()); + auto&& dnn_opr = opr::intl::create_megdnn_opr(inputs[0].comp_node()); + opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, + const SmallVector& inputs) { - DeviceTensorND out; - SmallVector dt_inputs(inputs.size()); + SmallVector inp_tensornds(inputs.size()); for (unsigned i = 0; i < inputs.size(); ++i){ - dt_inputs[i] = inputs[i]->dev_tensor(); + inp_tensornds[i] = inputs[i]->dev_tensor(); } - auto&& dnn_opr = opr::intl::create_megdnn_opr(inputs[0]->comp_node()); - opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr); - return {Tensor::make(out)}; + SmallVector oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}}; + apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); + return {Tensor::make(oup_tensornds[0])}; } MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT) //{ @@ -214,8 +237,10 @@ std::tuple, bool> infer_inplace_add_output_attrs_ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) .make_from_op_node(make_from_op_node) + .decide_dispatch_mode(decide_dispatch_mode) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index a52189617..ae5b4c68f 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -15,8 +15,8 @@ #include "../op_trait.h" namespace mgb::imperative { -namespace { +namespace get_var_shape { cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { @@ -24,17 +24,38 @@ cg::OperatorNodeBase* apply_on_var_node( return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); } -SmallVector apply_on_physical_tensor( +DispatchMode decide_dispatch_mode( const OpDef& def, - const SmallVector& inputs) { + const SmallVector& inputs) { + bool host_computable = true; + for (auto&& inp : inputs) { + // FIXME(czh): remove value chech after proxy graph's + // apply_on_device_tensornd is supported and output Tensor + // is made before add_task. + // then if layout is valid, ptr->layout must be ready + if (inp.value.empty() || inp.value.layout().ndim == 0) { + host_computable = false; + break; + } + } + return host_computable ? DEFAULT_CPU : KERNEL; +} + +void apply_on_device_tensornd( + const OpDef& def, + const SmallVector& inputs, + SmallVector* outputs) { auto&& op_def = def.cast_final_safe(); mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); auto&& inp = inputs[0]; - auto&& shp = inp->layout(); + auto&& shp = inp.layout(); mgb_assert(shp.ndim != 0, "input shape invalid"); + mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), + "GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); + HostTensorND hv; - if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ - hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32()); + if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { + hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32()); auto* ptr = hv.ptr(); for (size_t i = 0; i < shp.ndim; ++i) { ptr[i] = shp.shape[i]; @@ -45,11 +66,29 @@ SmallVector apply_on_physical_tensor( axis += shp.ndim; } mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); - hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); + hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); auto* ptr = hv.ptr(); ptr[0] = shp.shape[axis]; } - return {Tensor::make(std::move(hv))}; + (*outputs)[0] = DeviceTensorND::make_proxy(hv); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, + const SmallVector& inputs) { + SmallVector input_tensornds; + input_tensornds.reserve(inputs.size()); + for (auto&& inp : inputs) { + input_tensornds.push_back(inp->dev_tensor()); + } + SmallVector output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}}; + + apply_on_device_tensornd(def, input_tensornds, &output_tensornds); + + // restore to input comp_node + HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0]) + .proxy_to_comp_node(inputs[0]->comp_node()); + return {Tensor::make(std::move(host_tensornd))}; } std::tuple, bool> infer_output_attrs_fallible( @@ -62,7 +101,7 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; } DeviceTensorND value; - if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ + if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); auto* ptr = value.ptr(); for (size_t i = 0; i < desc.layout.ndim; ++i) { @@ -88,11 +127,15 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) .make_from_op_node(make_from_op_node) + .decide_dispatch_mode(decide_dispatch_mode) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_var_node(apply_on_var_node) + .apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); +} // get_var_shape +namespace param_pack { TensorShapeArray get_shapes(const std::vector>& shapes) { TensorShapeArray ret; for (auto&& i:shapes) { @@ -156,6 +199,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) .apply_on_var_node(param_pack_concat_apply_on_var_node) .fallback(); -} // namespace +} // param_pack } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 92238e775..5f250ac42 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -20,6 +20,11 @@ namespace imperative { class OpDef; struct OpTrait; +enum DispatchMode { + DEFAULT_CPU = 0, + KERNEL = 1 +}; + struct BackwardGraphResult { std::shared_ptr backward; std::vector save_for_backward; @@ -36,10 +41,31 @@ public: static std::shared_ptr make_from_op_node( cg::OperatorNodeBase* node); + /*! + * \brief Decide which dispatch method to be used according to the inputs' + * host value and size. + * + * \param def Specific :c:expr:`OpDef` to be executed. + * \param inputs Input tensor descriptions. + * \return Which DispatchMode to be used, such as `CUDA` or `DEFAULT_CPU`. + */ + static DispatchMode decide_dispatch_mode( + const OpDef& def, + const SmallVector& inputs); + static SmallVector apply_on_physical_tensor( const OpDef& def, SmallVector inputs); + /*! + * \brief Call the corresponding dnn op to calculate results. Output + * tensors' device memory should be allocated outside. + */ + static void apply_on_device_tensornd( + const OpDef& def, + const SmallVector& inputs, + SmallVector* outputs); + static cg::VarNodeArray apply_on_var_node( const OpDef& def, const VarNodeArray& inputs); -- GitLab