relayout_helper.h 4.5 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/common/relayout_helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16
 */
#pragma once

#include "megdnn/oprs.h"
#include "src/common/utils.h"

17 18 19 20
#include "midout.h"

MIDOUT_DECL(transpose_fallback)

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
namespace megdnn {
namespace relayout {

static inline bool is_contig(const TensorLayout& layout) {
    return layout.ndim == 1 && layout.stride[0] == 1;
}

//! [b][m][n][c] to [b][n][m][c]
struct TransposeParam {
    size_t batch, m, n, c;
};

/**
 * \brief whether the relayout can be formulated as TransposeParam
 *
 * Note that \p src and \p dst should have been processed by
 * RelayoutForward::check_layout_and_canonize
 */
bool is_transpose(const TensorLayout& src, const TensorLayout& dst,
                  TransposeParam& p);

namespace transpose_fallback {

#if MEGDNN_X86
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64;
46 47
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 /*BEGIN-INLINE-INTERNAL*/ || \
        MEGDNN_MIPS /*END-INLINE-INTERNAL*/
48
constexpr size_t BLOCK_LINE_SIZE_BYTES = 32;
49 50 51
#elif MEGDNN_RISCV64
//! ref U54-MC arch
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64;
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
#else
#error "unknown megdnn arch"
#endif

/**
 * \brief transpose traits
 * \tparam T element type
 */
template <typename T>
struct transpose_traits {
    static constexpr size_t block_size = BLOCK_LINE_SIZE_BYTES / sizeof(T);
};

template <typename T>
void transpose_block_fallback(const T* src, T* dst, const size_t src_stride,
                              const size_t dst_stride, size_t block_h,
                              size_t block_w) {
    constexpr size_t block_size = transpose_traits<T>::block_size;
    T block[block_size][block_size];

    for (size_t i = 0; i < block_h; ++i) {
        auto src_ptr = src + i * src_stride;
        for (size_t j = 0; j < block_w; ++j) {
            block[j][i] = src_ptr[j];
        }
    }
    for (size_t i = 0; i < block_w; ++i) {
        auto dst_ptr = dst + i * dst_stride;
        for (size_t j = 0; j < block_h; ++j) {
            dst_ptr[j] = block[i][j];
        }
    }
}

template <typename T>
void transpose_block(const T* src, T* dst, const size_t src_stride,
                     const size_t dst_stride, size_t block_h, size_t block_w) {
    transpose_block_fallback(src, dst, src_stride, dst_stride, block_h,
                             block_w);
}

/*!
 * \brief transpose a single block whose size is transpose_traits<T>::block_size
 *
 * This function and transpose_traits can be specialized to implement optimized
 * block transpose
 */
template <typename T>
void transpose_block(const T* src, T* dst, const size_t src_stride,
                     const size_t dst_stride) {
    constexpr size_t block_size = transpose_traits<T>::block_size;
    transpose_block_fallback(src, dst, src_stride, dst_stride, block_size,
                             block_size);
}

/*!
 * \brief transpose contiguous (batch, m, n) to (batch, n, m)
 */
template <typename T>
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) {
    auto batch_src = src;
    auto batch_dst = dst;
    constexpr size_t B = transpose_traits<T>::block_size;

    auto work_block = [m, n, &batch_src, &batch_dst](
                              const size_t i, const size_t j, const size_t h,
                              const size_t w) {
        auto src = batch_src + i * n + j, dst = batch_dst + j * m + i;
120 121 122 123 124 125
        MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) {
            if (h == B && w == B) {
                transpose_block(src, dst, n, m);
            } else {
                transpose_block(src, dst, n, m, h, w);
            }
126
        }
127
        MIDOUT_END();
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    };
    auto work_row = [&work_block, n](size_t i, size_t h) {
        size_t j = 0;
        for (; j + B <= n; j += B) {
            work_block(i, j, h, B);
        }
        if (j < n) {
            work_block(i, j, h, n - j);
        }
    };

    for (size_t b = 0; b < batch; ++b) {
        size_t i = 0;
        for (; i + B <= m; i += B) {
            work_row(i, B);
        }
        if (i < m) {
            work_row(i, m - i);
        }
        batch_src += m * n;
        batch_dst += m * n;
    }
}
}  // namespace transpose_fallback

}  // namespace relayout
}  // namespace megdnn

// vim: syntax=cpp.doxygen