relayout_helper.h 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
/**
 * \file dnn/src/common/relayout_helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {
namespace relayout {

static inline bool is_contig(const TensorLayout& layout) {
    return layout.ndim == 1 && layout.stride[0] == 1;
}

//! [b][m][n][c] to [b][n][m][c]
struct TransposeParam {
    size_t batch, m, n, c;
};

/**
 * \brief whether the relayout can be formulated as TransposeParam
 *
 * Note that \p src and \p dst should have been processed by
 * RelayoutForward::check_layout_and_canonize
 */
bool is_transpose(const TensorLayout& src, const TensorLayout& dst,
                  TransposeParam& p);

namespace transpose_fallback {

#if MEGDNN_X86
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64;
#else
#error "unknown megdnn arch"
#endif

/**
 * \brief transpose traits
 * \tparam T element type
 */
template <typename T>
struct transpose_traits {
    static constexpr size_t block_size = BLOCK_LINE_SIZE_BYTES / sizeof(T);
};

template <typename T>
void transpose_block_fallback(const T* src, T* dst, const size_t src_stride,
                              const size_t dst_stride, size_t block_h,
                              size_t block_w) {
    constexpr size_t block_size = transpose_traits<T>::block_size;
    T block[block_size][block_size];

    for (size_t i = 0; i < block_h; ++i) {
        auto src_ptr = src + i * src_stride;
        for (size_t j = 0; j < block_w; ++j) {
            block[j][i] = src_ptr[j];
        }
    }
    for (size_t i = 0; i < block_w; ++i) {
        auto dst_ptr = dst + i * dst_stride;
        for (size_t j = 0; j < block_h; ++j) {
            dst_ptr[j] = block[i][j];
        }
    }
}

template <typename T>
void transpose_block(const T* src, T* dst, const size_t src_stride,
                     const size_t dst_stride, size_t block_h, size_t block_w) {
    transpose_block_fallback(src, dst, src_stride, dst_stride, block_h,
                             block_w);
}

/*!
 * \brief transpose a single block whose size is transpose_traits<T>::block_size
 *
 * This function and transpose_traits can be specialized to implement optimized
 * block transpose
 */
template <typename T>
void transpose_block(const T* src, T* dst, const size_t src_stride,
                     const size_t dst_stride) {
    constexpr size_t block_size = transpose_traits<T>::block_size;
    transpose_block_fallback(src, dst, src_stride, dst_stride, block_size,
                             block_size);
}

/*!
 * \brief transpose contiguous (batch, m, n) to (batch, n, m)
 */
template <typename T>
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) {
    auto batch_src = src;
    auto batch_dst = dst;
    constexpr size_t B = transpose_traits<T>::block_size;

    auto work_block = [m, n, &batch_src, &batch_dst](
                              const size_t i, const size_t j, const size_t h,
                              const size_t w) {

        auto src = batch_src + i * n + j, dst = batch_dst + j * m + i;
        if (h == B && w == B) {
            transpose_block(src, dst, n, m);
        } else {
            transpose_block(src, dst, n, m, h, w);
        }
    };
    auto work_row = [&work_block, n](size_t i, size_t h) {
        size_t j = 0;
        for (; j + B <= n; j += B) {
            work_block(i, j, h, B);
        }
        if (j < n) {
            work_block(i, j, h, n - j);
        }
    };

    for (size_t b = 0; b < batch; ++b) {
        size_t i = 0;
        for (; i + B <= m; i += B) {
            work_row(i, B);
        }
        if (i < m) {
            work_row(i, m - i);
        }
        batch_src += m * n;
        batch_dst += m * n;
    }
}
}  // namespace transpose_fallback

}  // namespace relayout
}  // namespace megdnn

// vim: syntax=cpp.doxygen