提交 409c9881 编写于 作者: M Megvii Engine Team

fix(imperative): add matmul apply_on_varnode

GitOrigin-RevId: 2cf6bf237cb573f0c78fcb5cacc0257f99ebcecb
上级 d52ba79d
......@@ -36,7 +36,7 @@ public:
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType& C);
MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
......@@ -73,7 +73,7 @@ public:
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType& C);
MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
......
......@@ -44,216 +44,6 @@ def _elwise(*args, mode):
return _elwise_apply(args, mode)
@lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-1):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-1):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(builtin.GetVarShape(), inp1)
shape2 = f(builtin.GetVarShape(), inp2)
if _dim1 > 2:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
build_shape_tail(shape1),
),
)
if _dim2 > 2:
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
build_shape_tail(shape2),
),
)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
result_shape = f(builtin.GetVarShape(), result)
if _dim1 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1),
build_shape_tail(result_shape),
),
)
if _dim2 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2),
build_shape_tail(result_shape),
),
)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedMatrixMulOp
@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-2):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-2):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(builtin.GetVarShape(), inp1)
shape2 = f(builtin.GetVarShape(), inp2)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if _dim1 > _dim2:
# broadcast
shape2 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
shape2,
)
inp2 = f(builtin.Broadcast(), inp2, shape2)
batch_shape = build_shape_head(shape1)
if _dim2 > _dim1:
# broadcast
shape1 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
shape1,
)
inp1 = f(builtin.Broadcast(), inp1, shape1)
batch_shape = build_shape_head(shape2)
if _dim1 == _dim2:
batch_shape = build_shape_head(shape1)
# compress inputs to 3d
if maxdim > 3:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape1),
),
)
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape2),
),
)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
if maxdim > 3:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
batch_shape,
build_shape_tail(f(builtin.GetVarShape(), result)),
),
)
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedBatchedMatrixMulOp
class _Hashable:
def __init__(self, value) -> None:
self.value = value
......@@ -267,42 +57,6 @@ class _Hashable:
return self.value == o.value
def symbolicMatrixMul(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy
):
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result
def symbolicBatchedMatrixMul(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy
):
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result
def _matmul(
inp1,
inp2,
......@@ -342,11 +96,8 @@ def _matmul(
transpose_a,
transpose_b,
compute_mode,
format,
_config._benchmark_kernel,
_config._deterministic_kernel,
strategy,
symbolicMatrixMul,
)
else: # dispath to BatchedMatrixMul
# nx1(transpose_a=True), n>=3
......@@ -362,11 +113,8 @@ def _matmul(
transpose_a,
transpose_b,
compute_mode,
format,
_config._benchmark_kernel,
_config._deterministic_kernel,
strategy,
symbolicBatchedMatrixMul,
)
......
......@@ -32,7 +32,7 @@ from ..core.ops.builtin import (
TypeCvt,
)
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.array_method import _matmul
from ..core.tensor.utils import (
astensor1d,
cast_tensors,
......@@ -49,7 +49,7 @@ from ..utils.deprecation import deprecated_func
from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, log, log1p, maximum, minimum
from .math import matmul, max, sum
from .math import max, sum
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
__all__ = [
......@@ -127,7 +127,7 @@ def linear(
bias: bias with shape `(out_features,)`. Default: None
"""
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
ret = _matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
if bias is not None:
if amp._enabled:
bias = bias.astype("float16")
......
......@@ -1494,73 +1494,61 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::object _matmul_cpp(
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
py::handle format, py::handle profile, py::handle determistic,
py::handle strategy, py::handle func) {
if (enable_fastpath(inp1)) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
::megdnn::param::ExecutionPolicy::Strategy cstrategy;
if (profile.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
} else {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = MatrixMul::make(
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX);
py::object Op = py::cast(op);
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret[0];
py::handle profile, py::handle determistic) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
::megdnn::param::ExecutionPolicy::Strategy cstrategy =
static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0);
if (profile.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
} else {
// fallback to traceable implementation
return func(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format,
strategy);
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = MatrixMul::make(
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX,
dim1.cast<uint32_t>(), dim2.cast<uint32_t>());
py::object Op = py::cast(op);
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret[0];
}
py::object _batched_matmul_cpp(
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
py::handle format, py::handle profile, py::handle determistic,
py::handle strategy, py::handle func) {
if (enable_fastpath(inp1)) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
::megdnn::param::ExecutionPolicy::Strategy cstrategy;
if (profile.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
} else {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = BatchedMatrixMul::make(
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX);
py::object Op = py::cast(op);
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret[0];
py::handle profile, py::handle determistic) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
::megdnn::param::ExecutionPolicy::Strategy cstrategy =
static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0);
if (profile.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
} else {
// fallback to traceable implementation
return func(
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format,
strategy);
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = BatchedMatrixMul::make(
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX,
dim1.cast<uint32_t>(), dim2.cast<uint32_t>());
py::object Op = py::cast(op);
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret[0];
}
py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) {
......@@ -1671,7 +1659,7 @@ PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _matmul_cpp(
args[0], args[1], args[2], args[3], args[4], args[5], args[6],
args[7], args[8], args[9], args[10], args[11])
args[7], args[8])
.release()
.ptr();
}
......@@ -1682,7 +1670,7 @@ PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs
try {
return _batched_matmul_cpp(
args[0], args[1], args[2], args[3], args[4], args[5], args[6],
args[7], args[8], args[9], args[10], args[11])
args[7], args[8])
.release()
.ptr();
}
......
......@@ -20,7 +20,6 @@ import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace
from megengine.traced_module import trace_module
@contextlib.contextmanager
......
......@@ -2,8 +2,12 @@
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/graph/symbol_var.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "../algo_chooser.h"
......@@ -12,12 +16,93 @@ namespace imperative {
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = def.cast_final_safe<MatrixMul>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::MatrixMul::make(
inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]};
auto dim1 = matmul.dimA, dim2 = matmul.dimB;
auto cn = inputs[0]->comp_node();
using Desc = opr::AxisAddRemove::AxisDesc;
using IndexDesc = opr::Subtensor::IndexDesc;
OperatorNodeConfig config{matmul.make_name(), cn};
DTypeScalar vi{-1};
auto graph = inputs[0]->owner_graph();
bool remove_row = false, remove_col = false;
if (dim1 == 1) {
dim1 = 2;
remove_row = true;
inp1 = inp1.add_axis(0);
}
if (dim2 == 1) {
dim2 = 2;
remove_col = true;
inp2 = inp2.add_axis(1);
}
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail;
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp1 = inp1.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp1_head = opr::Subtensor::make(shp1, head_desc);
auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
shp1_tail = opr::Subtensor::make(shp1, tail_desc);
auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn);
inp1 = inp1.reshape(tshp);
}
if (dim2 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp2 = inp2.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp2_head = opr::Subtensor::make(shp2, head_desc);
auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp2_tail = opr::Subtensor::make(shp2, tail_desc);
auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn);
inp2 = inp2.reshape(tshp);
}
auto result =
opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config);
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
}
if (dim2 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
}
auto maxdim = dim1 > dim2 ? dim1 : dim2;
if (remove_row) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 2));
result = opr::AxisAddRemove::make(result, remove_param);
}
if (remove_col) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 1));
result = opr::AxisAddRemove::make(result, remove_param);
}
return result;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......@@ -27,8 +112,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto layout2 = inputs[1].layout;
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
DType dst_dtype;
DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node);
dnn_opr.op->param() = matmul.param();
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false};
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false};
}
if (matmul.transposeA)
......@@ -37,7 +128,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
std::swap(layout2[0], layout2[1]);
mgb_assert(layout1[dim1 - 1] == layout2[0]);
TensorLayout dst_layout(layout1.dtype);
TensorLayout dst_layout(dst_dtype);
size_t ci = 0;
for (size_t i = 0; i < dim1 - 1; i++)
dst_layout[ci++] = layout1[i];
......@@ -61,6 +153,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
SmallVector<TensorND> inp_tensornds(inputs.size());
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout();
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn);
dnn_opr.op->param() = matmul.param();
DType dst_dtype;
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
// only matters when layout1 has dim 2
if (matmul.transposeA)
std::swap(layout1.shape[0], layout1.shape[1]);
......@@ -69,7 +167,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
std::swap(layout2.shape[0], layout2.shape[1]);
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
TensorLayout real_dst_layout(layout1.dtype);
TensorLayout real_dst_layout(dst_dtype);
if (validated) {
real_dst_layout = output_descs[0].layout;
} else {
......@@ -126,12 +224,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
inp_tensornds[1] = inputs[1]->dnn_tensor();
}
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype);
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype);
dst_layout.init_contiguous_stride();
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn);
dnn_opr.op->param() = matmul.param();
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout);
size_t sz = setup_algo<megdnn::MatrixMul>(
......@@ -167,9 +262,99 @@ namespace batched_matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::BatchedMatrixMul::make(
inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]};
auto dim1 = matmul.dimA, dim2 = matmul.dimB;
auto cn = inputs[0]->comp_node();
using Desc = opr::AxisAddRemove::AxisDesc;
using IndexDesc = opr::Subtensor::IndexDesc;
OperatorNodeConfig config{matmul.make_name(), cn};
DTypeScalar vi{-2};
auto graph = inputs[0]->owner_graph();
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
bool remove_row = false, remove_col = false;
if (dim1 == 1) {
dim1 = 2;
remove_row = true;
inp1 = inp1.add_axis(0);
}
if (dim2 == 1) {
dim2 = 2;
remove_col = true;
inp2 = inp2.add_axis(1);
}
auto shp1 = inp1.symshape();
auto shp2 = inp2.symshape();
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail;
SymbolVar batch_shape;
if (dim1 > dim2) {
HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = -dim2;
IndexDesc head_desc(1);
head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config);
shp1_head = opr::Subtensor::make(shp1, head_desc);
shp2 = opr::Concat::make({shp1_head, shp2}, 0, cn);
inp2 = inp2.broadcast(shp2);
head_desc[0].end = idx;
batch_shape = opr::Subtensor::make(shp1, head_desc);
}
if (dim2 > dim1) {
HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = -dim1;
IndexDesc head_desc(1);
head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config);
shp2_head = opr::Subtensor::make(shp2, head_desc);
shp1 = opr::Concat::make({shp2_head, shp1}, 0, cn);
inp1 = inp1.broadcast(shp1);
head_desc[0].end = idx;
batch_shape = opr::Subtensor::make(shp2, head_desc);
}
if (dim1 == dim2) {
IndexDesc head_desc(1);
head_desc[0].end = idx;
batch_shape = opr::Subtensor::make(shp1, head_desc);
}
auto maxdim = dim1 > dim2 ? dim1 : dim2;
if (maxdim > 3) {
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
shp1_tail = opr::Subtensor::make(shp1, tail_desc);
auto batch = opr::Reduce::make(batch_shape, {Reduce::Mode::PRODUCT, 0});
shp1 = opr::Concat::make({batch, shp1_tail}, 0, cn);
inp1 = inp1.reshape(shp1);
shp2_tail = opr::Subtensor::make(shp2, tail_desc);
shp2 = opr::Concat::make({batch, shp2_tail}, 0, cn);
inp2 = inp2.reshape(shp2);
}
auto result = opr::BatchedMatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
if (maxdim > 3) {
auto result_shp = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shp, tail_desc);
result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn);
result = result.reshape(result_shp);
}
if (remove_row) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 2));
result = opr::AxisAddRemove::make(result, remove_param);
}
if (remove_col) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 1));
result = opr::AxisAddRemove::make(result, remove_param);
}
return result;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......@@ -178,8 +363,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout;
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
DType dst_dtype;
DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node);
dnn_opr.op->param() = matmul.param();
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false};
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false};
}
if (matmul.transposeA)
......@@ -187,7 +378,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
if (matmul.transposeB)
std::swap(layout2[dim2 - 1], layout2[dim2 - 2]);
TensorLayout dst_layout(layout1.dtype);
TensorLayout dst_layout(dst_dtype);
size_t di = 0;
if (dim1 > dim2) {
for (size_t i = 0; i < dim1 - 2; i++)
......@@ -217,6 +408,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout();
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;
DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn);
dnn_opr.op->param() = matmul.param();
DType dst_dtype;
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
bool remove_row = false, remove_col = false;
if (dim1 == 1) {
dim1 = 2;
......@@ -234,6 +430,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorShape tshp, batch_shp;
size_t j = 0;
auto inp1 = inputs[0], inp2 = inputs[1];
if (dim1 > dim2) {
for (size_t i = 0; i < dim1 - 2; i++)
tshp[j++] = layout1.shape[i];
......@@ -266,7 +463,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
shp2.ndim += 2;
size_t maxdim = dim1 > dim2 ? dim1 : dim2;
size_t nbatch = batch_shp[0];
auto inp1 = inputs[0], inp2 = inputs[1];
if (maxdim > 3) {
nbatch = std::accumulate(
batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1,
......@@ -274,29 +470,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorLayout layout_a;
// batched_matmul does not support memory forwarding, so ensure contiguous
// manually
TensorShape nl1 = TensorShape(
{nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]});
if (!layout1.try_reshape(layout_a, nl1)) {
inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1);
inp1->to_contiguous_inplace();
layout1 = inp1->layout();
}
inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1);
inp1->to_contiguous_inplace();
layout1 = inp1->layout();
layout_a = layout1.reshape(nl1);
layout1 = layout_a;
TensorShape nl2 = TensorShape(
{nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]});
if (!layout2.try_reshape(layout_a, nl2)) {
inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2);
inp2->to_contiguous_inplace();
layout2 = inp2->layout();
}
inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2);
inp2->to_contiguous_inplace();
layout2 = inp2->layout();
layout_a = layout2.reshape(nl2);
layout2 = layout_a;
}
TensorLayout dst_layout(
{nbatch, matmul.transposeA ? layout1[2] : layout1[1],
matmul.transposeB ? layout2[1] : layout2[2]},
layout1.dtype);
dst_dtype);
dst_layout.init_contiguous_stride();
if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) {
......@@ -317,9 +513,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout);
DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn);
dnn_opr.op->param() = matmul.param();
size_t sz = setup_algo<megdnn::BatchedMatrixMul>(
{layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn,
matmul.policy(), false);
......
......@@ -246,7 +246,12 @@ private:
it.name, enumMember.substr(0, d));
body += " break;\n";
}
body += " default: break;\n";
body += " default:\n";
body +=
formatv(" props_.emplace_back(\"{0}\", "
"\"INVALID\");\n",
it.name);
body += " break;\n";
body += " }\n";
} else {
auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
......
......@@ -89,19 +89,35 @@ void OpDefEmitter::emit_header() {
gen_ctor("", "", " = default;");
if (!op.getMgbAttributes().empty()) {
std::string strategy_val = "";
std::vector<std::string> paramList, initList;
for (auto&& i : op.getMgbAttributes()) {
if (attr_to_ctype(i.attr).compare("Strategy") == 0) {
strategy_val = i.name;
}
paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
initList.push_back(formatv("{0}({0}_)", i.name));
}
paramList.push_back("std::string scope_ = {}");
gen_ctor(
llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
" { set_scope(scope_); }");
if (!strategy_val.empty()) {
gen_ctor(
llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
formatv(" {"
"\n set_scope(scope_);"
"\n mgb_assert(static_cast<uint32_t>({0}) <= "
"uint32_t(8));"
"\n }",
strategy_val));
} else {
gen_ctor(
llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "),
" { set_scope(scope_); }");
}
}
auto packedParams = op.getPackedParams();
if (!packedParams.empty()) {
std::string strategy_val = "";
std::vector<std::string> paramList, initList;
for (auto&& p : packedParams) {
auto&& paramFields = p.getFields();
......@@ -111,6 +127,9 @@ void OpDefEmitter::emit_header() {
paramFields.empty() ? paramType.str()
: formatv("{0} {1}", paramType, paramName));
for (auto&& i : paramFields) {
if (i.name.compare("strategy") == 0) {
strategy_val = i.name;
}
initList.push_back(formatv("{0}({1}.{0})", i.name, paramName));
}
}
......@@ -118,9 +137,20 @@ void OpDefEmitter::emit_header() {
paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name));
initList.push_back(formatv("{0}({0}_)", i.name));
}
gen_ctor(
llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}");
if (!strategy_val.empty()) {
gen_ctor(
llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "),
formatv(" {"
"\n mgb_assert(static_cast<uint32_t>({0}) <= "
"uint32_t(8));"
"\n }",
strategy_val));
} else {
gen_ctor(
llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}");
}
}
if (!packedParams.empty()) {
......
......@@ -43,9 +43,19 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
let extraArguments = (ins
MgbUI32Attr:$dimA,
MgbUI32Attr:$dimB
);
}
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
let extraArguments = (ins
MgbUI32Attr:$dimA,
MgbUI32Attr:$dimB
);
}
def Dot: MgbHashableOp<"Dot", [EmptyParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册