matrix_mul_float_simt_cutlass_wrapper.cuinl 2.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/**
 * \file
 * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "cutlass/gemm/device/gemm.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"

using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;

template <typename Gemm>
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper(
        const typename Gemm::ElementA* d_A, size_t lda,
        const typename Gemm::ElementB* d_B, size_t ldb,
        typename Gemm::ElementC* d_C, size_t ldc, int* workspace,
        GemmCoord const& problem_size,
        typename Gemm::EpilogueOutputOp::Params const& epilogue,
        cudaStream_t stream) {
    typename Gemm::TensorRefA tensor_a{
            const_cast<typename Gemm::ElementA*>(d_A),
            typename Gemm::LayoutA{static_cast<int>(lda)}};
    typename Gemm::TensorRefB tensor_b{
            const_cast<typename Gemm::ElementB*>(d_B),
            typename Gemm::LayoutB{static_cast<int>(ldb)}};
    typename Gemm::TensorRefC tensor_c{
            nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}};
    typename Gemm::TensorRefD tensor_d{
            d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}};

    typename Gemm::Arguments arguments{problem_size,
                                       tensor_a,
                                       tensor_b,
                                       tensor_c,
                                       tensor_d.non_const_ref(),
                                       epilogue,
                                       1};
    Gemm gemm_op;
    cutlass_check(gemm_op.initialize(arguments, workspace));
    cutlass_check(gemm_op(stream));
    after_kernel_launch();
}

// vim: syntax=cuda.doxygen