iconv_trait.cuh 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/***************************************************************************************************
 * Copyright (c) 2017-2019, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification, are permitted
 * provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice, this list of
 *       conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright notice, this list of
 *       conditions and the following disclaimer in the documentation and/or other materials
 *       provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
 *       to endorse or promote products derived from this software without specific prior written
 *       permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/**
 * \file dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
29
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include "src/cuda/convolution_helper/block_tile_consumer/block_consumer.cuh"
#include "src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator.cuh"
#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor.cuh"
#include "src/cuda/convolution_helper/global_memory_writer/global_memory_writer.cuh"
#include "src/cuda/convolution_helper/layout.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"

namespace megdnn {
namespace cuda {
namespace convolution {
#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(                   \
        _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, \
        _kern_layout, _output_layout, _conv_param)                     \
    using src_dtype = _src_dtype;                                      \
    using filter_dtype = _filter_dtype;                                \
    using smem_storage_dtype = _smem_storage_dtype;                    \
    using InputLayout = _input_layout;                                 \
    using KernLayout = _kern_layout;                                   \
    using OutputLayout = _output_layout;                               \
    using Param = _conv_param;                                         \
56
    static constexpr bool check_bounds = check_bounds_;
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
#define MEGDNN_COMMA ,

template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_,
          typename ThreadConfig_>
struct IConvTrait {
    COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t,
                                                Layout<Format::CHWN4>,
                                                Layout<Format::CHWN4>,
                                                Layout<Format::CHWN4>,
                                                ConvParam);
    using RegBlockConfig = RegBlockConfig_;
    using ThreadConfig = ThreadConfig_;
    struct DataTileCount {
        using RegBlockConfig = RegBlockConfig;
        using ThreadConfig = ThreadConfig;
        using copy_t = ldg_dtype;
        using smem_storage_dtype = smem_storage_dtype;
        static int constexpr load_width =
                sizeof(copy_t) / sizeof(smem_storage_dtype);
        static int constexpr ldg_load_width =
                sizeof(copy_t) / sizeof(src_dtype);
        static int constexpr skew = load_width;
        static int constexpr block_tile_batch =
                RegBlockConfig::reg_n * ThreadConfig::nr_thread_x;
        static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;

        static int constexpr smem_load_x = block_tile_batch / load_width;
        static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
        static int constexpr load_y = ThreadConfig::nr_threads / load_x;

        static int constexpr smem_h = RegBlockConfig::reg_k_packed;
        static int constexpr smem_w = block_tile_batch;
        static int constexpr smem_stride =
                smem_w % 2 == 0 ? smem_w + skew : smem_w;
        static int constexpr smem_tot = smem_h * smem_stride;

        static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
        static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;

        static bool constexpr check_bounds_h = smem_h % load_y != 0;
        static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
    };

    struct FilterTileCount {
        using RegBlockConfig = RegBlockConfig;
        using ThreadConfig = ThreadConfig;
        using copy_t = ldg_dtype;
        using smem_storage_dtype = smem_storage_dtype;
        static int constexpr load_width =
                sizeof(copy_t) / sizeof(smem_storage_dtype);
        static int constexpr ldg_load_width =
                sizeof(copy_t) / sizeof(filter_dtype);
        static int constexpr skew = load_width;
        static int constexpr block_tile_out_channel =
                RegBlockConfig::reg_m * ThreadConfig::nr_thread_y;
        static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;

        static int constexpr smem_load_x = block_tile_out_channel / load_width;
        static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
        static int constexpr load_y = ThreadConfig::nr_threads / load_x;

        static int constexpr smem_h = RegBlockConfig::reg_k_packed;
        static int constexpr smem_w = block_tile_out_channel;
        static int constexpr smem_stride =
                smem_w % 2 == 0 ? smem_w + skew : smem_w;
        static int constexpr smem_tot = smem_h * smem_stride;

        static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
        static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;

        static bool constexpr check_bounds_h = smem_h % load_y != 0;
        static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
    };

    using BlockTileIterator =
            BlockTileIteratorBasic<DataTileCount, FilterTileCount>;
    using DataGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxN<check_bounds, DataTileCount, InputLayout>;
    using FilterGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxN<check_bounds, FilterTileCount, KernLayout>;
    static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1;
    using BlockConsumer =
            IConvBlockConsumer<RegBlockConfig, ThreadConfig, pipelined>;
    using GlobalMemoryWriter =
            IConvGlobalMemoryWriter<RegBlockConfig, ThreadConfig>;
};

template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_,
          typename ThreadConfig_>
struct IConvTraitUnrollWidth {
    COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t,
                                                Layout<Format::CHWN4>,
                                                Layout<Format::CHWN4>,
                                                Layout<Format::CHWN4>,
                                                ConvParam);
    using RegBlockConfig = RegBlockConfig_;
    using ThreadConfig = ThreadConfig_;
    struct DataTileCount {
        using RegBlockConfig = RegBlockConfig;
        using ThreadConfig = ThreadConfig;
        using copy_t = ldg_dtype;
        using smem_storage_dtype = smem_storage_dtype;
        static int constexpr load_width =
                sizeof(copy_t) / sizeof(smem_storage_dtype);
        static int constexpr ldg_load_width =
                sizeof(copy_t) / sizeof(src_dtype);
        static int constexpr skew = load_width;
        static int constexpr block_tile_batch =
                RegBlockConfig::reg_n * ThreadConfig::nr_thread_x;
        static int constexpr block_tile_out_width = RegBlockConfig::reg_width;
        static int constexpr block_tile_in_channel = RegBlockConfig::reg_k;

        static int constexpr smem_load_x = block_tile_batch / load_width;
        static int constexpr load_x = smem_load_x > 32 ? 32 : smem_load_x;
        static int constexpr load_y = ThreadConfig::nr_threads / load_x;

        static int constexpr smem_h = RegBlockConfig::reg_k_packed;
        static int constexpr smem_w = block_tile_batch;
        static int constexpr img_cache = RegBlockConfig::reg_width;
        static int constexpr smem_stride =
                smem_w % 2 == 0 ? smem_w + skew : smem_w;
        static int constexpr smem_tot = smem_h * img_cache * smem_stride;

        static int constexpr reg_h = (smem_h + load_y - 1) / load_y;
        static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x;

        static bool constexpr check_bounds_h = smem_h % load_y != 0;
        static bool constexpr check_bounds_w = smem_load_x % load_x != 0;
    };
    MEGDNN_STATIC_ASSERT(
            std::is_same<typename IConvTrait<
                            check_bounds MEGDNN_COMMA ldg_dtype MEGDNN_COMMA
                                    RegBlockConfig MEGDNN_COMMA
                                            ThreadConfig>::filter_dtype
                                 MEGDNN_COMMA filter_dtype>::value == true,
            "data type of filter tensor should be int8_t");
    using FilterTileCount =
            typename IConvTrait<check_bounds, ldg_dtype, RegBlockConfig,
                                ThreadConfig>::FilterTileCount;
    using BlockTileIterator =
            BlockTileIteratorUnrollWidth<DataTileCount, FilterTileCount>;
    using DataGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxWOxN<check_bounds, DataTileCount,
                                               InputLayout>;
    using FilterGlobal2ShareMemVisitor =
            Global2ShareMemVisitor_CIxN<check_bounds, FilterTileCount,
                                        KernLayout>;
    static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1;
    using BlockConsumer =
            IConvBlockConsumerUnrollWidth<RegBlockConfig, ThreadConfig,
                                          pipelined>;
    using GlobalMemoryWriter =
            IConvGlobalMemoryWriterUnrollWidth<RegBlockConfig, ThreadConfig>;
};

#undef COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM
#undef MEGDNN_COMMA

}  // namespace convolution
}  // namespace cuda
}  // namespace megdnn

// vim: ft=cpp syntax=cuda.doxygen foldmethod=marker foldmarker=f{{{,f}}}