/** * \file dnn/src/naive/matrix_mul/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/naive/matrix_mul/opr_impl.h" #include "src/common/utils.h" #include "src/naive/handle.h" #include "./matrix_mul_helper.h" #include "midout.h" MIDOUT_DECL(megdnn_naive_matmul) namespace megdnn { namespace naive { size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, const TensorLayout&) { if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { return (A.span().dist_elem() + B.span().dist_elem()) * sizeof(uint8_t); } return 0; } template void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace, const MatrixMul::Param& param) { auto M = C.layout.shape[0], N = C.layout.shape[1]; auto K = A.layout.shape[param.transposeA ? 0 : 1]; auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], LDC = C.layout.stride[0]; #define cb(_itype, _otype, _comp_type) \ if (param.format == param::MatrixMul::Format::DEFAULT) { \ return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ A.layout.dtype, B.layout.dtype); \ } else if (param.format == param::MatrixMul::Format::MK4) { \ return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ A.layout.dtype, B.layout.dtype); \ } else if (param.format == param::MatrixMul::Format::MK8) { \ return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ A.layout.dtype, B.layout.dtype); \ } if (A.layout.dtype == dtype::Float32()) { cb(dt_float32, dt_float32, dt_float32); #if !MEGDNN_DISABLE_FLOAT16 } else if (A.layout.dtype == dtype::Float16()) { using Param = MatrixMul::Param; if (param.compute_mode == Param::ComputeMode::DEFAULT) { cb(dt_float16, dt_float16, dt_float16); } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { cb(dt_float16, dt_float16, dt_float32); } } else if (A.layout.dtype == dtype::BFloat16()) { using Param = MatrixMul::Param; if (param.compute_mode == Param::ComputeMode::DEFAULT) { cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { cb(dt_bfloat16, dt_bfloat16, dt_float32); } #endif } else if (A.layout.dtype == dtype::Int8() && C.layout.dtype == dtype::Int16()) { cb(dt_int8, dt_int16, dt_int16); } else if (A.layout.dtype == dtype::Int16() && C.layout.dtype == dtype::Int32()) { cb(dt_int16, dt_int32, dt_int32); } else if ((A.layout.dtype == dtype::Int8() || A.layout.dtype.enumv() == DTypeEnum::QuantizedS8) && (C.layout.dtype == dtype::Int32() || C.layout.dtype.enumv() == DTypeEnum::QuantizedS32)) { cb(dt_int8, dt_int32, dt_int32); } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && C.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { cb(uint8_t, dt_int32, dt_int32); } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && C.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && param.format == param::MatrixMul::Format::DEFAULT) { exec_matrix_mul_quint4x4x32_helper(A, B, C, workspace, param); return; } #undef cb megdnn_throw(ssprintf( "unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", A.layout.dtype.name(), B.layout.dtype.name(), C.layout.dtype.name(), static_cast(param.compute_mode))); } void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace, const Param& param) { #define DISPATCH(TA, TB) \ if (param.transposeA == TA && param.transposeB == TB) { \ dispatch_ta_tb(A, B, C, workspace, param); \ return; \ } DISPATCH(true, true); DISPATCH(true, false); DISPATCH(false, true); DISPATCH(false, false); #undef DISPATCH megdnn_assert_internal(0); } void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) { MIDOUT_BEGIN(megdnn_naive_matmul) { check_exec(A.layout, B.layout, C.layout, workspace.size); auto p = param(); MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(A, B, C, workspace, p)); } MIDOUT_END(); } } // namespace naive } // namespace megdnn // vim: syntax=cpp.doxygen