helper.h 9.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution3d/helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include "./opr_impl.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include "src/common/utils.h"
#include "src/common/algo_chooser.h"
#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {
namespace convolution3d {
    using CanonizedFilterMeta = Convolution3DForward::CanonizedFilterMeta;

    //! conv size descriptor in the forward view
    struct ForwardSizeArgs {
        HandleImpl *handle;
        const TensorLayout *src_layout;
        CanonizedFilterMeta filter_meta;
        const TensorLayout *dst_layout;
        param::Convolution3D::DataType data_type;
    };

    //! whether cudnn is supported for a filter meta
    bool is_cudnn_supported(const ForwardSizeArgs &args);

    struct CUDNNForwardDescs {
        Tensor3DDesc src_desc, dst_desc;
        Filter3DDesc filter_desc;
        Conv3DDesc conv_desc;
        void set(const TensorLayout &src,
                const CanonizedFilterMeta &filter,
                const TensorLayout &dst,
                const param::Convolution3D &param)
        {
            src_desc.set(src);
            filter_desc.set(filter);
            dst_desc.set(dst);
            conv_desc.set(param, filter.group);
        }
    };

    struct CUDNNBwdDataDescs {
        Tensor3DDesc diff_desc, grad_desc;
        Filter3DDesc filter_desc;
        Conv3DDesc conv_desc;
        void set(const CanonizedFilterMeta &filter,
                const TensorLayout &diff,
                const TensorLayout &grad,
                const param::Convolution3D &param)
        {
            filter_desc.set(filter);
            diff_desc.set(diff);
            grad_desc.set(grad);
            conv_desc.set(param, filter.group);
        }
    };

    struct CUDNNBwdFilterDescs {
        Tensor3DDesc diff_desc, src_desc;
        Filter3DDesc grad_desc;
        Conv3DDesc conv_desc;
        void set(const TensorLayout &src,
                const TensorLayout &diff,
                const CanonizedFilterMeta &grad,
                const param::Convolution3D &param)
        {
            src_desc.set(src);
            diff_desc.set(diff);
            grad_desc.set(grad);
            conv_desc.set(param, grad.group);
        }
    };

    /*!
     * \brief flip conv filter
     *
     * Flip conv filter pointed by \p raw_ptr, store result in workspace, and
     * change \p raw_ptr to workspace.
     */
    void flip_filter(const ForwardSizeArgs &args,
            const Workspace &workspace, void *&raw_ptr);

    inline bool cudnn_get_convolution_fwd_algo_helper(
95
            Handle* handle, const cudnnTensorDescriptor_t x_desc,
96 97 98 99
            const cudnnFilterDescriptor_t w_desc,
            const cudnnConvolutionDescriptor_t conv_desc,
            const cudnnTensorDescriptor_t y_desc,
            size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo,
100 101 102 103
            const AlgoAttribute& positive_attr,
            const AlgoAttribute& negative_attr) {
        MEGDNN_MARK_USED_VAR(positive_attr);
        MEGDNN_MARK_USED_VAR(negative_attr);
104
#if CUDNN_MAJOR >= 7
105
        auto& cudnn = static_cast<HandleImpl*>(handle)->cudnn();
106
        int algo_max_count = 0;
107 108
        cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(
                cuda::cudnn_handle(handle), &algo_max_count));
109 110
        SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count);
        int algo_count = 0;
111 112
        cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7(
                cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count,
113 114 115 116 117 118 119
                &algo_count, algo_perf.data()));
        for (int i = 0; i < algo_count; ++i) {
            if (algo_perf[i].algo ==
                cudnnConvolutionFwdAlgo_t::
                        CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING)
                continue;
            size_t workspace_size = 0;
120 121
            cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize(
                    cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc,
122 123
                    algo_perf[i].algo, &workspace_size));
            if (workspace_size > workspace_limit_in_bytes) continue;
124
            if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
125 126 127 128 129 130 131 132 133 134 135 136
                *algo = algo_perf[i].algo;
                return true;
            } else {
                if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
                    *algo = algo_perf[i].algo;
                    return true;
                }
            }
        }
        return false;
#else
        cudnn_check(cudnnGetConvolutionForwardAlgorithm(
137
                cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc,
138 139 140 141 142 143 144 145 146 147 148 149
                CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
                workspace_limit_in_bytes, algo));
        return true;
#endif
    }

    inline bool cudnn_get_convolution_bwd_data_algo_helper(
            cudnnHandle_t cudnn_handle, const cudnnFilterDescriptor_t w_desc,
            const cudnnTensorDescriptor_t dy_desc,
            const cudnnConvolutionDescriptor_t conv_desc,
            const cudnnTensorDescriptor_t dx_desc,
            size_t workspace_limit_in_bytes,
150 151 152 153 154
            cudnnConvolutionBwdDataAlgo_t* algo,
            const AlgoAttribute& positive_attr,
            const AlgoAttribute& negative_attr) {
        MEGDNN_MARK_USED_VAR(positive_attr);
        MEGDNN_MARK_USED_VAR(negative_attr);
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
#if CUDNN_MAJOR >= 7
        int algo_max_count = 0;
        cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(
                cudnn_handle, &algo_max_count));
        SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(
                algo_max_count);
        int algo_count = 0;
        cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7(
                cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc,
                algo_max_count, &algo_count, algo_perf.data()));
        for (int i = 0; i < algo_count; ++i) {
            if (algo_perf[i].algo ==
                cudnnConvolutionBwdDataAlgo_t::
                        CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING)
                continue;
            size_t workspace_size = 0;
            cudnn_check(cudnnGetConvolutionBackwardDataWorkspaceSize(
                    cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc,
                    algo_perf[i].algo, &workspace_size));
            if (workspace_size > workspace_limit_in_bytes) continue;
175
            if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
                *algo = algo_perf[i].algo;
                return true;
            } else {
                if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
                    *algo = algo_perf[i].algo;
                    return true;
                }
            }
        }
        return false;
#else
        cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle,
                    w_desc, dy_desc, conv_desc, dx_desc,
                    CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
                    workspace_limit_in_bytes,
                    algo));
        return true;
#endif
    }

    inline bool cudnn_get_convolution_bwd_filter_algo_helper(
            cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc,
            const cudnnTensorDescriptor_t dy_desc,
            const cudnnConvolutionDescriptor_t conv_desc,
            const cudnnFilterDescriptor_t dw_desc,
            size_t workspace_limit_in_bytes,
202 203 204 205 206
            cudnnConvolutionBwdFilterAlgo_t* algo,
            const AlgoAttribute& positive_attr,
            const AlgoAttribute& negative_attr) {
        MEGDNN_MARK_USED_VAR(positive_attr);
        MEGDNN_MARK_USED_VAR(negative_attr);
207 208 209 210 211 212 213 214 215 216 217 218
#if CUDNN_MAJOR >= 7
        int algo_max_count = 0;
        cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(
                cudnn_handle, &algo_max_count));
        SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(
                algo_max_count);
        int algo_count = 0;
        cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
                cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc,
                algo_max_count, &algo_count, algo_perf.data()));
        for (int i = 0; i < algo_count; ++i) {
            if (algo_perf[i].algo ==
219 220
                cudnnConvolutionBwdFilterAlgo_t::
                        CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING)
221 222 223 224 225 226
                continue;
            size_t workspace_size = 0;
            cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize(
                    cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc,
                    algo_perf[i].algo, &workspace_size));
            if (workspace_size > workspace_limit_in_bytes) continue;
227
            if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
                *algo = algo_perf[i].algo;
                return true;
            } else {
                if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
                    *algo = algo_perf[i].algo;
                    return true;
                }
            }
        }
        return false;
#else
        cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm(
                cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc,
                CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
                workspace_limit_in_bytes, algo));
        return true;
#endif
    }

} // namespace convolution3d
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen