diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 3285fd6eeb5ce1db46d31aa328661a7c2f71b582..45a6230035071b942233313d33ef07ff94610830 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -20,9 +20,10 @@ from .._imperative_rt.core2 import ( Tensor, apply, astype_cpp, + batched_matmul_cpp, broadcast_cpp, - dtype_promotion, getitem_cpp, + matmul_cpp, ) from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp @@ -266,6 +267,42 @@ class _Hashable: return self.value == o.value +def symbolicMatrixMul( + inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy +): + extentedMatrixMulOp = _get_extentedMatrixMulOp( + inp1.device, + inp1.dtype, + dim1, + dim2, + transpose_a, + transpose_b, + compute_mode, + format, + strategy=_Hashable(strategy), + ) + (result,) = apply(extentedMatrixMulOp(), inp1, inp2) + return result + + +def symbolicBatchedMatrixMul( + inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy +): + extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( + inp1.device, + inp1.dtype, + dim1, + dim2, + transpose_a, + transpose_b, + compute_mode, + format, + strategy=_Hashable(strategy), + ) + (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) + return result + + def _matmul( inp1, inp2, @@ -274,16 +311,6 @@ def _matmul( compute_mode="default", format="default", ): - if amp._enabled: - compute_mode = "float32" - inp1, inp2 = cast_tensors(inp1, inp2) - else: - dtype = dtype_promotion(inp1, inp2) - if inp1.dtype != dtype: - inp1 = inp1.astype(dtype) - if inp2.dtype != dtype: - inp2 = inp2.astype(dtype) - dim1, dim2 = inp1.ndim, inp2.ndim assert dim1 > 0 and dim2 > 0 maxdim = dim1 if dim1 > dim2 else dim2 @@ -301,34 +328,46 @@ def _matmul( if dim1 == 1 and dim2 == 1: # dispatch to Dot (result,) = apply(builtin.Dot(), inp1, inp2) return result - elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul - extentedMatrixMulOp = _get_extentedMatrixMulOp( - inp1.device, - inp1.dtype, + elif maxdim <= 2 or (dim2 <= 2 and not transpose_a): # dispath to MatrixMul + # 2x1 + # 1x2 + # 2x2 + # nx1(transpose_a=False), n>=3 + # nx2(transpose_a=False), n>=3 + return matmul_cpp( + inp1, + inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, - strategy=_Hashable(strategy), + _config._benchmark_kernel, + _config._deterministic_kernel, + strategy, + symbolicMatrixMul, ) - (result,) = apply(extentedMatrixMulOp(), inp1, inp2) - return result else: # dispath to BatchedMatrixMul - extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( - inp1.device, - inp1.dtype, + # nx1(transpose_a=True), n>=3 + # nx2(transpose_a=True), n>=3 + # nxm,n>=3,m>=3 + # 1xm,m>=3 + # 2xm,m>=3 + return batched_matmul_cpp( + inp1, + inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, - strategy=_Hashable(strategy), + _config._benchmark_kernel, + _config._deterministic_kernel, + strategy, + symbolicBatchedMatrixMul, ) - (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) - return result def _unary_elwise(mode): diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index e066f79083b54215ed6311659c336c440c1f859b..6ff37249030df36b37fb0ebb3d81463bc9d5fdda 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -10,7 +10,7 @@ import collections import math from typing import Iterable, Optional, Sequence, Tuple, Union -from ..core._imperative_rt.core2 import Const, apply, dtype_promotion +from ..core._imperative_rt.core2 import Const, apply from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin from ..core.tensor.array_method import _matmul diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 0e6ea5345aa3ae062b47cf541cf6bc47eb4712c1..8594bf564f5eaa82b46244ddb6c6c4f95c6fa5c7 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -17,7 +17,6 @@ from ..core._imperative_rt.core2 import ( apply, dtype_promotion, ) -from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core.ops import builtin from ..core.ops.builtin import ( @@ -177,16 +176,6 @@ def conv1d( assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" assert inp.ndim == 3, "the input dimension of conv1d should be 3" assert weight.ndim == 3, "the weight dimension of conv1d should be 3" - if amp._enabled: - compute_mode = "float32" - inp, weight, bias = cast_tensors(inp, weight, bias) - else: - dtype = dtype_promotion(inp, weight) - if inp.dtype != dtype: - inp = inp.astype(dtype) - if weight.dtype != dtype: - weight = weight.astype(dtype) - if bias is not None: assert bias.ndim == 3, "the bias dimension of conv1d should be 3" @@ -522,12 +511,6 @@ def local_conv2d( pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) - dtype = dtype_promotion(inp, weight) - if inp.dtype != dtype: - inp = inp.astype(dtype) - if weight.dtype != dtype: - weight = weight.astype(dtype) - # local conv only support "dense" mode, but weight could contain group dimension. op = builtin.GroupLocal( stride_h=stride_h, diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index a1c6aaa7e65dff5fe0d1abaa5988d95229f66d7e..834a7af5f322d6a78a9ca9b9c5b70641fda84950 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -433,6 +433,8 @@ WRAP_FUNC_PY35(reshape_cpp); WRAP_FUNC_PY35(adaptive_pool2d_cpp); WRAP_FUNC_PY35(Const); WRAP_FUNC_PY35(astype_cpp); +WRAP_FUNC_PY35(matmul_cpp); +WRAP_FUNC_PY35(batched_matmul_cpp); WRAP_FUNC_PY35(convert_single_value_cpp); WRAP_FUNC_PY35(convert_inputs_cpp); WRAP_FUNC_PY35(astensor1d_cpp); @@ -588,6 +590,8 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp), MGE_PY_INTERFACE(Const, Const), MGE_PY_INTERFACE(astype_cpp, astype_cpp), + MGE_PY_INTERFACE(matmul_cpp, matmul_cpp), + MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp), MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 05c0d4a9d8c3202de3e742c59c50df10d26809dc..70375d574a05ceae0c4a9a1544a136f23c623034 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1490,6 +1490,78 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { return ret[0]; } +py::object _matmul_cpp( + py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, + py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, + py::handle format, py::handle profile, py::handle determistic, + py::handle strategy, py::handle func) { + if (enable_fastpath(inp1)) { + ::megdnn::param::MatrixMul::ComputeMode mode = + ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; + if (compute_mode.cast().compare(std::string("float32")) == 0) { + mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; + } + ::megdnn::param::ExecutionPolicy::Strategy cstrategy; + if (profile.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; + } else { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; + } + if (determistic.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; + } + std::shared_ptr op = MatrixMul::make( + transpose_a.cast(), transpose_b.cast(), mode, + ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); + + py::object Op = py::cast(op); + PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); + return ret[0]; + } else { + // fallback to traceable implementation + return func( + inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, + strategy); + } +} + +py::object _batched_matmul_cpp( + py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, + py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, + py::handle format, py::handle profile, py::handle determistic, + py::handle strategy, py::handle func) { + if (enable_fastpath(inp1)) { + ::megdnn::param::MatrixMul::ComputeMode mode = + ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; + if (compute_mode.cast().compare(std::string("float32")) == 0) { + mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; + } + ::megdnn::param::ExecutionPolicy::Strategy cstrategy; + if (profile.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; + } else { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; + } + if (determistic.cast()) { + cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; + } + std::shared_ptr op = BatchedMatrixMul::make( + transpose_a.cast(), transpose_b.cast(), mode, + ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); + + py::object Op = py::cast(op); + PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); + return ret[0]; + } else { + // fallback to traceable implementation + return func( + inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, + strategy); + } +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(args[0]).release().ptr(); @@ -1574,6 +1646,28 @@ PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _matmul_cpp( + args[0], args[1], args[2], args[3], args[4], args[5], args[6], + args[7], args[8], args[9], args[10], args[11]) + .release() + .ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + +PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _batched_matmul_cpp( + args[0], args[1], args[2], args[3], args[4], args[5], args[6], + args[7], args[8], args[9], args[10], args[11]) + .release() + .ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + PyObject* convert_single_value_cpp( PyObject* self, PyObject* const* args, size_t nargs) { try { diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index ab832f669f5035f3737c45b1f478243ff7767b32..6ceaa6bd6d284c9b8227cb15698f6b183b909f37 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -30,6 +30,10 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs); + +PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs); + PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manager.py similarity index 100% rename from imperative/python/test/unit/autodiff/test_grad_manger.py rename to imperative/python/test/unit/autodiff/test_grad_manager.py diff --git a/imperative/src/impl/ops/dot.cpp b/imperative/src/impl/ops/dot.cpp deleted file mode 100644 index 183e338ed6b0afa123e2aedbbf90a3b19bf87780..0000000000000000000000000000000000000000 --- a/imperative/src/impl/ops/dot.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include "megbrain/imperative/opr_utility.h" -#include "megbrain/imperative/ops/autogen.h" -#include "megbrain/imperative/utils/stats.h" -#include "megbrain/opr/basic_arith.h" -#include "megbrain/opr/blas.h" -#include "megbrain/opr/utility.h" - -#include "../blob_manager_impl.h" -#include "../dnn_op_helper.h" -#include "../op_trait.h" - -namespace mgb { -namespace imperative { - -namespace { -namespace dot { - -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = def.cast_final_safe(); - mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{op.make_name()}; - return opr::Dot::make(inputs[0], inputs[1], config); -} - -SmallVector apply_on_physical_tensor( - const OpDef& def, const SmallVector& inputs, - SmallVector& output_descs, const bool& validated) { - auto comp_node = inputs[0]->comp_node(); - using TensorND = megdnn::TensorND; - SmallVector inp_tensornds; - inp_tensornds.reserve(inputs.size()); - DnnOprCaller dnn_opr(comp_node); - for (unsigned i = 0; i < inputs.size(); ++i) { - auto dnn_ten = inputs[i]->dnn_tensor(); - inp_tensornds.push_back(dnn_ten); - } - TensorLayout oup_layout{inputs[0]->dtype()}; - auto inp1_tensor = inputs[0]->dnn_tensor(); - auto inp2_tensor = inputs[1]->dnn_tensor(); - dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout); - - if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) { - DnnOprCaller fill_opr(comp_node); - DeviceTensorND out = - BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); - fill_opr.op->param() = 0; - fill_opr.op->exec(out.as_megdnn(), {}); - return {Tensor::make(out)}; - } - - auto sz = dnn_opr.op->get_workspace_in_bytes( - inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout); - - DeviceTensorND out_devtensor = - BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); - - TensorLayout w_layout({sz}, dtype::Byte()); - auto dnn_wk = dnn_opr.create_workspace(w_layout); - - dnn_opr.op->exec( - inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk); - - return {Tensor::make(out_devtensor)}; -} - -std::tuple, bool> infer_output_attrs_fallible( - const OpDef& def, const SmallVector& inputs) { - mgb_assert( - inputs.size() == 2, "Dot expects 2 inputs; got %lu actually", - inputs.size()); - SmallVector dests(1); - dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype); - dests[0].comp_node = inputs[0].comp_node; - bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0; - return {dests, validated}; -} - -OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot) - .apply_on_var_node(apply_on_var_node) - .infer_output_attrs_fallible(infer_output_attrs_fallible) - .apply_on_physical_tensor(apply_on_physical_tensor) - .fallback(); - -} // namespace dot -} // anonymous namespace -} // namespace imperative -} // namespace mgb \ No newline at end of file diff --git a/imperative/src/impl/ops/matmul.cpp b/imperative/src/impl/ops/matmul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88494c34d285384bfa0840178504ef90b9ca1c75 --- /dev/null +++ b/imperative/src/impl/ops/matmul.cpp @@ -0,0 +1,435 @@ +#include +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/blas.h" + +#include "../algo_chooser.h" + +namespace mgb { +namespace imperative { + +namespace { +namespace matrix_mul { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& matmul = def.cast_final_safe(); + mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{matmul.make_name()}; + return opr::MatrixMul::make( + inputs[0], inputs[1], matmul.param(), matmul.policy(), config); +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& matmul = def.cast_final_safe(); + auto layout1 = inputs[0].layout; + auto layout2 = inputs[1].layout; + size_t dim1 = layout1.ndim, dim2 = layout2.ndim; + + if (dim1 == 0 || dim2 == 0) { + return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; + } + + if (matmul.transposeA) + std::swap(layout1[0], layout1[1]); + if (matmul.transposeB) + std::swap(layout2[0], layout2[1]); + + mgb_assert(layout1[dim1 - 1] == layout2[0]); + TensorLayout dst_layout(layout1.dtype); + size_t ci = 0; + for (size_t i = 0; i < dim1 - 1; i++) + dst_layout[ci++] = layout1[i]; + if (dim2 == 2) + dst_layout[ci++] = layout2[1]; + dst_layout.ndim = ci; + dst_layout.init_contiguous_stride(); + + SmallVector out_descs(1u); + out_descs[0] = {dst_layout, inputs[0].comp_node}; + return {out_descs, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& matmul = def.cast_final_safe(); + auto&& cn = inputs[0]->comp_node(); + + using TensorND = megdnn::TensorND; + SmallVector inp_tensornds(inputs.size()); + TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); + + // only matters when layout1 has dim 2 + if (matmul.transposeA) + std::swap(layout1.shape[0], layout1.shape[1]); + // only matters when layout2 has dim 2 + if (matmul.transposeB) + std::swap(layout2.shape[0], layout2.shape[1]); + + size_t dim1 = layout1.ndim, dim2 = layout2.ndim; + TensorLayout real_dst_layout(layout1.dtype); + if (validated) { + real_dst_layout = output_descs[0].layout; + } else { + size_t ri = 0; + for (size_t i = 0; i < dim1 - 2; i++) + real_dst_layout[ri++] = layout1[i]; + real_dst_layout[ri++] = layout1[dim1 - 2]; + if (dim2 == 2) + real_dst_layout[ri++] = layout2[dim2 - 1]; + real_dst_layout.ndim = ri; + real_dst_layout.init_contiguous_stride(); + } + + if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(cn, real_dst_layout); + if (!out.empty()) { + dev_tensor_memset(out, 0); + } + return {Tensor::make(out)}; + } + + TensorLayout layout_a = layout1, layout_b = layout2; + if (dim1 == 1) { + layout_a.add_axis_cont_inplace(0); + inp_tensornds[0] = inputs[0]->dnn_tensor(); + inp_tensornds[0].layout = layout_a; + } else if (dim1 > 2) { + size_t batch = std::accumulate( + layout1.shape, layout1.shape + dim1 - 1, (size_t)1, + std::multiplies()); + + TensorShape na = TensorShape{batch, layout1[dim1 - 1]}; + auto inp1 = inputs[0]; + if (!layout1.try_reshape(layout_a, na)) { + inp1 = Tensor::make(inp1->blob(), inp1->offset(), layout1); + inp1->to_contiguous_inplace(); + layout1 = inp1->layout(); + layout_a = TensorLayout{{batch, layout1[dim1 - 1]}, layout1.dtype}; + } + + layout_a.init_contiguous_stride(); + inp_tensornds[0] = inp1->dnn_tensor(); + inp_tensornds[0].layout = layout_a; + } else { + inp_tensornds[0] = inputs[0]->dnn_tensor(); + } + + if (dim2 == 1) { + layout_b.add_axis_inplace(1, 1, 1); + inp_tensornds[1] = inputs[1]->dnn_tensor(); + inp_tensornds[1].layout = layout_b; + } else { + inp_tensornds[1] = inputs[1]->dnn_tensor(); + } + + TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype); + dst_layout.init_contiguous_stride(); + + DnnOprCaller dnn_opr(cn); + dnn_opr.op->param() = matmul.param(); + + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); + size_t sz = setup_algo( + {layout_a, layout_b, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, + matmul.policy(), false); + TensorLayout w_layout({sz}, dtype::Byte()); + auto dnn_wk = dnn_opr.create_workspace(w_layout); + + dnn_opr.op->exec(inp_tensornds[0], inp_tensornds[1], out.as_megdnn(), dnn_wk); + return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(real_dst_layout)))}; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { + return layout.is_contiguous(); + }; + return layout_checker; +} + +OP_TRAIT_REG(MatrixMul, MatrixMul) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace matrix_mul +} // namespace + +namespace { +namespace batched_matrix_mul { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& matmul = def.cast_final_safe(); + mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{matmul.make_name()}; + return opr::BatchedMatrixMul::make( + inputs[0], inputs[1], matmul.param(), matmul.policy(), config); +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& matmul = def.cast_final_safe(); + TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; + size_t dim1 = layout1.ndim, dim2 = layout2.ndim; + + if (dim1 == 0 || dim2 == 0) { + return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; + } + + if (matmul.transposeA) + std::swap(layout1[dim1 - 1], layout1[dim1 - 2]); + if (matmul.transposeB) + std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); + + TensorLayout dst_layout(layout1.dtype); + size_t di = 0; + if (dim1 > dim2) { + for (size_t i = 0; i < dim1 - 2; i++) + dst_layout[di++] = layout1[i]; + } else { + for (size_t i = 0; i < dim2 - 2; i++) + dst_layout[di++] = layout2[i]; + } + if (dim1 > 1) + dst_layout[di++] = layout1[dim1 - 2]; + if (dim2 > 1) + dst_layout[di++] = layout2[dim2 - 1]; + dst_layout.ndim = di; + dst_layout.init_contiguous_stride(); + + SmallVector out_descs(1u); + out_descs[0] = {dst_layout, inputs[0].comp_node}; + return {out_descs, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& matmul = def.cast_final_safe(); + auto&& cn = inputs[0]->comp_node(); + + TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); + size_t dim1 = layout1.ndim, dim2 = layout2.ndim; + + bool remove_row = false, remove_col = false; + if (dim1 == 1) { + dim1 = 2; + remove_row = true; + } + if (dim2 == 1) { + dim2 = 2; + remove_col = true; + } + + if (remove_row) + layout1.add_axis_cont_inplace(0); + if (remove_col) + layout2.add_axis_inplace(1, 1, 1); + + TensorShape tshp, batch_shp; + size_t j = 0; + if (dim1 > dim2) { + for (size_t i = 0; i < dim1 - 2; i++) + tshp[j++] = layout1.shape[i]; + batch_shp = tshp; + batch_shp.ndim = dim1 - 2; + tshp[j++] = layout2[layout2.ndim - 2]; + tshp[j++] = layout2[layout2.ndim - 1]; + tshp.ndim = j; + layout2 = layout2.broadcast(tshp); + } + if (dim2 > dim1) { + for (size_t i = 0; i < dim2 - 2; i++) + tshp[j++] = layout2.shape[i]; + batch_shp = tshp; + batch_shp.ndim = dim2 - 2; + tshp[j++] = layout1[layout1.ndim - 2]; + tshp[j++] = layout1[layout1.ndim - 1]; + tshp.ndim = j; + layout1 = layout1.broadcast(tshp); + } + if (dim1 == dim2) { + for (size_t i = 0; i < dim1 - 2; i++) + tshp[j++] = layout1.shape[i]; + batch_shp = tshp; + batch_shp.ndim = dim1 - 2; + } + + TensorShape shp1 = batch_shp, shp2 = batch_shp; + shp1.ndim += 2; + shp2.ndim += 2; + size_t maxdim = dim1 > dim2 ? dim1 : dim2; + size_t nbatch = batch_shp[0]; + auto inp1 = inputs[0], inp2 = inputs[1]; + if (maxdim > 3) { + nbatch = std::accumulate( + batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, + std::multiplies()); + + TensorLayout layout_a; + + TensorShape nl1 = TensorShape( + {nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); + if (!layout1.try_reshape(layout_a, nl1)) { + inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); + inp1->to_contiguous_inplace(); + layout1 = inp1->layout(); + } + layout1 = layout_a; + + TensorShape nl2 = TensorShape( + {nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); + if (!layout2.try_reshape(layout_a, nl2)) { + inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); + inp2->to_contiguous_inplace(); + layout2 = inp2->layout(); + } + layout2 = layout_a; + } + + TensorLayout dst_layout( + {nbatch, matmul.transposeA ? layout1[2] : layout1[1], + matmul.transposeB ? layout2[1] : layout2[2]}, + layout1.dtype); + dst_layout.init_contiguous_stride(); + + if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); + if (!out.empty()) { + dev_tensor_memset(out, 0); + } + return {Tensor::make(out)}; + } + + using TensorND = megdnn::TensorND; + TensorND inp_nd1 = inp1->dnn_tensor(); + inp_nd1.layout = layout1; + TensorND inp_nd2 = inp2->dnn_tensor(); + inp_nd2.layout = layout2; + + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); + + DnnOprCaller dnn_opr(cn); + dnn_opr.op->param() = matmul.param(); + + size_t sz = setup_algo( + {layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, + matmul.policy(), false); + TensorLayout w_layout({sz}, dtype::Byte()); + auto dnn_wk = dnn_opr.create_workspace(w_layout); + dnn_opr.op->exec(inp_nd1, inp_nd2, out.as_megdnn(), dnn_wk); + + shp1[shp1.ndim - 2] = dst_layout[dst_layout.ndim - 2]; + shp1[shp1.ndim - 1] = dst_layout[dst_layout.ndim - 1]; + if (maxdim > 3) { + dst_layout = dst_layout.reshape(shp1); + } + if (remove_row) { + dst_layout = dst_layout.remove_axis(maxdim - 2); + } + if (remove_col) { + dst_layout = dst_layout.remove_axis(maxdim - 1); + } + return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { + return layout.is_contiguous(); + }; + return layout_checker; +} + +OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .get_input_layout_constraint(get_input_layout_constraint) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); +} // namespace batched_matrix_mul +} // namespace + +namespace { +namespace dot { + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{op.make_name()}; + return opr::Dot::make(inputs[0], inputs[1], config); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto comp_node = inputs[0]->comp_node(); + using TensorND = megdnn::TensorND; + SmallVector inp_tensornds; + inp_tensornds.reserve(inputs.size()); + DnnOprCaller dnn_opr(comp_node); + for (unsigned i = 0; i < inputs.size(); ++i) { + auto dnn_ten = inputs[i]->dnn_tensor(); + inp_tensornds.push_back(dnn_ten); + } + TensorLayout oup_layout{inputs[0]->dtype()}; + auto inp1_tensor = inputs[0]->dnn_tensor(); + auto inp2_tensor = inputs[1]->dnn_tensor(); + dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout); + + if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) { + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); + if (!out.empty()) { + dev_tensor_memset(out, 0); + } + return {Tensor::make(out)}; + } + + auto sz = dnn_opr.op->get_workspace_in_bytes( + inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout); + + DeviceTensorND out_devtensor = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); + + TensorLayout w_layout({sz}, dtype::Byte()); + auto dnn_wk = dnn_opr.create_workspace(w_layout); + + dnn_opr.op->exec( + inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk); + + return {Tensor::make(out_devtensor)}; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + mgb_assert( + inputs.size() == 2, "Dot expects 2 inputs; got %lu actually", + inputs.size()); + SmallVector dests(1); + dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype); + dests[0].comp_node = inputs[0].comp_node; + bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0; + return {dests, validated}; +} + +OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); + +} // namespace dot +} // anonymous namespace + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index ffd2400e0f72e9ca5573fa7ccc392ac28655a5c8..35a49ef4de801623df13428049ee21d35a072ae5 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -123,7 +123,6 @@ SmallVector apply_on_physical_tensor( inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src); auto mode = op_def.param().mode; - DnnOprCaller fill_op(comp_node); if (!keepdim && src.ndim > 1) { layout.remove_axis_inplace(axis); @@ -135,12 +134,12 @@ SmallVector apply_on_physical_tensor( switch (mode) { case Reduce::Mode::SUM: if (!out.empty()) { - fill_op.op->param() = 0; - fill_op.op->exec(out.as_megdnn(), {}); + dev_tensor_memset(out, 0); } break; case Reduce::Mode::PRODUCT: if (!out.empty()) { + DnnOprCaller fill_op(comp_node); fill_op.op->param() = 1; fill_op.op->exec(out.as_megdnn(), {}); } diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 0afcaef8d0cce3189006274a9d64e09f43b07ba1..ff82a840764643b3b6a7600a3285bbac89911ab6 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -319,34 +319,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias) } // namespace batch_conv_bias } // namespace -namespace { -namespace matrix_mul { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& matmul = static_cast(def); - mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{matmul.make_name()}; - return opr::MatrixMul::make( - inputs[0], inputs[1], matmul.param(), matmul.policy(), config); -} -OP_TRAIT_REG(MatrixMul, MatrixMul).apply_on_var_node(apply_on_var_node).fallback(); -} // namespace matrix_mul -} // namespace - -namespace { -namespace batched_matrix_mul { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& matmul = static_cast(def); - mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{matmul.make_name()}; - return opr::BatchedMatrixMul::make( - inputs[0], inputs[1], matmul.param(), matmul.policy(), config); -} -OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) - .apply_on_var_node(apply_on_var_node) - .fallback(); -} // namespace batched_matrix_mul -} // namespace - namespace { namespace argsort { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index d39b6565a48a516a6b8733bc4a576c803a66f9ae..5501bcf988a103add1dfd57dbdeef28e8ed3cdb0 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -183,6 +183,57 @@ ValueRefList convolution_rule(const OpDef& op, Span inputs) { return imperative::apply(op, converted); } +ValueRefList matmul_rule(const OpDef& op, Span inputs) { + auto&& conv_op = const_cast(op.cast_final_safe()); + SmallVector dtypes = get_value_dtypes(inputs); + mgb::DType target_dtype; + + if (DTypePromoteCfg::amp_dtype_autocast_enabled) { + conv_op.compute_mode = MatrixMul::ComputeMode::FLOAT32; + target_dtype = DTypePromoteCfg::amp_low_prec_dtype; + } else { + target_dtype = get_promoted_dtype(dtypes); + } + + ValueRefList converted(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (dtypes[i] != target_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); +} + +ValueRefList batch_matmul_rule(const OpDef& op, Span inputs) { + auto&& conv_op = + const_cast(op.cast_final_safe()); + SmallVector dtypes = get_value_dtypes(inputs); + mgb::DType target_dtype; + + if (DTypePromoteCfg::amp_dtype_autocast_enabled) { + conv_op.compute_mode = BatchedMatrixMul::ComputeMode::FLOAT32; + target_dtype = DTypePromoteCfg::amp_low_prec_dtype; + } else { + target_dtype = get_promoted_dtype(dtypes); + } + + ValueRefList converted(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (dtypes[i] != target_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); +} + // differ from Convolution, ConvolutionBackwardData is used in both // functional.conv_transpose2d and quantize.conv_transpose2d ValueRefList convolution_backward_rule(const OpDef& op, Span inputs) { @@ -259,8 +310,11 @@ struct DTypePromoteRuleRegistry { DTypePromoteRuleRegistry() { register_dtype_promote_rule(elemwise_rule); register_dtype_promote_rule(naive_promote_rule); + register_dtype_promote_rule(naive_promote_rule); register_dtype_promote_rule(reduce_rule); register_dtype_promote_rule(convolution_rule); + register_dtype_promote_rule(matmul_rule); + register_dtype_promote_rule(batch_matmul_rule); register_dtype_promote_rule(convolution_backward_rule); register_dtype_promote_rule(batch_norm_rule); register_dtype_promote_rule(naive_promote_rule);