helper.h 4.9 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include "./opr_impl.h"
#include "src/cuda/handle.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/common/utils.h"
#include "src/common/algo_chooser.h"

namespace megdnn {
namespace cuda {

class ConvBiasDesc {
public:
    ConvBiasDesc();
    void set_conv_bias(DType data_type, const param::ConvBias& param,
                       const size_t nr_group);
    void set_conv(DType data_type, const param::ConvBias& param,
                  const size_t nr_group);
    ~ConvBiasDesc();
    cudnnConvolutionDescriptor_t conv_desc;
    cudnnActivationDescriptor_t act_desc;
};

namespace conv_bias {
    using CanonizedFilterMeta = ConvBiasForward::CanonizedFilterMeta;

    //! conv size descriptor in the forward view
    struct BiasForwardSizeArgs {
        HandleImpl *handle;
        const TensorLayout *src_layout;
        const TensorLayout *filter_layout;
        const TensorLayout *bias_layout;
        const TensorLayout *z_layout;
        CanonizedFilterMeta filter_meta;
        const TensorLayout *dst_layout;
        param::ConvBias::NonlineMode nonlinear_mode;
    };

    //! whether cudnn is supported for a filter meta
    bool is_cudnn_supported(const BiasForwardSizeArgs& args);

    //! get workspace bundle for matmul algo
53
    SmallVector<size_t> matmul_get_workspace_bundle(
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
            const BiasForwardSizeArgs& args);

    /*!
     * \brief flip conv filter
     *
     * Flip conv filter pointed by \p raw_ptr, store result in workspace, and
     * change \p raw_ptr to workspace.
     */
    void flip_filter(const BiasForwardSizeArgs& args,
                     const Workspace& workspace, void*& raw_ptr);

    struct CUDNNForwardDescs {
        TensorDesc src_desc, dst_desc, bias_desc, z_desc;
        FilterDesc<param::ConvBias> filter_desc;
        ConvBiasDesc conv_desc;

        void set_conv_bias(const TensorLayout& src,
                           const CanonizedFilterMeta& filter,
                           const TensorLayout& dst, const TensorLayout& bias,
                           const TensorLayout& z,
                           const param::ConvBias& param) {
75 76 77 78 79 80 81 82
            using Format = param::ConvBias::Format;
            Format src_format, dst_format;
            src_format = dst_format = param.format;
            if (param.format == Format::NCHW4_NCHW) {
                src_format = Format::NCHW4;
                dst_format = Format::NCHW;
            }
            src_desc.set(src, src_format);
83 84
            filter_desc.set(filter);
            if (z.ndim > 0) {
85
                z_desc.set(z, dst_format);
86
            }
87
            dst_desc.set(dst, dst_format);
88 89 90 91 92 93 94 95 96 97 98 99 100
            conv_desc.set_conv_bias(src.dtype, param, filter.group);

            // cudnn requires the bias to be float tensor.
            auto float_bias_layout = bias;
            float_bias_layout.dtype = dtype::Float32();
            if (param.format == param::ConvBias::Format::NCHW4 ||
                param.format == param::ConvBias::Format::NCHW32) {
                // cudnn require bias to be NCHW, not NCHW4.
                float_bias_layout = float_bias_layout.reshape(
                        {float_bias_layout[0],
                         float_bias_layout[1] * float_bias_layout[4],
                         float_bias_layout[2], float_bias_layout[3]});
                bias_desc.set(float_bias_layout);
101 102 103 104 105 106
            } else if (param.format == param::ConvBias::Format::NCHW4_NCHW) {
                megdnn_assert(float_bias_layout.ndim == 4,
                              "NCHW4_NCHW format assumes bias tensor is stored "
                              "in NCHW layout, ndim(expected:4,got:%zu)",
                              float_bias_layout.ndim);
                bias_desc.set(float_bias_layout);
107 108 109 110 111 112 113 114
            } else {
                bias_desc.set(float_bias_layout, param.format);
            }
        }

        void set_conv(const TensorLayout& src,
                      const CanonizedFilterMeta& filter,
                      const TensorLayout& dst, const param::ConvBias& param) {
115 116 117 118 119 120 121 122
            using Format = param::ConvBias::Format;
            Format src_format, dst_format;
            src_format = dst_format = param.format;
            if (param.format == Format::NCHW4_NCHW) {
                src_format = Format::NCHW4;
                dst_format = Format::NCHW;
            }
            src_desc.set(src, src_format);
123
            filter_desc.set(filter);
124
            dst_desc.set(dst, dst_format);
125 126 127 128 129 130 131 132 133
            conv_desc.set_conv(src.dtype, param, filter.group);
        }
    };

}  // namespace conv_bias
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen