提交 84466261 编写于 作者: M Megvii Engine Team

perf(imperative/src): improve elemwise

GitOrigin-RevId: 78aa487277b20bf08698ef1e100a7d4b0cc4df15
上级 e400b7ff
...@@ -114,15 +114,44 @@ void apply_on_device_tensornd( ...@@ -114,15 +114,44 @@ void apply_on_device_tensornd(
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs, const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); auto comp_node = inputs[0]->comp_node();
using Mode = Elemwise::Mode;
using TensorND = megdnn::TensorND;
auto&& op_def = def.cast_final_safe<Elemwise>();
SmallVector<TensorND> inp_tensornds;
TensorShapeArray inp_shapes(inputs.size());
inp_tensornds.reserve(inputs.size());
TensorLayout layout{inputs[0]->layout().dtype};
bool is_empty = false;
for (unsigned i = 0; i < inputs.size(); ++i) { for (unsigned i = 0; i < inputs.size(); ++i) {
inp_tensornds[i] = inputs[i]->dev_tensor(); if (inputs[i]->layout().is_empty()) {
is_empty = true;
}
inp_tensornds.push_back(inputs[i]->dnn_tensor());
inp_shapes[i] = inputs[i]->layout();
}
megdnn::Elemwise::deduce_shape(inp_shapes, layout);
layout.init_contiguous_stride();
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
if (is_empty) {
return {Tensor::make(out)};
} }
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(comp_node);
inp_tensornds[0].comp_node(), output_descs[0].layout);
SmallVector<DeviceTensorND> oup_tensornds = {out}; dnn_opr->param() = op_def.param();
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); if (dnn_opr->param().mode == Mode::FUSE_MUL_ADD3 ||
return {Tensor::make(oup_tensornds[0])}; dnn_opr->param().mode == Mode::FUSE_MUL_ADD4 ||
(inp_tensornds.size() &&
inp_tensornds[0].layout.dtype.category() == DTypeCategory::QUANTIZED)) {
opr::Elemwise::perform_dnn(comp_node, out, inp_tensornds, dnn_opr);
} else {
dnn_opr->exec(inp_tensornds, out.as_megdnn());
}
return {Tensor::make(out)};
} }
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
......
...@@ -212,6 +212,11 @@ DeviceTensorND Tensor::dev_tensor(bool contiguous) { ...@@ -212,6 +212,11 @@ DeviceTensorND Tensor::dev_tensor(bool contiguous) {
return ret; return ret;
} }
megdnn::TensorND Tensor::dnn_tensor() {
mgb_assert(m_blob, "uninitialized tensor.");
return {m_layout, {m_blob->storage().get(), m_offset}};
}
void Tensor::fetch_value() { void Tensor::fetch_value() {
MGB_LOCK_GUARD(m_blob_mtx); MGB_LOCK_GUARD(m_blob_mtx);
MGB_LOCK_GUARD(m_value_mtx); MGB_LOCK_GUARD(m_value_mtx);
......
...@@ -110,6 +110,8 @@ public: ...@@ -110,6 +110,8 @@ public:
void assign_from_dev_tensor(DeviceTensorND); void assign_from_dev_tensor(DeviceTensorND);
megdnn::TensorND dnn_tensor();
static TensorPtr make_scalar(DTypeScalar value, CompNode cn); static TensorPtr make_scalar(DTypeScalar value, CompNode cn);
TensorPtr make_scalar(DTypeScalar value) const { TensorPtr make_scalar(DTypeScalar value) const {
......
...@@ -268,6 +268,12 @@ void Elemwise::perform( ...@@ -268,6 +268,12 @@ void Elemwise::perform(
call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr); call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr);
} }
void Elemwise::perform_dnn(
CompNode cn, DeviceTensorND& dest, megdnn::TensorNDArray& inputs,
intl::UniqPtrWithCN<megdnn::Elemwise>& opr) {
call_megdnn_opr_exec(cn, inputs, dest.as_megdnn(), opr.get(), nullptr);
}
TensorLayoutArray Elemwise::collective_collapse(const TensorLayoutArray& layouts) { TensorLayoutArray Elemwise::collective_collapse(const TensorLayoutArray& layouts) {
TensorLayoutPtrArray inp(layouts.size()); TensorLayoutPtrArray inp(layouts.size());
TensorLayoutArray result(inp.size()); TensorLayoutArray result(inp.size());
......
...@@ -88,6 +88,10 @@ public: ...@@ -88,6 +88,10 @@ public:
Mode mode, DeviceTensorND& dest, const SmallVector<DeviceTensorND>& inputs, Mode mode, DeviceTensorND& dest, const SmallVector<DeviceTensorND>& inputs,
intl::UniqPtrWithCN<megdnn::Elemwise>& opr); intl::UniqPtrWithCN<megdnn::Elemwise>& opr);
MGE_WIN_DECLSPEC_FUC static void perform_dnn(
CompNode cn, DeviceTensorND& dest, megdnn::TensorNDArray& inputs,
intl::UniqPtrWithCN<megdnn::Elemwise>& opr);
using TensorLayoutPtrArray = SmallVector<TensorLayout*>; using TensorLayoutPtrArray = SmallVector<TensorLayout*>;
/*! /*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册