/** * \file * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "cutlass/gemm/device/gemm.h" #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" using namespace megdnn; using namespace cuda; using namespace cutlass_wrapper; template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper( const typename Gemm::ElementA* d_A, size_t lda, const typename Gemm::ElementB* d_B, size_t ldb, typename Gemm::ElementC* d_C, size_t ldc, int* workspace, GemmCoord const& problem_size, typename Gemm::EpilogueOutputOp::Params const& epilogue, cudaStream_t stream) { typename Gemm::TensorRefA tensor_a{ const_cast(d_A), typename Gemm::LayoutA{static_cast(lda)}}; typename Gemm::TensorRefB tensor_b{ const_cast(d_B), typename Gemm::LayoutB{static_cast(ldb)}}; typename Gemm::TensorRefC tensor_c{ nullptr, typename Gemm::LayoutC{static_cast(ldc)}}; typename Gemm::TensorRefD tensor_d{ d_C, typename Gemm::LayoutC{static_cast(ldc)}}; typename Gemm::Arguments arguments{problem_size, tensor_a, tensor_b, tensor_c, tensor_d.non_const_ref(), epilogue, 1}; Gemm gemm_op; cutlass_check(gemm_op.initialize(arguments, workspace)); cutlass_check(gemm_op(stream)); after_kernel_launch(); } // vim: syntax=cuda.doxygen