提交 f7e10ea8 编写于 作者: M Megvii Engine Team

perf(imperative): improve matmul/batch_matmul

GitOrigin-RevId: 4ceb2eb60148113dd789416d604f0e4f76a4ec7c
上级 1c2a323e
......@@ -20,9 +20,10 @@ from .._imperative_rt.core2 import (
Tensor,
apply,
astype_cpp,
batched_matmul_cpp,
broadcast_cpp,
dtype_promotion,
getitem_cpp,
matmul_cpp,
)
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp
......@@ -266,6 +267,42 @@ class _Hashable:
return self.value == o.value
def symbolicMatrixMul(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy
):
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result
def symbolicBatchedMatrixMul(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy
):
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result
def _matmul(
inp1,
inp2,
......@@ -274,16 +311,6 @@ def _matmul(
compute_mode="default",
format="default",
):
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype:
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)
dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
......@@ -301,34 +328,46 @@ def _matmul(
if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
elif maxdim <= 2 or (dim2 <= 2 and not transpose_a): # dispath to MatrixMul
# 2x1
# 1x2
# 2x2
# nx1(transpose_a=False), n>=3
# nx2(transpose_a=False), n>=3
return matmul_cpp(
inp1,
inp2,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
_config._benchmark_kernel,
_config._deterministic_kernel,
strategy,
symbolicMatrixMul,
)
(result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result
else: # dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
# nx1(transpose_a=True), n>=3
# nx2(transpose_a=True), n>=3
# nxm,n>=3,m>=3
# 1xm,m>=3
# 2xm,m>=3
return batched_matmul_cpp(
inp1,
inp2,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
_config._benchmark_kernel,
_config._deterministic_kernel,
strategy,
symbolicBatchedMatrixMul,
)
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result
def _unary_elwise(mode):
......
......@@ -10,7 +10,7 @@ import collections
import math
from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.core2 import Const, apply
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
from ..core.tensor.array_method import _matmul
......
......@@ -17,7 +17,6 @@ from ..core._imperative_rt.core2 import (
apply,
dtype_promotion,
)
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core.ops import builtin
from ..core.ops.builtin import (
......@@ -177,16 +176,6 @@ def conv1d(
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
if bias is not None:
assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
......@@ -522,12 +511,6 @@ def local_conv2d(
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
# local conv only support "dense" mode, but weight could contain group dimension.
op = builtin.GroupLocal(
stride_h=stride_h,
......
......@@ -433,6 +433,8 @@ WRAP_FUNC_PY35(reshape_cpp);
WRAP_FUNC_PY35(adaptive_pool2d_cpp);
WRAP_FUNC_PY35(Const);
WRAP_FUNC_PY35(astype_cpp);
WRAP_FUNC_PY35(matmul_cpp);
WRAP_FUNC_PY35(batched_matmul_cpp);
WRAP_FUNC_PY35(convert_single_value_cpp);
WRAP_FUNC_PY35(convert_inputs_cpp);
WRAP_FUNC_PY35(astensor1d_cpp);
......@@ -588,6 +590,8 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
MGE_PY_INTERFACE(Const, Const),
MGE_PY_INTERFACE(astype_cpp, astype_cpp),
MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
......
......@@ -1490,6 +1490,78 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
return ret[0];
}
py::object _matmul_cpp(
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
py::handle format, py::handle profile, py::handle determistic,
py::handle strategy, py::handle func) {
if (enable_fastpath(inp1)) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
::megdnn::param::ExecutionPolicy::Strategy cstrategy;
if (profile.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
} else {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = MatrixMul::make(
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX);
py::object Op = py::cast(op);
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret[0];
} else {
// fallback to traceable implementation
return func(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format,
strategy);
}
}
py::object _batched_matmul_cpp(
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
py::handle format, py::handle profile, py::handle determistic,
py::handle strategy, py::handle func) {
if (enable_fastpath(inp1)) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
::megdnn::param::ExecutionPolicy::Strategy cstrategy;
if (profile.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
} else {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = BatchedMatrixMul::make(
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX);
py::object Op = py::cast(op);
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret[0];
} else {
// fallback to traceable implementation
return func(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format,
strategy);
}
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(args[0]).release().ptr();
......@@ -1574,6 +1646,28 @@ PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _matmul_cpp(
args[0], args[1], args[2], args[3], args[4], args[5], args[6],
args[7], args[8], args[9], args[10], args[11])
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _batched_matmul_cpp(
args[0], args[1], args[2], args[3], args[4], args[5], args[6],
args[7], args[8], args[9], args[10], args[11])
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* convert_single_value_cpp(
PyObject* self, PyObject* const* args, size_t nargs) {
try {
......
......@@ -30,6 +30,10 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs);
......
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/utility.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
namespace dot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Dot>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Dot::make(inputs[0], inputs[1], config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto comp_node = inputs[0]->comp_node();
using TensorND = megdnn::TensorND;
SmallVector<TensorND> inp_tensornds;
inp_tensornds.reserve(inputs.size());
DnnOprCaller<megdnn::Dot> dnn_opr(comp_node);
for (unsigned i = 0; i < inputs.size(); ++i) {
auto dnn_ten = inputs[i]->dnn_tensor();
inp_tensornds.push_back(dnn_ten);
}
TensorLayout oup_layout{inputs[0]->dtype()};
auto inp1_tensor = inputs[0]->dnn_tensor();
auto inp2_tensor = inputs[1]->dnn_tensor();
dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout);
if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) {
DnnOprCaller<megdnn::Fill> fill_opr(comp_node);
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
fill_opr.op->param() = 0;
fill_opr.op->exec(out.as_megdnn(), {});
return {Tensor::make(out)};
}
auto sz = dnn_opr.op->get_workspace_in_bytes(
inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout);
DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
TensorLayout w_layout({sz}, dtype::Byte());
auto dnn_wk = dnn_opr.create_workspace(w_layout);
dnn_opr.op->exec(
inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk);
return {Tensor::make(out_devtensor)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(
inputs.size() == 2, "Dot expects 2 inputs; got %lu actually",
inputs.size());
SmallVector<LogicalTensorDesc> dests(1);
dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype);
dests[0].comp_node = inputs[0].comp_node;
bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0;
return {dests, validated};
}
OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace dot
} // anonymous namespace
} // namespace imperative
} // namespace mgb
\ No newline at end of file
#include <numeric>
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/blas.h"
#include "../algo_chooser.h"
namespace mgb {
namespace imperative {
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = def.cast_final_safe<MatrixMul>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::MatrixMul::make(
inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& matmul = def.cast_final_safe<MatrixMul>();
auto layout1 = inputs[0].layout;
auto layout2 = inputs[1].layout;
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false};
}
if (matmul.transposeA)
std::swap(layout1[0], layout1[1]);
if (matmul.transposeB)
std::swap(layout2[0], layout2[1]);
mgb_assert(layout1[dim1 - 1] == layout2[0]);
TensorLayout dst_layout(layout1.dtype);
size_t ci = 0;
for (size_t i = 0; i < dim1 - 1; i++)
dst_layout[ci++] = layout1[i];
if (dim2 == 2)
dst_layout[ci++] = layout2[1];
dst_layout.ndim = ci;
dst_layout.init_contiguous_stride();
SmallVector<LogicalTensorDesc> out_descs(1u);
out_descs[0] = {dst_layout, inputs[0].comp_node};
return {out_descs, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& matmul = def.cast_final_safe<MatrixMul>();
auto&& cn = inputs[0]->comp_node();
using TensorND = megdnn::TensorND;
SmallVector<TensorND> inp_tensornds(inputs.size());
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout();
// only matters when layout1 has dim 2
if (matmul.transposeA)
std::swap(layout1.shape[0], layout1.shape[1]);
// only matters when layout2 has dim 2
if (matmul.transposeB)
std::swap(layout2.shape[0], layout2.shape[1]);
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
TensorLayout real_dst_layout(layout1.dtype);
if (validated) {
real_dst_layout = output_descs[0].layout;
} else {
size_t ri = 0;
for (size_t i = 0; i < dim1 - 2; i++)
real_dst_layout[ri++] = layout1[i];
real_dst_layout[ri++] = layout1[dim1 - 2];
if (dim2 == 2)
real_dst_layout[ri++] = layout2[dim2 - 1];
real_dst_layout.ndim = ri;
real_dst_layout.init_contiguous_stride();
}
if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) {
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, real_dst_layout);
if (!out.empty()) {
dev_tensor_memset(out, 0);
}
return {Tensor::make(out)};
}
TensorLayout layout_a = layout1, layout_b = layout2;
if (dim1 == 1) {
layout_a.add_axis_cont_inplace(0);
inp_tensornds[0] = inputs[0]->dnn_tensor();
inp_tensornds[0].layout = layout_a;
} else if (dim1 > 2) {
size_t batch = std::accumulate(
layout1.shape, layout1.shape + dim1 - 1, (size_t)1,
std::multiplies<size_t>());
TensorShape na = TensorShape{batch, layout1[dim1 - 1]};
auto inp1 = inputs[0];
if (!layout1.try_reshape(layout_a, na)) {
inp1 = Tensor::make(inp1->blob(), inp1->offset(), layout1);
inp1->to_contiguous_inplace();
layout1 = inp1->layout();
layout_a = TensorLayout{{batch, layout1[dim1 - 1]}, layout1.dtype};
}
layout_a.init_contiguous_stride();
inp_tensornds[0] = inp1->dnn_tensor();
inp_tensornds[0].layout = layout_a;
} else {
inp_tensornds[0] = inputs[0]->dnn_tensor();
}
if (dim2 == 1) {
layout_b.add_axis_inplace(1, 1, 1);
inp_tensornds[1] = inputs[1]->dnn_tensor();
inp_tensornds[1].layout = layout_b;
} else {
inp_tensornds[1] = inputs[1]->dnn_tensor();
}
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype);
dst_layout.init_contiguous_stride();
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn);
dnn_opr.op->param() = matmul.param();
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout);
size_t sz = setup_algo<megdnn::MatrixMul>(
{layout_a, layout_b, dst_layout}, dnn_opr.op.get(), 0, false, false, cn,
matmul.policy(), false);
TensorLayout w_layout({sz}, dtype::Byte());
auto dnn_wk = dnn_opr.create_workspace(w_layout);
dnn_opr.op->exec(inp_tensornds[0], inp_tensornds[1], out.as_megdnn(), dnn_wk);
return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(real_dst_layout)))};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
return layout_checker;
}
OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace matrix_mul
} // namespace
namespace {
namespace batched_matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::BatchedMatrixMul::make(
inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>();
TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout;
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false};
}
if (matmul.transposeA)
std::swap(layout1[dim1 - 1], layout1[dim1 - 2]);
if (matmul.transposeB)
std::swap(layout2[dim2 - 1], layout2[dim2 - 2]);
TensorLayout dst_layout(layout1.dtype);
size_t di = 0;
if (dim1 > dim2) {
for (size_t i = 0; i < dim1 - 2; i++)
dst_layout[di++] = layout1[i];
} else {
for (size_t i = 0; i < dim2 - 2; i++)
dst_layout[di++] = layout2[i];
}
if (dim1 > 1)
dst_layout[di++] = layout1[dim1 - 2];
if (dim2 > 1)
dst_layout[di++] = layout2[dim2 - 1];
dst_layout.ndim = di;
dst_layout.init_contiguous_stride();
SmallVector<LogicalTensorDesc> out_descs(1u);
out_descs[0] = {dst_layout, inputs[0].comp_node};
return {out_descs, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>();
auto&& cn = inputs[0]->comp_node();
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout();
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
bool remove_row = false, remove_col = false;
if (dim1 == 1) {
dim1 = 2;
remove_row = true;
}
if (dim2 == 1) {
dim2 = 2;
remove_col = true;
}
if (remove_row)
layout1.add_axis_cont_inplace(0);
if (remove_col)
layout2.add_axis_inplace(1, 1, 1);
TensorShape tshp, batch_shp;
size_t j = 0;
if (dim1 > dim2) {
for (size_t i = 0; i < dim1 - 2; i++)
tshp[j++] = layout1.shape[i];
batch_shp = tshp;
batch_shp.ndim = dim1 - 2;
tshp[j++] = layout2[layout2.ndim - 2];
tshp[j++] = layout2[layout2.ndim - 1];
tshp.ndim = j;
layout2 = layout2.broadcast(tshp);
}
if (dim2 > dim1) {
for (size_t i = 0; i < dim2 - 2; i++)
tshp[j++] = layout2.shape[i];
batch_shp = tshp;
batch_shp.ndim = dim2 - 2;
tshp[j++] = layout1[layout1.ndim - 2];
tshp[j++] = layout1[layout1.ndim - 1];
tshp.ndim = j;
layout1 = layout1.broadcast(tshp);
}
if (dim1 == dim2) {
for (size_t i = 0; i < dim1 - 2; i++)
tshp[j++] = layout1.shape[i];
batch_shp = tshp;
batch_shp.ndim = dim1 - 2;
}
TensorShape shp1 = batch_shp, shp2 = batch_shp;
shp1.ndim += 2;
shp2.ndim += 2;
size_t maxdim = dim1 > dim2 ? dim1 : dim2;
size_t nbatch = batch_shp[0];
auto inp1 = inputs[0], inp2 = inputs[1];
if (maxdim > 3) {
nbatch = std::accumulate(
batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1,
std::multiplies<size_t>());
TensorLayout layout_a;
TensorShape nl1 = TensorShape(
{nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]});
if (!layout1.try_reshape(layout_a, nl1)) {
inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1);
inp1->to_contiguous_inplace();
layout1 = inp1->layout();
}
layout1 = layout_a;
TensorShape nl2 = TensorShape(
{nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]});
if (!layout2.try_reshape(layout_a, nl2)) {
inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2);
inp2->to_contiguous_inplace();
layout2 = inp2->layout();
}
layout2 = layout_a;
}
TensorLayout dst_layout(
{nbatch, matmul.transposeA ? layout1[2] : layout1[1],
matmul.transposeB ? layout2[1] : layout2[2]},
layout1.dtype);
dst_layout.init_contiguous_stride();
if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) {
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout);
if (!out.empty()) {
dev_tensor_memset(out, 0);
}
return {Tensor::make(out)};
}
using TensorND = megdnn::TensorND;
TensorND inp_nd1 = inp1->dnn_tensor();
inp_nd1.layout = layout1;
TensorND inp_nd2 = inp2->dnn_tensor();
inp_nd2.layout = layout2;
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout);
DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn);
dnn_opr.op->param() = matmul.param();
size_t sz = setup_algo<megdnn::BatchedMatrixMul>(
{layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn,
matmul.policy(), false);
TensorLayout w_layout({sz}, dtype::Byte());
auto dnn_wk = dnn_opr.create_workspace(w_layout);
dnn_opr.op->exec(inp_nd1, inp_nd2, out.as_megdnn(), dnn_wk);
shp1[shp1.ndim - 2] = dst_layout[dst_layout.ndim - 2];
shp1[shp1.ndim - 1] = dst_layout[dst_layout.ndim - 1];
if (maxdim > 3) {
dst_layout = dst_layout.reshape(shp1);
}
if (remove_row) {
dst_layout = dst_layout.remove_axis(maxdim - 2);
}
if (remove_col) {
dst_layout = dst_layout.remove_axis(maxdim - 1);
}
return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
return layout_checker;
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.get_input_layout_constraint(get_input_layout_constraint)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace batched_matrix_mul
} // namespace
namespace {
namespace dot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Dot>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Dot::make(inputs[0], inputs[1], config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto comp_node = inputs[0]->comp_node();
using TensorND = megdnn::TensorND;
SmallVector<TensorND> inp_tensornds;
inp_tensornds.reserve(inputs.size());
DnnOprCaller<megdnn::Dot> dnn_opr(comp_node);
for (unsigned i = 0; i < inputs.size(); ++i) {
auto dnn_ten = inputs[i]->dnn_tensor();
inp_tensornds.push_back(dnn_ten);
}
TensorLayout oup_layout{inputs[0]->dtype()};
auto inp1_tensor = inputs[0]->dnn_tensor();
auto inp2_tensor = inputs[1]->dnn_tensor();
dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout);
if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) {
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
if (!out.empty()) {
dev_tensor_memset(out, 0);
}
return {Tensor::make(out)};
}
auto sz = dnn_opr.op->get_workspace_in_bytes(
inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout);
DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
TensorLayout w_layout({sz}, dtype::Byte());
auto dnn_wk = dnn_opr.create_workspace(w_layout);
dnn_opr.op->exec(
inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk);
return {Tensor::make(out_devtensor)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(
inputs.size() == 2, "Dot expects 2 inputs; got %lu actually",
inputs.size());
SmallVector<LogicalTensorDesc> dests(1);
dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype);
dests[0].comp_node = inputs[0].comp_node;
bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0;
return {dests, validated};
}
OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace dot
} // anonymous namespace
} // namespace imperative
} // namespace mgb
......@@ -123,7 +123,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src);
auto mode = op_def.param().mode;
DnnOprCaller<megdnn::Fill> fill_op(comp_node);
if (!keepdim && src.ndim > 1) {
layout.remove_axis_inplace(axis);
......@@ -135,12 +134,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
switch (mode) {
case Reduce::Mode::SUM:
if (!out.empty()) {
fill_op.op->param() = 0;
fill_op.op->exec(out.as_megdnn(), {});
dev_tensor_memset(out, 0);
}
break;
case Reduce::Mode::PRODUCT:
if (!out.empty()) {
DnnOprCaller<megdnn::Fill> fill_op(comp_node);
fill_op.op->param() = 1;
fill_op.op->exec(out.as_megdnn(), {});
}
......
......@@ -319,34 +319,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias)
} // namespace batch_conv_bias
} // namespace
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::MatrixMul::make(
inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
}
OP_TRAIT_REG(MatrixMul, MatrixMul).apply_on_var_node(apply_on_var_node).fallback();
} // namespace matrix_mul
} // namespace
namespace {
namespace batched_matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::BatchedMatrixMul::make(
inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace batched_matrix_mul
} // namespace
namespace {
namespace argsort {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
......@@ -183,6 +183,57 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, converted);
}
ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>());
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype;
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
conv_op.compute_mode = MatrixMul::ComputeMode::FLOAT32;
target_dtype = DTypePromoteCfg::amp_low_prec_dtype;
} else {
target_dtype = get_promoted_dtype(dtypes);
}
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
ValueRefList batch_matmul_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op =
const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>());
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype;
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
conv_op.compute_mode = BatchedMatrixMul::ComputeMode::FLOAT32;
target_dtype = DTypePromoteCfg::amp_low_prec_dtype;
} else {
target_dtype = get_promoted_dtype(dtypes);
}
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
// differ from Convolution, ConvolutionBackwardData is used in both
// functional.conv_transpose2d and quantize.conv_transpose2d
ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) {
......@@ -259,8 +310,11 @@ struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<Concat>(naive_promote_rule);
register_dtype_promote_rule<GroupLocal>(naive_promote_rule);
register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule);
register_dtype_promote_rule<MatrixMul>(matmul_rule);
register_dtype_promote_rule<BatchedMatrixMul>(batch_matmul_rule);
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule);
register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册