提交 2b80806f 编写于 作者: M Megvii Engine Team

perf(imperative/src): improve dot performance

GitOrigin-RevId: 35b5bd164ffde647125f8fa9c0ebf91195dc4f1f
上级 2f3bc2db
......@@ -150,7 +150,8 @@ public:
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
......
......@@ -33,7 +33,7 @@ void DotForward::check_exec(
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void DotForward::deduce_layout(
MGE_WIN_DECLSPEC_FUC void DotForward::deduce_layout(
const TensorLayout& A, const TensorLayout&, TensorLayout& C) {
C = TensorLayout(TensorShape{1}, A.dtype);
}
......
......@@ -39,6 +39,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "../blob_manager_impl.h"
#include "../op_trait.h"
namespace mgb::imperative {
......@@ -319,7 +320,70 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
OperatorNodeConfig config{op.make_name()};
return opr::Dot::make(inputs[0], inputs[1], config);
}
OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback();
// std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
// auto* node = &node_->cast_final_safe<opr::Dot>();
// return Dot::make(node->param());
// }
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto a = inputs[0]->layout();
auto comp_node = inputs[0]->comp_node();
using TensorND = megdnn::TensorND;
SmallVector<TensorND> inp_tensornds;
inp_tensornds.reserve(inputs.size());
auto dnn_opr = opr::intl::create_megdnn_opr<megdnn::Dot>(comp_node);
for (unsigned i = 0; i < inputs.size(); ++i) {
auto dnn_ten = inputs[i]->dnn_tensor();
inp_tensornds.push_back(dnn_ten);
}
TensorLayout oup_layout{inputs[0]->dtype()};
auto inp1_tensor = inputs[0]->dnn_tensor();
auto inp2_tensor = inputs[1]->dnn_tensor();
dnn_opr->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout);
if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) {
auto fill_opr = opr::intl::create_megdnn_opr<megdnn::Fill>(comp_node);
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
fill_opr->param() = 0;
fill_opr->exec(out.as_megdnn(), {});
return {Tensor::make(out)};
}
auto wk_size = dnn_opr->get_workspace_in_bytes(
inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout);
DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
TensorLayout wk_layout{TensorShape{wk_size}, inputs[0]->dtype()};
DeviceTensorND workspace =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, wk_layout);
megdnn::Workspace dnn_wk(workspace.raw_ptr(), wk_size);
dnn_opr->exec(
inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk);
return {Tensor::make(out_devtensor)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Dot>();
SmallVector<LogicalTensorDesc> dests(1);
dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype);
dests[0].comp_node = inputs[0].comp_node;
return {dests, true};
}
OP_TRAIT_REG(Dot, Dot, opr::Dot)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace dot
} // namespace
......
......@@ -88,7 +88,7 @@ private:
/*!
* \brief dot product of two tensors
*/
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Dot, cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<megdnn::Dot>>) // {
public:
MGE_WIN_DECLSPEC_FUC Dot(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册