diff --git a/dnn/include/megdnn/oprs/linalg.h b/dnn/include/megdnn/oprs/linalg.h index 54931275be00eebd7e39dcf9d1c9c968d27fdab5..10722bb33f0f71bcb30713a4e21dc4db887fddcc 100644 --- a/dnn/include/megdnn/oprs/linalg.h +++ b/dnn/include/megdnn/oprs/linalg.h @@ -150,7 +150,8 @@ public: virtual void exec( _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); + MGE_WIN_DECLSPEC_FUC void deduce_layout( + const TensorLayout& A, const TensorLayout& B, TensorLayout& C); virtual size_t get_workspace_in_bytes( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; diff --git a/dnn/src/common/dot.cpp b/dnn/src/common/dot.cpp index ddae6c0cd8296bfa95bcb0d8a56aadc422ddce8a..4b001f427d47428cc3fb68041a2164cbed4cdc1c 100644 --- a/dnn/src/common/dot.cpp +++ b/dnn/src/common/dot.cpp @@ -33,7 +33,7 @@ void DotForward::check_exec( megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void DotForward::deduce_layout( +MGE_WIN_DECLSPEC_FUC void DotForward::deduce_layout( const TensorLayout& A, const TensorLayout&, TensorLayout& C) { C = TensorLayout(TensorShape{1}, A.dtype); } diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 80868a87ed0e76043ec8c5bbeda40802afc88e31..dae2eee12f5f9f58bd1a2bd315c5a3aab25a8d6f 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -39,6 +39,7 @@ #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" +#include "../blob_manager_impl.h" #include "../op_trait.h" namespace mgb::imperative { @@ -319,7 +320,70 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { OperatorNodeConfig config{op.make_name()}; return opr::Dot::make(inputs[0], inputs[1], config); } -OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback(); + +// std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { +// auto* node = &node_->cast_final_safe(); +// return Dot::make(node->param()); +// } + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto a = inputs[0]->layout(); + auto comp_node = inputs[0]->comp_node(); + using TensorND = megdnn::TensorND; + SmallVector inp_tensornds; + inp_tensornds.reserve(inputs.size()); + auto dnn_opr = opr::intl::create_megdnn_opr(comp_node); + for (unsigned i = 0; i < inputs.size(); ++i) { + auto dnn_ten = inputs[i]->dnn_tensor(); + inp_tensornds.push_back(dnn_ten); + } + TensorLayout oup_layout{inputs[0]->dtype()}; + auto inp1_tensor = inputs[0]->dnn_tensor(); + auto inp2_tensor = inputs[1]->dnn_tensor(); + dnn_opr->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout); + + if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) { + auto fill_opr = opr::intl::create_megdnn_opr(comp_node); + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); + fill_opr->param() = 0; + fill_opr->exec(out.as_megdnn(), {}); + return {Tensor::make(out)}; + } + + auto wk_size = dnn_opr->get_workspace_in_bytes( + inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout); + + DeviceTensorND out_devtensor = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); + TensorLayout wk_layout{TensorShape{wk_size}, inputs[0]->dtype()}; + DeviceTensorND workspace = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, wk_layout); + megdnn::Workspace dnn_wk(workspace.raw_ptr(), wk_size); + + dnn_opr->exec( + inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk); + + return {Tensor::make(out_devtensor)}; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + SmallVector dests(1); + dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype); + dests[0].comp_node = inputs[0].comp_node; + return {dests, true}; +} + +OP_TRAIT_REG(Dot, Dot, opr::Dot) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); + } // namespace dot } // namespace diff --git a/src/opr/include/megbrain/opr/blas.h b/src/opr/include/megbrain/opr/blas.h index e5983ab83b63f1f5cfe52abd8bd9374ce7f5f354..2d89d2e2befc37a733e52cc9c9c632633629edec 100644 --- a/src/opr/include/megbrain/opr/blas.h +++ b/src/opr/include/megbrain/opr/blas.h @@ -88,7 +88,7 @@ private: /*! * \brief dot product of two tensors */ -MGB_DEFINE_OPR_CLASS( +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( Dot, cg::SingleCNOperatorNodeBaseT>) // { public: MGE_WIN_DECLSPEC_FUC Dot(