opr_impl.cpp 12.6 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "src/cuda/conv_bias/opr_impl.h"
2
#include "megdnn/dtype.h"
3
#include "src/cuda/conv_bias/algo.h"
4
#include "src/cuda/conv_bias/helper.h"
5 6 7 8
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"

#include "src/common/algo_chooser.h"
M
Megvii Engine Team 已提交
9
#include "src/common/conv_bias.h"
10 11 12 13 14 15

#include "src/cuda/cudnn_with_check.h"

namespace megdnn {
namespace cuda {

M
Megvii Engine Team 已提交
16 17 18 19 20 21 22 23 24 25 26
void ConvBiasForwardImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_in z, _megdnn_tensor_out dst,
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
    check_exec_allow_noncontiguous(
            src.layout, filter.layout, bias.layout, z.layout, dst.layout,
            workspace.size, preprocessed_filter);
    AlgoBase::ExecArgs args(
            this, src, filter, bias, z, dst, workspace, preprocessed_filter);
    auto algo = get_algorithm(
            this, src.layout, filter.layout, bias.layout, z.layout, dst.layout);
27
    algo->exec(args);
28 29
};

M
Megvii Engine Team 已提交
30 31 32
std::vector<ConvBiasForward::Algorithm*> ConvBiasForwardImpl::get_all_algorithms(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
33 34 35 36
    return megdnn::get_all_algorithms<ConvBiasForwardImpl>(
            {this, src, filter, bias, z, dst});
}

M
Megvii Engine Team 已提交
37 38 39
std::vector<ConvBiasForward::Algorithm*> ConvBiasForwardImpl::get_all_algorithms_safe(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
40 41 42 43
    return megdnn::get_all_algorithms_safe<ConvBiasForwardImpl>(
            {this, src, filter, bias, z, dst});
}

44
ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
M
Megvii Engine Team 已提交
45 46 47
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst, size_t workspace_limit_in_bytes,
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
48 49
    using namespace conv_bias;
    AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
50 51 52 53 54 55 56 57 58 59 60
#if CUDNN_VERSION >= 8004
    if (sm_algo_pack.cudnn_conv_v8.is_available_attribute(
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
        return &sm_algo_pack.cudnn_conv_v8;
    }
    if (sm_algo_pack.cudnn_conv_bias_activation_v8.is_available_attribute(
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
        return &sm_algo_pack.cudnn_conv_bias_activation_v8;
    }
#endif

61 62 63
    auto dst_layout = *args.dst_layout;
    if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
        dst_layout.dtype = DType();
M
Megvii Engine Team 已提交
64 65
        args.opr->check_or_deduce_dtype_fwd(
                args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
66 67 68 69
    }
    auto conv_args = args;

    auto cudnn_conv_bias_act_from_enum_wrapper =
70
            [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
71 72 73 74
        return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo);
    };

    auto cudnn_conv_from_enum_wrapper =
75
            [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
76 77 78 79
        return sm_algo_pack.cudnn_conv_from_enum(algo);
    };

    auto get_cudnn_algo =
80 81
            [this, &conv_args, &args, workspace_limit_in_bytes, positive_attr,
             negative_attr](
M
Megvii Engine Team 已提交
82 83
                    const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>& cb)
            -> AlgoBase* {
84 85 86 87 88
        auto cudnn_handle = cuda::cudnn_handle(this->handle());
        CUDNNForwardDescs desc;
        conv_args.init_conv_desc(desc);
#if CUDNN_MAJOR >= 7
        int max_count = 0;
M
Megvii Engine Team 已提交
89 90
        cudnn_check(
                cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_count));
91 92
        SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count);
        int ret_count = 0;
93
        cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7(
94
                cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
M
Megvii Engine Team 已提交
95 96
                desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, &ret_count,
                algo_perf.data()));
97 98
        for (int i = 0; i < ret_count; ++i) {
            auto conv_bias_algo = cb(algo_perf[i].algo);
99
            if (conv_bias_algo->is_available_attribute(
M
Megvii Engine Team 已提交
100
                        args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
101
                return conv_bias_algo;
102
            }
103 104 105 106 107 108
        }
#else
        cudnnConvolutionFwdAlgo_t algo;
        cudnn_check(cudnnGetConvolutionForwardAlgorithm(
                cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
                desc.conv_desc.conv_desc, desc.dst_desc.desc,
M
Megvii Engine Team 已提交
109 110
                CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_limit_in_bytes,
                &algo));
111 112

        auto conv_bias_algo = cb(algo);
M
Megvii Engine Team 已提交
113 114
        if (conv_bias_algo->is_available_attribute(
                    args, positive_attr, negative_attr, workspace_limit_in_bytes))
115 116 117 118 119
            return conv_bias_algo;
#endif
        return nullptr;
    };

120 121
    auto get_1x1_algo = [workspace_limit_in_bytes, positive_attr,
                         negative_attr](const AlgoBase::SizeArgs& size_arg)
122
            -> ConvBiasForwardImpl::AlgoBase* {
123
        if (sm_algo_pack.batched_matmul.is_available_attribute(
M
Megvii Engine Team 已提交
124
                    size_arg, positive_attr, negative_attr, workspace_limit_in_bytes)) {
125 126 127 128 129
            return &sm_algo_pack.batched_matmul;
        }
        return nullptr;
    };

M
Megvii Engine Team 已提交
130 131 132 133 134 135
    const bool is_chanwise = (args.filter_meta.format == Param::Format::NCHW &&
                              args.filter_meta.group == src[1]) ||
                             (args.filter_meta.format == Param::Format::NCHW4 &&
                              args.filter_meta.group == src[1] * 4) ||
                             (args.filter_meta.format == Param::Format::NCHW32 &&
                              args.filter_meta.group == src[1] * 32);
136 137 138 139 140 141 142 143 144
    // prefer special chanwise impl since as the group conv of cudnn
    // whose version is lower than v7.5.0 is still slower than our
    // implementation in many channel-wise cases
    const bool slow_cudnn_chanwise_impl =
            CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5);
    //! choose CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM default for large image
    const int hw_size = src[2] * src[3];
    //! choose dnn when stride != 1, may need calibrate for different cudnn
    //! version
M
Megvii Engine Team 已提交
145 146 147
    const bool prefer_dnn_chanwise = slow_cudnn_chanwise_impl ||
                                     args.filter_meta.stride[0] != 1 ||
                                     args.filter_meta.stride[1] != 1 || hw_size < 512;
148
    //! choose for large kernel cases
149
    size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1];
150
    size_t hi = src[2], wi = src[3];
151 152 153
    const bool prefer_dnn_lk_implbmm = hi <= 2 * fh && wi <= 2 * fw;
    //! filter size > 9, choose large kernel cases
    const bool prefer_direct_lk = fh > 9 && fw > 9;
154 155
    //! avoid bad case in cudnn, check dnn chanwise impl first
    if (is_chanwise) {
156
        if (prefer_dnn_lk_implbmm) {
157
#if CUDA_VERSION >= 10020
158 159 160
            if (sm_algo_pack.f16_implicit_bmm[0].is_available_attribute(
                        args, positive_attr, negative_attr, workspace_limit_in_bytes))
                return &sm_algo_pack.f16_implicit_bmm[0];
161
#endif
162 163 164
            if (sm_algo_pack.f32_implicit_bmm[0].is_available_attribute(
                        args, positive_attr, negative_attr, workspace_limit_in_bytes))
                return &sm_algo_pack.f32_implicit_bmm[0];
165 166 167 168 169
        } else if (
                prefer_direct_lk &&
                sm_algo_pack.depthwise_large_filter.is_available_attribute(
                        args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
            return &sm_algo_pack.depthwise_large_filter;
170
        } else if (prefer_dnn_chanwise) {
171
            if (sm_algo_pack.chanwise.is_available_attribute(
M
Megvii Engine Team 已提交
172
                        args, positive_attr, negative_attr, workspace_limit_in_bytes))
173
                return &sm_algo_pack.chanwise;
174
            if (sm_algo_pack.chanwise8x8x32.is_available_attribute(
M
Megvii Engine Team 已提交
175
                        args, positive_attr, negative_attr, workspace_limit_in_bytes))
176 177 178 179 180 181 182 183 184 185 186
                return &sm_algo_pack.chanwise8x8x32;
        } else {
            conv_args.dst_layout = &dst_layout;
            if (is_cudnn_supported(conv_args)) {
                if (auto algo = get_cudnn_algo(cudnn_conv_from_enum_wrapper)) {
                    return algo;
                }
            }
        }
    }

187 188 189
    //! Prefer CUDNN CONVBIAS.
    bool cudnn_conv_bias_act_supported = false;
    for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) {
M
Megvii Engine Team 已提交
190 191
        if (algo.is_available_attribute(
                    args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
            cudnn_conv_bias_act_supported = true;
            break;
        }
    }

    if (cudnn_conv_bias_act_supported) {
        if (auto algo = get_cudnn_algo(cudnn_conv_bias_act_from_enum_wrapper))
            return algo;
    }

    // modify conv_args dst_layout
    conv_args.dst_layout = &dst_layout;
    if (is_cudnn_supported(conv_args)) {
        if (auto algo = get_cudnn_algo(cudnn_conv_from_enum_wrapper))
            return algo;
    }

209 210 211 212
    if (auto algo = get_1x1_algo(args)) {
        return algo;
    }

213 214 215 216
    if (args.filter_meta.group > 1 &&
        sm_algo_pack.group.is_available_attribute(
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
        return &sm_algo_pack.group;
217 218
    }

219
    if (sm_algo_pack.fallback_nchw_qs8.is_available_attribute(
220
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
221 222 223
        return &sm_algo_pack.fallback_nchw_qs8;
    }

224 225 226 227 228
    if (sm_algo_pack.int1_simple.is_available_attribute(
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
        return &sm_algo_pack.int1_simple;
    }

229
    if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
230 231 232
        return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>(
                sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
                "cuda convbias fwd", positive_attr, negative_attr);
233
    } else {
234 235 236
        return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>(
                sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
                "cuda convbias fwd", positive_attr, negative_attr);
237 238 239 240 241 242 243
    }
}

const char* ConvBiasForwardImpl::get_algorithm_set_name() const {
    return "CONV_BIAS_CUDA";
}

M
Megvii Engine Team 已提交
244
size_t ConvBiasForwardImpl::get_workspace_in_bytes(
M
Megvii Engine Team 已提交
245 246
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst,
M
Megvii Engine Team 已提交
247
        const PreprocessedFilter* preprocessed_filter) {
248
    TensorLayoutArray layouts{src, filter, bias, z, dst};
249
    AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
M
Megvii Engine Team 已提交
250 251
                            layouts.data(), layouts.size(),
                            &this->param(), sizeof(this->param())};
252
    auto rst = AlgorithmCache::instance().get(key);
253 254 255 256
    if (rst.policy.algo.valid()) {
        return rst.workspace;
    }

M
Megvii Engine Team 已提交
257 258
    AlgoBase::SizeArgs args{this, src, filter, bias, z, dst, preprocessed_filter};
    return get_algorithm(this, src, filter, bias, z, dst)->get_workspace_in_bytes(args);
259 260
};

M
Megvii Engine Team 已提交
261
size_t ConvBiasForwardImpl::get_preprocess_workspace_in_bytes(
M
Megvii Engine Team 已提交
262 263
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
M
Megvii Engine Team 已提交
264 265 266 267 268
    AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
    return get_algorithm(this, src, filter, bias, z, dst)
            ->get_preprocess_workspace_in_bytes(args);
}

M
Megvii Engine Team 已提交
269 270 271
SmallVector<TensorLayout> ConvBiasForwardImpl::deduce_preprocessed_filter_layout(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
M
Megvii Engine Team 已提交
272 273 274 275 276 277 278
    AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
    return get_algorithm(this, src, filter, bias, z, dst)
            ->deduce_preprocessed_filter_layout(args);
}

void ConvBiasForwardImpl::exec_preprocess(
        const TensorLayout& src_layout, _megdnn_tensor_in filter,
279
        _megdnn_tensor_in bias, const TensorLayout& z_layout,
M
Megvii Engine Team 已提交
280 281
        const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
        _megdnn_workspace workspace) {
M
Megvii Engine Team 已提交
282 283 284 285 286
    TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}, z{nullptr, z_layout};
    AlgoBase::ExecArgs args(
            this, src, filter, bias, z, dst, workspace, preprocessed_filter);
    auto algo = get_algorithm(
            this, src.layout, filter.layout, bias.layout, z.layout, dst.layout);
M
Megvii Engine Team 已提交
287 288 289
    return algo->exec_preprocess(args);
}

290 291 292 293
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen