matrix_mul_float_simt_cutlass_wrapper.cuinl 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/**
 * \file
 * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "cutlass/gemm/device/gemm.h"
14
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
15 16 17 18 19 20 21 22 23 24 25 26 27
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"

using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;

template <typename Gemm>
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper(
        const typename Gemm::ElementA* d_A, size_t lda,
        const typename Gemm::ElementB* d_B, size_t ldb,
        typename Gemm::ElementC* d_C, size_t ldc, int* workspace,
        GemmCoord const& problem_size,
        typename Gemm::EpilogueOutputOp::Params const& epilogue,
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
        cudaStream_t stream, int split_k_slices) {
    using TensorRefA = cutlass::TensorRef<typename Gemm::ElementA const,
                                          typename Gemm::LayoutA>;
    using TensorRefB = cutlass::TensorRef<typename Gemm::ElementB const,
                                          typename Gemm::LayoutB>;
    using TensorRefC = cutlass::TensorRef<typename Gemm::ElementC const,
                                          typename Gemm::LayoutC>;
    using TensorRefD =
            cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>;
    TensorRefA tensor_a{const_cast<typename Gemm::ElementA*>(d_A),
                        typename Gemm::LayoutA{static_cast<int>(lda)}};
    TensorRefB tensor_b{const_cast<typename Gemm::ElementB*>(d_B),
                        typename Gemm::LayoutB{static_cast<int>(ldb)}};
    TensorRefC tensor_c{nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}};
    TensorRefD tensor_d{d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}};
43 44 45 46 47 48 49

    typename Gemm::Arguments arguments{problem_size,
                                       tensor_a,
                                       tensor_b,
                                       tensor_c,
                                       tensor_d.non_const_ref(),
                                       epilogue,
50
                                       split_k_slices};
51 52 53 54 55 56 57
    Gemm gemm_op;
    cutlass_check(gemm_op.initialize(arguments, workspace));
    cutlass_check(gemm_op(stream));
    after_kernel_launch();
}

// vim: syntax=cuda.doxygen