opr_impl.cpp 3.1 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/matrix_mul/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/cuda/matrix_mul/opr_impl.h"
#include "./algos.h"
#include "src/common/algo_chooser.h"

#include <cuda.h>
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
#include "src/cuda/matrix_mul/cublasLt_wrapper.h"

namespace megdnn {
namespace cuda {

std::vector<MatrixMulForwardImpl::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
                                         const TensorLayout& B,
                                         const TensorLayout& C) {
    AlgoBase::SizeArgs args{this, A, B, C};
    return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args);
}

MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
        const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
33
        size_t workspace_limit_in_bytes, const AlgoAttribute& attr) {
34
    AlgoBase::SizeArgs args{this, A, B, C};
35 36
    if (sm_algo_pack.cublas.is_available_attribute(args, attr,
                                                   workspace_limit_in_bytes)) {
37 38 39
        return &sm_algo_pack.cublas;
    }
#if CUDA_VERSION >= 10010
40 41
    if (sm_algo_pack.cublas_lt.is_available_attribute(
                args, attr, workspace_limit_in_bytes)) {
42 43 44 45 46
        return &sm_algo_pack.cublas_lt;
    }
#endif

#if CUDA_VERSION >= 10000
47 48
    if (sm_algo_pack.wmma_uint4x4x32.is_available_attribute(
                args, attr, workspace_limit_in_bytes)) {
49 50 51 52
        return &sm_algo_pack.wmma_uint4x4x32;
    }
#endif

53 54
    if (attr != AlgoAttribute::DEFAULT) {
        return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>(
55
                sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
56
                "matrix mul forward", attr);
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    } else {
        return megdnn::get_usable_algo<MatrixMulForwardImpl>(
                sm_algo_pack.all_algos, args, workspace_limit_in_bytes,
                "matrix mul forward");
    }
}

size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
                                                    const TensorLayout& B,
                                                    const TensorLayout& C) {
    AlgoBase::SizeArgs args{this, A, B, C};
    return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
}

void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
                                _megdnn_tensor_out C,
                                _megdnn_workspace workspace) {
    check_exec(A.layout, B.layout, C.layout, workspace.size);
    AlgoBase::ExecArgs args(this, A, B, C, workspace);
    auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
    algo->check_workspace(args, workspace).exec(args);
}

}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen