operation_table.h 15.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
/***************************************************************************************************
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
#pragma once

#include <unordered_map>

#include "src/common/hash_ct.h"
#include "src/cuda/cutlass/manifest.h"
#include "src/cuda/cutlass/util.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace library {

/////////////////////////////////////////////////////////////////////////////////////////////////

class Hash {
public:
    Hash() : m_val(0) {}

    Hash& update(const void* ptr, size_t len) {
        m_val += megdnn::XXHash64CT::hash((const char*)ptr, len, 123456);
        return *this;
    }

    uint64_t digest() const { return m_val; }

private:
    uint64_t m_val;
};

/////////////////////////////////////////////////////////////////////////////////////////////////
//                          Data Structures for GemmOperationMap
/////////////////////////////////////////////////////////////////////////////////////////////////

struct GemmKey {
    NumericTypeID element_A;
    LayoutTypeID layout_A;
    NumericTypeID element_B;
    LayoutTypeID layout_B;
    NumericTypeID element_C;
    LayoutTypeID layout_C;
68
    NumericTypeID element_accumulator;
69 70 71 72 73 74 75 76 77 78 79 80 81 82

    int threadblock_shape_m;
    int threadblock_shape_n;
    int threadblock_shape_k;

    int warp_shape_m;
    int warp_shape_n;
    int warp_shape_k;

    int instruction_shape_m;
    int instruction_shape_n;
    int instruction_shape_k;

    int stages;
83 84
    int alignment_A;
    int alignment_B;
85 86 87 88 89 90
    SplitKMode split_k_mode;

    inline bool operator==(GemmKey const& rhs) const {
        return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) &&
               (element_B == rhs.element_B) && (layout_B == rhs.layout_B) &&
               (element_C == rhs.element_C) && (layout_C == rhs.layout_C) &&
M
Megvii Engine Team 已提交
91
               (element_accumulator == rhs.element_accumulator) &&
92 93 94 95 96 97 98 99 100
               (threadblock_shape_m == rhs.threadblock_shape_m) &&
               (threadblock_shape_n == rhs.threadblock_shape_n) &&
               (threadblock_shape_k == rhs.threadblock_shape_k) &&
               (warp_shape_m == rhs.warp_shape_m) &&
               (warp_shape_n == rhs.warp_shape_n) &&
               (warp_shape_k == rhs.warp_shape_k) &&
               (instruction_shape_m == rhs.instruction_shape_m) &&
               (instruction_shape_n == rhs.instruction_shape_n) &&
               (instruction_shape_k == rhs.instruction_shape_k) &&
101
               (stages == rhs.stages) && (alignment_A == rhs.alignment_A) &&
M
Megvii Engine Team 已提交
102
               (alignment_B == rhs.alignment_B) && (split_k_mode == rhs.split_k_mode);
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    }

    inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); }

    inline std::string str() const {
        auto tuple_to_str = [](int m, int n, int k) -> std::string {
            return std::to_string(m) + " x " + std::to_string(n) + " x " +
                   std::to_string(k);
        };

        std::string threadblock_shape_str = tuple_to_str(
                threadblock_shape_m, threadblock_shape_n, threadblock_shape_k);
        std::string warp_shape_str =
                tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k);
        std::string instruction_shape_str = tuple_to_str(
                instruction_shape_m, instruction_shape_n, instruction_shape_k);

        return std::string("{") + "\n    element_A: " + to_string(element_A) +
               "\n    layout_A: " + to_string(layout_A) +
               "\n    element_B: " + to_string(element_B) +
               "\n    layout_B: " + to_string(layout_B) +
               "\n    element_C: " + to_string(element_C) +
               "\n    layout_C: " + to_string(layout_C) +
M
Megvii Engine Team 已提交
126
               "\n    element_accumulator: " + to_string(element_accumulator) +
127 128 129 130
               "\n    threadblock_shape: " + threadblock_shape_str +
               "\n    warp_shape: " + warp_shape_str +
               "\n    instruction_shape: " + instruction_shape_str +
               "\n    stages: " + std::to_string(stages) +
M
Megvii Engine Team 已提交
131 132
               "\n    alignment_A: " + std::to_string(alignment_A) +
               "\n    alignment_B: " + std::to_string(alignment_B) +
133 134 135 136 137 138 139 140 141 142 143 144 145
               "\n    split_k_mode: " + to_string(split_k_mode) + "\n}";
    }
};

struct GemmKeyHasher {
    inline size_t operator()(GemmKey const& key) const {
        return Hash()
                .update(&key.element_A, sizeof(key.element_A))
                .update(&key.layout_A, sizeof(key.layout_A))
                .update(&key.element_B, sizeof(key.element_B))
                .update(&key.layout_B, sizeof(key.layout_B))
                .update(&key.element_C, sizeof(key.element_C))
                .update(&key.layout_C, sizeof(key.layout_C))
M
Megvii Engine Team 已提交
146 147 148 149
                .update(&key.element_accumulator, sizeof(key.element_accumulator))
                .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m))
                .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n))
                .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k))
150 151 152 153
                .update(&key.warp_shape_m, sizeof(key.warp_shape_m))
                .update(&key.warp_shape_n, sizeof(key.warp_shape_n))
                .update(&key.warp_shape_k, sizeof(key.warp_shape_k))
                .update(&key.stages, sizeof(key.stages))
154 155
                .update(&key.alignment_A, sizeof(key.alignment_A))
                .update(&key.alignment_B, sizeof(key.alignment_B))
156 157 158 159 160 161
                .update(&key.split_k_mode, sizeof(key.split_k_mode))
                .digest();
    }
};

using GemmOperationMap =
M
Megvii Engine Team 已提交
162
        std::unordered_map<GemmKey, std::vector<Operation const*>, GemmKeyHasher>;
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178

/////////////////////////////////////////////////////////////////////////////////////////////////
//                          Data Structures for ConvolutionOperationMap
/////////////////////////////////////////////////////////////////////////////////////////////////

struct ConvolutionKey {
    conv::Operator conv_op;

    library::NumericTypeID element_src;
    library::LayoutTypeID layout_src;
    library::NumericTypeID element_filter;
    library::LayoutTypeID layout_filter;
    library::NumericTypeID element_dst;
    library::LayoutTypeID layout_dst;
    library::NumericTypeID element_bias;
    library::LayoutTypeID layout_bias;
179
    NumericTypeID element_accumulator;
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

    conv::ConvType convolution_type;

    int threadblock_shape_m;
    int threadblock_shape_n;
    int threadblock_shape_k;

    int warp_shape_m;
    int warp_shape_n;
    int warp_shape_k;

    int instruction_shape_m;
    int instruction_shape_n;
    int instruction_shape_k;

    epilogue::EpilogueType epilogue_type;
    int stages;
197
    conv::SpecialOptimizeDesc special_optimization;
198 199 200 201

    int alignment_src;
    int alignment_filter;

202 203
    bool without_shared_load;

204 205 206 207 208 209
    // only used by rrconv
    library::NumericTypeID element_rin = library::NumericTypeID::kInvalid;
    library::LayoutTypeID layout_rin = library::LayoutTypeID::kInvalid;
    library::NumericTypeID element_rout = library::NumericTypeID::kInvalid;
    library::LayoutTypeID layout_rout = library::LayoutTypeID::kInvalid;

210 211 212 213 214
    inline bool operator==(ConvolutionKey const& rhs) const {
        return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) &&
               (layout_src == rhs.layout_src) &&
               (element_filter == rhs.element_filter) &&
               (layout_filter == rhs.layout_filter) &&
M
Megvii Engine Team 已提交
215 216
               (element_dst == rhs.element_dst) && (layout_dst == rhs.layout_dst) &&
               (element_bias == rhs.element_bias) && (layout_bias == rhs.layout_bias) &&
217
               (element_accumulator == rhs.element_accumulator) &&
218 219 220 221 222 223 224 225 226 227 228
               (convolution_type == rhs.convolution_type) &&
               (threadblock_shape_m == rhs.threadblock_shape_m) &&
               (threadblock_shape_n == rhs.threadblock_shape_n) &&
               (threadblock_shape_k == rhs.threadblock_shape_k) &&
               (warp_shape_m == rhs.warp_shape_m) &&
               (warp_shape_n == rhs.warp_shape_n) &&
               (warp_shape_k == rhs.warp_shape_k) &&
               (instruction_shape_m == rhs.instruction_shape_m) &&
               (instruction_shape_n == rhs.instruction_shape_n) &&
               (instruction_shape_k == rhs.instruction_shape_k) &&
               (epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) &&
229
               (special_optimization == rhs.special_optimization) &&
230 231
               (alignment_src == rhs.alignment_src) &&
               (alignment_filter == rhs.alignment_filter) &&
232 233 234
               (without_shared_load == rhs.without_shared_load) &&
               (element_rin == rhs.element_rin) && (layout_rin == rhs.layout_rin) &&
               (element_rout == rhs.element_rout) && (layout_rout == rhs.layout_rout);
235 236
    }

M
Megvii Engine Team 已提交
237
    inline bool operator!=(ConvolutionKey const& rhs) const { return !(*this == rhs); }
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

    inline std::string str() const {
        auto tuple_to_str = [](int m, int n, int k) -> std::string {
            return std::to_string(m) + " x " + std::to_string(n) + " x " +
                   std::to_string(k);
        };

        std::string threadblock_shape_str = tuple_to_str(
                threadblock_shape_m, threadblock_shape_n, threadblock_shape_k);
        std::string warp_shape_str =
                tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k);
        std::string instruction_shape_str = tuple_to_str(
                instruction_shape_m, instruction_shape_n, instruction_shape_k);

        return std::string("{") + "\n    conv_op: " + to_string(conv_op) +
               "\n    element_src: " + to_string(element_src) +
               "\n    layout_src: " + to_string(layout_src) +
               "\n    element_filter: " + to_string(element_filter) +
               "\n    layout_filter: " + to_string(layout_filter) +
               "\n    element_dst: " + to_string(element_dst) +
               "\n    layout_dst: " + to_string(layout_dst) +
               "\n    element_bias: " + to_string(element_bias) +
               "\n    layout_bias: " + to_string(layout_bias) +
261
               "\n    element_accumulator: " + to_string(element_accumulator) +
262 263 264 265 266 267
               "\n    convolution_type: " + to_string(convolution_type) +
               "\n    threadblock_shape: " + threadblock_shape_str +
               "\n    warp_shape: " + warp_shape_str +
               "\n    instruction_shape: " + instruction_shape_str +
               "\n    epilogue_type: " + to_string(epilogue_type) +
               "\n    stages: " + std::to_string(stages) +
M
Megvii Engine Team 已提交
268
               "\n    special_optimization: " + to_string(special_optimization) +
269 270
               "\n    alignment_src: " + std::to_string(alignment_src) +
               "\n    alignment_filter: " + std::to_string(alignment_filter) +
271 272 273 274 275
               "\n    without_shared_load: " + to_string(without_shared_load) +
               "\n    element_rin: " + to_string(element_rin) +
               "\n    layout_rin: " + to_string(layout_rin) +
               "\n    element_rout: " + to_string(element_rout) +
               "\n    layout_rout: " + to_string(layout_rout) + "\n}";
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    }
};

struct ConvolutionKeyHasher {
    inline size_t operator()(ConvolutionKey const& key) const {
        return Hash()
                .update(&key.conv_op, sizeof(key.conv_op))
                .update(&key.element_src, sizeof(key.element_src))
                .update(&key.layout_src, sizeof(key.layout_src))
                .update(&key.element_filter, sizeof(key.element_filter))
                .update(&key.layout_filter, sizeof(key.layout_filter))
                .update(&key.element_dst, sizeof(key.element_dst))
                .update(&key.layout_dst, sizeof(key.layout_dst))
                .update(&key.element_bias, sizeof(key.element_bias))
                .update(&key.layout_bias, sizeof(key.layout_bias))
291
                .update(&key.element_accumulator, sizeof(key.element_accumulator))
292
                .update(&key.convolution_type, sizeof(key.convolution_type))
M
Megvii Engine Team 已提交
293 294 295
                .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m))
                .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n))
                .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k))
296 297 298
                .update(&key.warp_shape_m, sizeof(key.warp_shape_m))
                .update(&key.warp_shape_n, sizeof(key.warp_shape_n))
                .update(&key.warp_shape_k, sizeof(key.warp_shape_k))
M
Megvii Engine Team 已提交
299 300 301
                .update(&key.instruction_shape_m, sizeof(key.instruction_shape_m))
                .update(&key.instruction_shape_n, sizeof(key.instruction_shape_n))
                .update(&key.instruction_shape_k, sizeof(key.instruction_shape_k))
302 303
                .update(&key.epilogue_type, sizeof(key.epilogue_type))
                .update(&key.stages, sizeof(key.stages))
M
Megvii Engine Team 已提交
304
                .update(&key.special_optimization, sizeof(key.special_optimization))
305 306
                .update(&key.alignment_src, sizeof(key.alignment_src))
                .update(&key.alignment_filter, sizeof(key.alignment_filter))
M
Megvii Engine Team 已提交
307
                .update(&key.without_shared_load, sizeof(key.without_shared_load))
308 309 310 311
                .update(&key.element_rin, sizeof(key.element_rin))
                .update(&key.layout_rin, sizeof(key.layout_rin))
                .update(&key.element_rout, sizeof(key.element_rout))
                .update(&key.layout_rout, sizeof(key.layout_rout))
312 313 314 315
                .digest();
    }
};

M
Megvii Engine Team 已提交
316 317
using ConvolutionOperationMap = std::unordered_map<
        ConvolutionKey, std::vector<Operation const*>, ConvolutionKeyHasher>;
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Table of cutlass::library::Operation instances
class OperationTable {
public:
    /// Map of all operations of type kGemm
    GemmOperationMap gemm_operations;

    /// Map of all operations of type kConvolution
    ConvolutionOperationMap convolution_operations;

public:
    void append(Manifest const& manifest);

    Operation const* find_op(GemmKey const& key) const;

    Operation const* find_op(ConvolutionKey const& key) const;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace library
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////