diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 7d45232b46444d6db491785ff5afe02edacf2f5f..cf3b51f2fa6d42b043df368779dec535d9c6fe96 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -114,15 +114,44 @@ void apply_on_device_tensornd( SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { - SmallVector inp_tensornds(inputs.size()); + auto comp_node = inputs[0]->comp_node(); + using Mode = Elemwise::Mode; + using TensorND = megdnn::TensorND; + auto&& op_def = def.cast_final_safe(); + SmallVector inp_tensornds; + TensorShapeArray inp_shapes(inputs.size()); + inp_tensornds.reserve(inputs.size()); + + TensorLayout layout{inputs[0]->layout().dtype}; + bool is_empty = false; for (unsigned i = 0; i < inputs.size(); ++i) { - inp_tensornds[i] = inputs[i]->dev_tensor(); + if (inputs[i]->layout().is_empty()) { + is_empty = true; + } + inp_tensornds.push_back(inputs[i]->dnn_tensor()); + inp_shapes[i] = inputs[i]->layout(); + } + megdnn::Elemwise::deduce_shape(inp_shapes, layout); + layout.init_contiguous_stride(); + + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout); + if (is_empty) { + return {Tensor::make(out)}; + } + auto&& dnn_opr = opr::intl::create_megdnn_opr(comp_node); + + dnn_opr->param() = op_def.param(); + if (dnn_opr->param().mode == Mode::FUSE_MUL_ADD3 || + dnn_opr->param().mode == Mode::FUSE_MUL_ADD4 || + (inp_tensornds.size() && + inp_tensornds[0].layout.dtype.category() == DTypeCategory::QUANTIZED)) { + opr::Elemwise::perform_dnn(comp_node, out, inp_tensornds, dnn_opr); + } else { + dnn_opr->exec(inp_tensornds, out.as_megdnn()); } - DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( - inp_tensornds[0].comp_node(), output_descs[0].layout); - SmallVector oup_tensornds = {out}; - apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); - return {Tensor::make(oup_tensornds[0])}; + + return {Tensor::make(out)}; } MGB_DEFINE_OPR_CLASS( diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index 2a0994d7bf47f852039d01e1e1468e2a8afd142e..bfda63687dd2eb04ed2c323243fb4353a9a23096 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -212,6 +212,11 @@ DeviceTensorND Tensor::dev_tensor(bool contiguous) { return ret; } +megdnn::TensorND Tensor::dnn_tensor() { + mgb_assert(m_blob, "uninitialized tensor."); + return {m_layout, {m_blob->storage().get(), m_offset}}; +} + void Tensor::fetch_value() { MGB_LOCK_GUARD(m_blob_mtx); MGB_LOCK_GUARD(m_value_mtx); diff --git a/imperative/src/include/megbrain/imperative/physical_tensor.h b/imperative/src/include/megbrain/imperative/physical_tensor.h index ca2232f00bc27a0be37469e2b29717be769032db..092bdcbfe0444e6978c5c455795d79b537519248 100644 --- a/imperative/src/include/megbrain/imperative/physical_tensor.h +++ b/imperative/src/include/megbrain/imperative/physical_tensor.h @@ -110,6 +110,8 @@ public: void assign_from_dev_tensor(DeviceTensorND); + megdnn::TensorND dnn_tensor(); + static TensorPtr make_scalar(DTypeScalar value, CompNode cn); TensorPtr make_scalar(DTypeScalar value) const { diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 3d05f401b77fe7b297a968b6cf6e733d6152224e..fc6c02b00ce8ba62b1e7690a5e09e91982467af6 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -268,6 +268,12 @@ void Elemwise::perform( call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr); } +void Elemwise::perform_dnn( + CompNode cn, DeviceTensorND& dest, megdnn::TensorNDArray& inputs, + intl::UniqPtrWithCN& opr) { + call_megdnn_opr_exec(cn, inputs, dest.as_megdnn(), opr.get(), nullptr); +} + TensorLayoutArray Elemwise::collective_collapse(const TensorLayoutArray& layouts) { TensorLayoutPtrArray inp(layouts.size()); TensorLayoutArray result(inp.size()); diff --git a/src/opr/include/megbrain/opr/basic_arith.h b/src/opr/include/megbrain/opr/basic_arith.h index 45f0123e1b38cb9b57ae1f6614507317b71cbbb9..771aff8104218a813eff6e21edeaaaf746463d0f 100644 --- a/src/opr/include/megbrain/opr/basic_arith.h +++ b/src/opr/include/megbrain/opr/basic_arith.h @@ -88,6 +88,10 @@ public: Mode mode, DeviceTensorND& dest, const SmallVector& inputs, intl::UniqPtrWithCN& opr); + MGE_WIN_DECLSPEC_FUC static void perform_dnn( + CompNode cn, DeviceTensorND& dest, megdnn::TensorNDArray& inputs, + intl::UniqPtrWithCN& opr); + using TensorLayoutPtrArray = SmallVector; /*!