diff --git a/dnn/include/megdnn/oprs/linalg.h b/dnn/include/megdnn/oprs/linalg.h index 10722bb33f0f71bcb30713a4e21dc4db887fddcc..0bc51b8ef1c4ed9fe4ad25f632fe15dae5066e2a 100644 --- a/dnn/include/megdnn/oprs/linalg.h +++ b/dnn/include/megdnn/oprs/linalg.h @@ -36,7 +36,7 @@ public: virtual void exec( _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; - void deduce_dtype(DType A, DType B, DType& C); + MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); virtual size_t get_workspace_in_bytes( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; @@ -73,7 +73,7 @@ public: virtual void exec( _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; - void deduce_dtype(DType A, DType B, DType& C); + MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); virtual size_t get_workspace_in_bytes( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 45a6230035071b942233313d33ef07ff94610830..b055f31f2619bc21ab46f3b4c49fe3c3c0c23192 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -44,216 +44,6 @@ def _elwise(*args, mode): return _elwise_apply(args, mode) -@lru_cache(maxsize=None) -def _get_extentedMatrixMulOp( - device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, -): - @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) - def extentedMatrixMulOp(inputs, f, c): - assert len(inputs) == 2 - inp1, inp2 = inputs - _dim1, _dim2 = dim1, dim2 - - def build_shape_head(shape, idx=-1): - # shape[:idx] - return f( - builtin.Subtensor(items=[[0, False, True, False, False]]), - shape, - c(idx, "int32"), - ) - - def build_shape_tail(shape, idx=-1): - # shape[idx:] - return f( - builtin.Subtensor(items=[[0, True, False, False, False]]), - shape, - c(idx, "int32"), - ) - - remove_row, remove_col = False, False - if _dim1 == 1: - _dim1 = 2 - remove_row = True - if _dim2 == 1: - _dim2 = 2 - remove_col = True - - if remove_row: - inp1 = f(builtin.AddAxis(axis=[0,]), inp1) - if remove_col: - inp2 = f(builtin.AddAxis(axis=[1,]), inp2) - - shape1 = f(builtin.GetVarShape(), inp1) - shape2 = f(builtin.GetVarShape(), inp2) - if _dim1 > 2: - inp1 = f( - builtin.Reshape(), - inp1, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), - build_shape_tail(shape1), - ), - ) - if _dim2 > 2: - inp2 = f( - builtin.Reshape(), - inp2, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), - build_shape_tail(shape2), - ), - ) - op = builtin.MatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=strategy.value, - ) - result = f(op, inp1, inp2) - result_shape = f(builtin.GetVarShape(), result) - if _dim1 > 2: - result = f( - builtin.Reshape(), - result, - f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape1), - build_shape_tail(result_shape), - ), - ) - if _dim2 > 2: - result = f( - builtin.Reshape(), - result, - f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape2), - build_shape_tail(result_shape), - ), - ) - maxdim = _dim1 if _dim1 > _dim2 else _dim2 - if remove_row: - result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) - if remove_col: - result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) - return (result,), (True,) - - return extentedMatrixMulOp - - -@lru_cache(maxsize=None) -def _get_extentedBatchedMatrixMulOp( - device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, -): - @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) - def extentedBatchedMatrixMulOp(inputs, f, c): - assert len(inputs) == 2 - inp1, inp2 = inputs - _dim1, _dim2 = dim1, dim2 - - def build_shape_head(shape, idx=-2): - # shape[:idx] - return f( - builtin.Subtensor(items=[[0, False, True, False, False]]), - shape, - c(idx, "int32"), - ) - - def build_shape_tail(shape, idx=-2): - # shape[idx:] - return f( - builtin.Subtensor(items=[[0, True, False, False, False]]), - shape, - c(idx, "int32"), - ) - - remove_row, remove_col = False, False - if _dim1 == 1: - _dim1 = 2 - remove_row = True - if _dim2 == 1: - _dim2 = 2 - remove_col = True - - if remove_row: - inp1 = f(builtin.AddAxis(axis=[0,]), inp1) - if remove_col: - inp2 = f(builtin.AddAxis(axis=[1,]), inp2) - shape1 = f(builtin.GetVarShape(), inp1) - shape2 = f(builtin.GetVarShape(), inp2) - maxdim = _dim1 if _dim1 > _dim2 else _dim2 - if _dim1 > _dim2: - # broadcast - shape2 = f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] - shape2, - ) - inp2 = f(builtin.Broadcast(), inp2, shape2) - batch_shape = build_shape_head(shape1) - if _dim2 > _dim1: - # broadcast - shape1 = f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] - shape1, - ) - inp1 = f(builtin.Broadcast(), inp1, shape1) - batch_shape = build_shape_head(shape2) - if _dim1 == _dim2: - batch_shape = build_shape_head(shape1) - - # compress inputs to 3d - if maxdim > 3: - inp1 = f( - builtin.Reshape(), - inp1, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), batch_shape), - build_shape_tail(shape1), - ), - ) - inp2 = f( - builtin.Reshape(), - inp2, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), batch_shape), - build_shape_tail(shape2), - ), - ) - op = builtin.BatchedMatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=strategy.value, - ) - result = f(op, inp1, inp2) - - if maxdim > 3: - result = f( - builtin.Reshape(), - result, - f( - builtin.Concat(axis=0, comp_node=device), - batch_shape, - build_shape_tail(f(builtin.GetVarShape(), result)), - ), - ) - if remove_row: - result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) - if remove_col: - result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) - return (result,), (True,) - - return extentedBatchedMatrixMulOp - - class _Hashable: def __init__(self, value) -> None: self.value = value @@ -267,42 +57,6 @@ class _Hashable: return self.value == o.value -def symbolicMatrixMul( - inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy -): - extentedMatrixMulOp = _get_extentedMatrixMulOp( - inp1.device, - inp1.dtype, - dim1, - dim2, - transpose_a, - transpose_b, - compute_mode, - format, - strategy=_Hashable(strategy), - ) - (result,) = apply(extentedMatrixMulOp(), inp1, inp2) - return result - - -def symbolicBatchedMatrixMul( - inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy -): - extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( - inp1.device, - inp1.dtype, - dim1, - dim2, - transpose_a, - transpose_b, - compute_mode, - format, - strategy=_Hashable(strategy), - ) - (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) - return result - - def _matmul( inp1, inp2, @@ -342,11 +96,8 @@ def _matmul( transpose_a, transpose_b, compute_mode, - format, _config._benchmark_kernel, _config._deterministic_kernel, - strategy, - symbolicMatrixMul, ) else: # dispath to BatchedMatrixMul # nx1(transpose_a=True), n>=3 @@ -362,11 +113,8 @@ def _matmul( transpose_a, transpose_b, compute_mode, - format, _config._benchmark_kernel, _config._deterministic_kernel, - strategy, - symbolicBatchedMatrixMul, ) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 0137b6d7d51c431eaa40980d87ce4c9297915452..75f84a0deb3b9758d6acee9d29cb50be79af1c29 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -32,7 +32,7 @@ from ..core.ops.builtin import ( TypeCvt, ) from ..core.tensor import amp, megbrain_graph -from ..core.tensor.array_method import _elwise_apply +from ..core.tensor.array_method import _matmul from ..core.tensor.utils import ( astensor1d, cast_tensors, @@ -49,7 +49,7 @@ from ..utils.deprecation import deprecated_func from .debug_param import get_execution_strategy from .distributed import all_reduce_sum from .elemwise import _elwise, exp, log, log1p, maximum, minimum -from .math import matmul, max, sum +from .math import max, sum from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros __all__ = [ @@ -127,7 +127,7 @@ def linear( bias: bias with shape `(out_features,)`. Default: None """ compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) - ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) + ret = _matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) if bias is not None: if amp._enabled: bias = bias.astype("float16") diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 7171566dcc8b33ab828253a7fe5a08bb93b6fae6..9f5bcc6204bb100c0a068d595e83d68f9e6a9d55 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1494,73 +1494,61 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { py::object _matmul_cpp( py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, - py::handle format, py::handle profile, py::handle determistic, - py::handle strategy, py::handle func) { - if (enable_fastpath(inp1)) { - ::megdnn::param::MatrixMul::ComputeMode mode = - ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; - if (compute_mode.cast().compare(std::string("float32")) == 0) { - mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; - } - ::megdnn::param::ExecutionPolicy::Strategy cstrategy; - if (profile.cast()) { - cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; - } else { - cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; - } - if (determistic.cast()) { - cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; - } - std::shared_ptr op = MatrixMul::make( - transpose_a.cast(), transpose_b.cast(), mode, - ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); - - py::object Op = py::cast(op); - PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; - py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); - return ret[0]; + py::handle profile, py::handle determistic) { + ::megdnn::param::MatrixMul::ComputeMode mode = + ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; + if (compute_mode.cast().compare(std::string("float32")) == 0) { + mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; + } + ::megdnn::param::ExecutionPolicy::Strategy cstrategy = + static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); + if (profile.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; } else { - // fallback to traceable implementation - return func( - inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, - strategy); + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; } + if (determistic.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; + } + std::shared_ptr op = MatrixMul::make( + transpose_a.cast(), transpose_b.cast(), mode, + ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, + dim1.cast(), dim2.cast()); + + py::object Op = py::cast(op); + PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); + return ret[0]; } py::object _batched_matmul_cpp( py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, - py::handle format, py::handle profile, py::handle determistic, - py::handle strategy, py::handle func) { - if (enable_fastpath(inp1)) { - ::megdnn::param::MatrixMul::ComputeMode mode = - ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; - if (compute_mode.cast().compare(std::string("float32")) == 0) { - mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; - } - ::megdnn::param::ExecutionPolicy::Strategy cstrategy; - if (profile.cast()) { - cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; - } else { - cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; - } - if (determistic.cast()) { - cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; - } - std::shared_ptr op = BatchedMatrixMul::make( - transpose_a.cast(), transpose_b.cast(), mode, - ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); - - py::object Op = py::cast(op); - PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; - py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); - return ret[0]; + py::handle profile, py::handle determistic) { + ::megdnn::param::MatrixMul::ComputeMode mode = + ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; + if (compute_mode.cast().compare(std::string("float32")) == 0) { + mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; + } + ::megdnn::param::ExecutionPolicy::Strategy cstrategy = + static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); + if (profile.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; } else { - // fallback to traceable implementation - return func( - inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, - strategy); + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; } + if (determistic.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; + } + std::shared_ptr op = BatchedMatrixMul::make( + transpose_a.cast(), transpose_b.cast(), mode, + ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, + dim1.cast(), dim2.cast()); + + py::object Op = py::cast(op); + PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); + return ret[0]; } py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { @@ -1671,7 +1659,7 @@ PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { try { return _matmul_cpp( args[0], args[1], args[2], args[3], args[4], args[5], args[6], - args[7], args[8], args[9], args[10], args[11]) + args[7], args[8]) .release() .ptr(); } @@ -1682,7 +1670,7 @@ PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs try { return _batched_matmul_cpp( args[0], args[1], args[2], args[3], args[4], args[5], args[6], - args[7], args[8], args[9], args[10], args[11]) + args[7], args[8]) .release() .ptr(); } diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index c2e09bfb7944300f6c29c2f6c151a1476b197491..cee3b9d847856398e3d6213eea849d055fe85184 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -20,7 +20,6 @@ import megengine.optimizer as optim from megengine import tensor from megengine.autodiff import GradManager from megengine.jit import trace -from megengine.traced_module import trace_module @contextlib.contextmanager diff --git a/imperative/src/impl/ops/matmul.cpp b/imperative/src/impl/ops/matmul.cpp index 88494c34d285384bfa0840178504ef90b9ca1c75..634fee2b6002ee0d05ba2c669968a4bb2738d009 100644 --- a/imperative/src/impl/ops/matmul.cpp +++ b/imperative/src/impl/ops/matmul.cpp @@ -2,8 +2,12 @@ #include "../blob_manager_impl.h" #include "../dnn_op_helper.h" #include "../op_trait.h" +#include "megbrain/graph/symbol_var.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/blas.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" #include "../algo_chooser.h" @@ -12,12 +16,93 @@ namespace imperative { namespace { namespace matrix_mul { + auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& matmul = def.cast_final_safe(); mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{matmul.make_name()}; - return opr::MatrixMul::make( - inputs[0], inputs[1], matmul.param(), matmul.policy(), config); + auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; + auto dim1 = matmul.dimA, dim2 = matmul.dimB; + + auto cn = inputs[0]->comp_node(); + using Desc = opr::AxisAddRemove::AxisDesc; + using IndexDesc = opr::Subtensor::IndexDesc; + OperatorNodeConfig config{matmul.make_name(), cn}; + + DTypeScalar vi{-1}; + auto graph = inputs[0]->owner_graph(); + + bool remove_row = false, remove_col = false; + if (dim1 == 1) { + dim1 = 2; + remove_row = true; + inp1 = inp1.add_axis(0); + } + if (dim2 == 1) { + dim2 = 2; + remove_col = true; + inp2 = inp2.add_axis(1); + } + + SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; + if (dim1 > 2) { + auto idx = opr::ImmutableTensor::make(*graph, vi, config); + auto shp1 = inp1.symshape(); + IndexDesc head_desc(1); + head_desc[0].end = idx; + shp1_head = opr::Subtensor::make(shp1, head_desc); + auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0}); + IndexDesc tail_desc(1); + tail_desc[0].begin = idx; + shp1_tail = opr::Subtensor::make(shp1, tail_desc); + auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn); + inp1 = inp1.reshape(tshp); + } + if (dim2 > 2) { + auto idx = opr::ImmutableTensor::make(*graph, vi, config); + auto shp2 = inp2.symshape(); + IndexDesc head_desc(1); + head_desc[0].end = idx; + shp2_head = opr::Subtensor::make(shp2, head_desc); + auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0}); + IndexDesc tail_desc(1); + tail_desc[0].begin = idx; + auto shp2_tail = opr::Subtensor::make(shp2, tail_desc); + auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn); + inp2 = inp2.reshape(tshp); + } + auto result = + opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config); + if (dim1 > 2) { + auto idx = opr::ImmutableTensor::make(*graph, vi, config); + auto result_shape = result.symshape(); + IndexDesc tail_desc(1); + tail_desc[0].begin = idx; + auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); + auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn); + result = result.reshape(tshp); + } + if (dim2 > 2) { + auto idx = opr::ImmutableTensor::make(*graph, vi, config); + auto result_shape = result.symshape(); + IndexDesc tail_desc(1); + tail_desc[0].begin = idx; + auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); + auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn); + result = result.reshape(tshp); + } + + auto maxdim = dim1 > dim2 ? dim1 : dim2; + if (remove_row) { + std::vector remove_param; + remove_param.push_back(Desc::make_remove(maxdim - 2)); + result = opr::AxisAddRemove::make(result, remove_param); + } + if (remove_col) { + std::vector remove_param; + remove_param.push_back(Desc::make_remove(maxdim - 1)); + result = opr::AxisAddRemove::make(result, remove_param); + } + return result; } std::tuple, bool> infer_output_attrs_fallible( @@ -27,8 +112,14 @@ std::tuple, bool> infer_output_attrs_fallible( auto layout2 = inputs[1].layout; size_t dim1 = layout1.ndim, dim2 = layout2.ndim; + DType dst_dtype; + + DnnOprCaller dnn_opr(inputs[0].comp_node); + dnn_opr.op->param() = matmul.param(); + dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + if (dim1 == 0 || dim2 == 0) { - return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; + return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; } if (matmul.transposeA) @@ -37,7 +128,8 @@ std::tuple, bool> infer_output_attrs_fallible( std::swap(layout2[0], layout2[1]); mgb_assert(layout1[dim1 - 1] == layout2[0]); - TensorLayout dst_layout(layout1.dtype); + + TensorLayout dst_layout(dst_dtype); size_t ci = 0; for (size_t i = 0; i < dim1 - 1; i++) dst_layout[ci++] = layout1[i]; @@ -61,6 +153,12 @@ SmallVector apply_on_physical_tensor( SmallVector inp_tensornds(inputs.size()); TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); + DnnOprCaller dnn_opr(cn); + dnn_opr.op->param() = matmul.param(); + + DType dst_dtype; + dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + // only matters when layout1 has dim 2 if (matmul.transposeA) std::swap(layout1.shape[0], layout1.shape[1]); @@ -69,7 +167,7 @@ SmallVector apply_on_physical_tensor( std::swap(layout2.shape[0], layout2.shape[1]); size_t dim1 = layout1.ndim, dim2 = layout2.ndim; - TensorLayout real_dst_layout(layout1.dtype); + TensorLayout real_dst_layout(dst_dtype); if (validated) { real_dst_layout = output_descs[0].layout; } else { @@ -126,12 +224,9 @@ SmallVector apply_on_physical_tensor( inp_tensornds[1] = inputs[1]->dnn_tensor(); } - TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype); + TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); dst_layout.init_contiguous_stride(); - DnnOprCaller dnn_opr(cn); - dnn_opr.op->param() = matmul.param(); - DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); size_t sz = setup_algo( @@ -167,9 +262,99 @@ namespace batched_matrix_mul { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& matmul = def.cast_final_safe(); mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{matmul.make_name()}; - return opr::BatchedMatrixMul::make( - inputs[0], inputs[1], matmul.param(), matmul.policy(), config); + auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; + auto dim1 = matmul.dimA, dim2 = matmul.dimB; + + auto cn = inputs[0]->comp_node(); + using Desc = opr::AxisAddRemove::AxisDesc; + using IndexDesc = opr::Subtensor::IndexDesc; + OperatorNodeConfig config{matmul.make_name(), cn}; + + DTypeScalar vi{-2}; + auto graph = inputs[0]->owner_graph(); + auto idx = opr::ImmutableTensor::make(*graph, vi, config); + + bool remove_row = false, remove_col = false; + if (dim1 == 1) { + dim1 = 2; + remove_row = true; + inp1 = inp1.add_axis(0); + } + if (dim2 == 1) { + dim2 = 2; + remove_col = true; + inp2 = inp2.add_axis(1); + } + + auto shp1 = inp1.symshape(); + auto shp2 = inp2.symshape(); + SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; + SymbolVar batch_shape; + if (dim1 > dim2) { + HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); + auto* ptr = hv.ptr(); + ptr[0] = -dim2; + IndexDesc head_desc(1); + head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); + shp1_head = opr::Subtensor::make(shp1, head_desc); + shp2 = opr::Concat::make({shp1_head, shp2}, 0, cn); + inp2 = inp2.broadcast(shp2); + head_desc[0].end = idx; + batch_shape = opr::Subtensor::make(shp1, head_desc); + } + if (dim2 > dim1) { + HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); + auto* ptr = hv.ptr(); + ptr[0] = -dim1; + IndexDesc head_desc(1); + head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); + shp2_head = opr::Subtensor::make(shp2, head_desc); + shp1 = opr::Concat::make({shp2_head, shp1}, 0, cn); + inp1 = inp1.broadcast(shp1); + head_desc[0].end = idx; + batch_shape = opr::Subtensor::make(shp2, head_desc); + } + if (dim1 == dim2) { + IndexDesc head_desc(1); + head_desc[0].end = idx; + batch_shape = opr::Subtensor::make(shp1, head_desc); + } + + auto maxdim = dim1 > dim2 ? dim1 : dim2; + if (maxdim > 3) { + IndexDesc tail_desc(1); + tail_desc[0].begin = idx; + shp1_tail = opr::Subtensor::make(shp1, tail_desc); + auto batch = opr::Reduce::make(batch_shape, {Reduce::Mode::PRODUCT, 0}); + shp1 = opr::Concat::make({batch, shp1_tail}, 0, cn); + inp1 = inp1.reshape(shp1); + shp2_tail = opr::Subtensor::make(shp2, tail_desc); + shp2 = opr::Concat::make({batch, shp2_tail}, 0, cn); + inp2 = inp2.reshape(shp2); + } + + auto result = opr::BatchedMatrixMul::make( + inp1, inp2, matmul.param(), matmul.policy(), config); + + if (maxdim > 3) { + auto result_shp = result.symshape(); + IndexDesc tail_desc(1); + tail_desc[0].begin = idx; + auto shp_tail = opr::Subtensor::make(result_shp, tail_desc); + result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); + result = result.reshape(result_shp); + } + if (remove_row) { + std::vector remove_param; + remove_param.push_back(Desc::make_remove(maxdim - 2)); + result = opr::AxisAddRemove::make(result, remove_param); + } + if (remove_col) { + std::vector remove_param; + remove_param.push_back(Desc::make_remove(maxdim - 1)); + result = opr::AxisAddRemove::make(result, remove_param); + } + return result; } std::tuple, bool> infer_output_attrs_fallible( @@ -178,8 +363,14 @@ std::tuple, bool> infer_output_attrs_fallible( TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; size_t dim1 = layout1.ndim, dim2 = layout2.ndim; + DType dst_dtype; + + DnnOprCaller dnn_opr(inputs[0].comp_node); + dnn_opr.op->param() = matmul.param(); + dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + if (dim1 == 0 || dim2 == 0) { - return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; + return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; } if (matmul.transposeA) @@ -187,7 +378,7 @@ std::tuple, bool> infer_output_attrs_fallible( if (matmul.transposeB) std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); - TensorLayout dst_layout(layout1.dtype); + TensorLayout dst_layout(dst_dtype); size_t di = 0; if (dim1 > dim2) { for (size_t i = 0; i < dim1 - 2; i++) @@ -217,6 +408,11 @@ SmallVector apply_on_physical_tensor( TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); size_t dim1 = layout1.ndim, dim2 = layout2.ndim; + DnnOprCaller dnn_opr(cn); + dnn_opr.op->param() = matmul.param(); + DType dst_dtype; + dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + bool remove_row = false, remove_col = false; if (dim1 == 1) { dim1 = 2; @@ -234,6 +430,7 @@ SmallVector apply_on_physical_tensor( TensorShape tshp, batch_shp; size_t j = 0; + auto inp1 = inputs[0], inp2 = inputs[1]; if (dim1 > dim2) { for (size_t i = 0; i < dim1 - 2; i++) tshp[j++] = layout1.shape[i]; @@ -266,7 +463,6 @@ SmallVector apply_on_physical_tensor( shp2.ndim += 2; size_t maxdim = dim1 > dim2 ? dim1 : dim2; size_t nbatch = batch_shp[0]; - auto inp1 = inputs[0], inp2 = inputs[1]; if (maxdim > 3) { nbatch = std::accumulate( batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, @@ -274,29 +470,29 @@ SmallVector apply_on_physical_tensor( TensorLayout layout_a; + // batched_matmul does not support memory forwarding, so ensure contiguous + // manually TensorShape nl1 = TensorShape( {nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); - if (!layout1.try_reshape(layout_a, nl1)) { - inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); - inp1->to_contiguous_inplace(); - layout1 = inp1->layout(); - } + inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); + inp1->to_contiguous_inplace(); + layout1 = inp1->layout(); + layout_a = layout1.reshape(nl1); layout1 = layout_a; TensorShape nl2 = TensorShape( {nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); - if (!layout2.try_reshape(layout_a, nl2)) { - inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); - inp2->to_contiguous_inplace(); - layout2 = inp2->layout(); - } + inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); + inp2->to_contiguous_inplace(); + layout2 = inp2->layout(); + layout_a = layout2.reshape(nl2); layout2 = layout_a; } TensorLayout dst_layout( {nbatch, matmul.transposeA ? layout1[2] : layout1[1], matmul.transposeB ? layout2[1] : layout2[2]}, - layout1.dtype); + dst_dtype); dst_layout.init_contiguous_stride(); if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { @@ -317,9 +513,6 @@ SmallVector apply_on_physical_tensor( DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); - DnnOprCaller dnn_opr(cn); - dnn_opr.op->param() = matmul.param(); - size_t sz = setup_algo( {layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, matmul.policy(), false); diff --git a/imperative/tablegen/helper.h b/imperative/tablegen/helper.h index 1117ad24178ba43b7a442f3f42f601022aca1ca0..9fda8d36f9a4684335dbc300a9bc5f3b70fc1d47 100644 --- a/imperative/tablegen/helper.h +++ b/imperative/tablegen/helper.h @@ -246,7 +246,12 @@ private: it.name, enumMember.substr(0, d)); body += " break;\n"; } - body += " default: break;\n"; + body += " default:\n"; + body += + formatv(" props_.emplace_back(\"{0}\", " + "\"INVALID\");\n", + it.name); + body += " break;\n"; body += " }\n"; } else { auto&& attr = llvm::cast(it.attr); diff --git a/imperative/tablegen/targets/cpp_class.cpp b/imperative/tablegen/targets/cpp_class.cpp index 6f02ab55d57c6813169d6fc86e0e779b7dc41406..64c10b55ceb8d08df353a9ce06167cd438b3644c 100644 --- a/imperative/tablegen/targets/cpp_class.cpp +++ b/imperative/tablegen/targets/cpp_class.cpp @@ -89,19 +89,35 @@ void OpDefEmitter::emit_header() { gen_ctor("", "", " = default;"); if (!op.getMgbAttributes().empty()) { + std::string strategy_val = ""; std::vector paramList, initList; for (auto&& i : op.getMgbAttributes()) { + if (attr_to_ctype(i.attr).compare("Strategy") == 0) { + strategy_val = i.name; + } paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); initList.push_back(formatv("{0}({0}_)", i.name)); } paramList.push_back("std::string scope_ = {}"); - gen_ctor( - llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), - " { set_scope(scope_); }"); + if (!strategy_val.empty()) { + gen_ctor( + llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), + formatv(" {" + "\n set_scope(scope_);" + "\n mgb_assert(static_cast({0}) <= " + "uint32_t(8));" + "\n }", + strategy_val)); + } else { + gen_ctor( + llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), + " { set_scope(scope_); }"); + } } auto packedParams = op.getPackedParams(); if (!packedParams.empty()) { + std::string strategy_val = ""; std::vector paramList, initList; for (auto&& p : packedParams) { auto&& paramFields = p.getFields(); @@ -111,6 +127,9 @@ void OpDefEmitter::emit_header() { paramFields.empty() ? paramType.str() : formatv("{0} {1}", paramType, paramName)); for (auto&& i : paramFields) { + if (i.name.compare("strategy") == 0) { + strategy_val = i.name; + } initList.push_back(formatv("{0}({1}.{0})", i.name, paramName)); } } @@ -118,9 +137,20 @@ void OpDefEmitter::emit_header() { paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); initList.push_back(formatv("{0}({0}_)", i.name)); } - gen_ctor( - llvm::join(paramList, ", "), - initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); + if (!strategy_val.empty()) { + gen_ctor( + llvm::join(paramList, ", "), + initList.empty() ? "" : ": " + llvm::join(initList, ", "), + formatv(" {" + "\n mgb_assert(static_cast({0}) <= " + "uint32_t(8));" + "\n }", + strategy_val)); + } else { + gen_ctor( + llvm::join(paramList, ", "), + initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); + } } if (!packedParams.empty()) { diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 13148873b9a0274d66c0808cc2b756cd3acd6924..47f1c826e48da759d2c7f998579cfd196975ed79 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -43,9 +43,19 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>; -def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; +def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { + let extraArguments = (ins + MgbUI32Attr:$dimA, + MgbUI32Attr:$dimB + ); +} -def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; +def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { + let extraArguments = (ins + MgbUI32Attr:$dimA, + MgbUI32Attr:$dimB + ); +} def Dot: MgbHashableOp<"Dot", [EmptyParam]>;