helper.h 4.5 KB
Newer Older
1 2 3 4
#pragma once

#include "./opr_impl.h"
#include "src/common/algo_chooser.h"
M
Megvii Engine Team 已提交
5 6 7
#include "src/common/utils.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
8 9 10 11 12 13 14

namespace megdnn {
namespace cuda {

class ConvBiasDesc {
public:
    ConvBiasDesc();
M
Megvii Engine Team 已提交
15 16 17
    void set_conv_bias(
            DType data_type, const param::ConvBias& param, const size_t nr_group);
    void set_conv(DType data_type, const param::ConvBias& param, const size_t nr_group);
18 19 20 21 22 23
    ~ConvBiasDesc();
    cudnnConvolutionDescriptor_t conv_desc;
    cudnnActivationDescriptor_t act_desc;
};

namespace conv_bias {
M
Megvii Engine Team 已提交
24
using CanonizedFilterMeta = ConvBiasForward::CanonizedFilterMeta;
25

M
Megvii Engine Team 已提交
26 27 28 29 30 31 32 33 34 35 36
//! conv size descriptor in the forward view
struct BiasForwardSizeArgs {
    HandleImpl* handle;
    const TensorLayout* src_layout;
    const TensorLayout* filter_layout;
    const TensorLayout* bias_layout;
    const TensorLayout* z_layout;
    CanonizedFilterMeta filter_meta;
    const TensorLayout* dst_layout;
    param::ConvBias::NonlineMode nonlinear_mode;
};
37

M
Megvii Engine Team 已提交
38 39
//! whether cudnn is supported for a filter meta
bool is_cudnn_supported(const BiasForwardSizeArgs& args);
40

M
Megvii Engine Team 已提交
41 42
//! get workspace bundle for matmul algo
SmallVector<size_t> matmul_get_workspace_bundle(const BiasForwardSizeArgs& args);
43

M
Megvii Engine Team 已提交
44 45 46 47 48 49 50
/*!
 * \brief flip conv filter
 *
 * Flip conv filter pointed by \p raw_ptr, store result in workspace, and
 * change \p raw_ptr to workspace.
 */
void flip_filter(
51
        const BiasForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr);
52

M
Megvii Engine Team 已提交
53 54 55 56
struct CUDNNForwardDescs {
    TensorDesc src_desc, dst_desc, bias_desc, z_desc;
    FilterDesc<param::ConvBias> filter_desc;
    ConvBiasDesc conv_desc;
57

M
Megvii Engine Team 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    void set_conv_bias(
            const TensorLayout& src, const CanonizedFilterMeta& filter,
            const TensorLayout& dst, const TensorLayout& bias, const TensorLayout& z,
            const param::ConvBias& param) {
        using Format = param::ConvBias::Format;
        Format src_format, dst_format;
        src_format = dst_format = param.format;
        if (param.format == Format::NCHW4_NCHW) {
            src_format = Format::NCHW4;
            dst_format = Format::NCHW;
        }
        src_desc.set(src, src_format);
        filter_desc.set(filter);
        if (z.ndim > 0) {
            z_desc.set(z, dst_format);
        }
        dst_desc.set(dst, dst_format);
        conv_desc.set_conv_bias(src.dtype, param, filter.group);
76

M
Megvii Engine Team 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        // cudnn requires the bias to be float tensor.
        auto float_bias_layout = bias;
        float_bias_layout.dtype = dtype::Float32();
        if (param.format == param::ConvBias::Format::NCHW4 ||
            param.format == param::ConvBias::Format::NCHW32) {
            // cudnn require bias to be NCHW, not NCHW4.
            float_bias_layout = float_bias_layout.reshape(
                    {float_bias_layout[0], float_bias_layout[1] * float_bias_layout[4],
                     float_bias_layout[2], float_bias_layout[3]});
            bias_desc.set(float_bias_layout);
        } else if (param.format == param::ConvBias::Format::NCHW4_NCHW) {
            megdnn_assert(
                    float_bias_layout.ndim == 4,
                    "NCHW4_NCHW format assumes bias tensor is stored "
                    "in NCHW layout, ndim(expected:4,got:%zu)",
                    float_bias_layout.ndim);
            bias_desc.set(float_bias_layout);
        } else {
            bias_desc.set(float_bias_layout, param.format);
96
        }
M
Megvii Engine Team 已提交
97
    }
98

M
Megvii Engine Team 已提交
99 100 101 102 103 104 105 106 107
    void set_conv(
            const TensorLayout& src, const CanonizedFilterMeta& filter,
            const TensorLayout& dst, const param::ConvBias& param) {
        using Format = param::ConvBias::Format;
        Format src_format, dst_format;
        src_format = dst_format = param.format;
        if (param.format == Format::NCHW4_NCHW) {
            src_format = Format::NCHW4;
            dst_format = Format::NCHW;
108
        }
M
Megvii Engine Team 已提交
109 110 111 112 113 114
        src_desc.set(src, src_format);
        filter_desc.set(filter);
        dst_desc.set(dst, dst_format);
        conv_desc.set_conv(src.dtype, param, filter.group);
    }
};
115

116 117 118 119 120 121 122 123 124
std::pair<float, float> cudnn_get_conv_bias_act_scale_param(
        const TensorLayout& x, const TensorLayout& y, const TensorLayout& w,
        const TensorLayout& b, const TensorLayout& z);

void cudnn_reorder_filer_and_bias_nchw32(
        const cudnnHandle_t& handle, const void* filter_ptr,
        const CanonizedFilterMeta& fm, const void* bias_ptr, void* reordered_filter_ptr,
        void* reordered_bias_ptr);

125
}  // namespace conv_bias
M
Megvii Engine Team 已提交
126 127
}  // namespace cuda
}  // namespace megdnn
128 129

// vim: syntax=cpp.doxygen