relayout_helper.h 4.4 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/relayout_helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16
 */
#pragma once

#include "megdnn/oprs.h"
#include "src/common/utils.h"

17 18 19 20
#include "midout.h"

MIDOUT_DECL(transpose_fallback)

21 22 23 24 25 26 27 28 29
namespace megdnn {
namespace relayout {

static inline bool is_contig(const TensorLayout& layout) {
    return layout.ndim == 1 && layout.stride[0] == 1;
}

//! [b][m][n][c] to [b][n][m][c]
struct TransposeParam {
30
    size_t batch, m, n, c, stride_m;
31 32 33 34 35 36 37 38
};

/**
 * \brief whether the relayout can be formulated as TransposeParam
 *
 * Note that \p src and \p dst should have been processed by
 * RelayoutForward::check_layout_and_canonize
 */
39 40 41
bool is_transpose(
        const TensorLayout& src, const TensorLayout& dst, TransposeParam& p,
        bool allow_non_contig = false);
42 43 44

namespace transpose_fallback {

M
Megvii Engine Team 已提交
45
#if MEGDNN_X86 || MEGDNN_NAIVE
46
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64;
47
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
48
constexpr size_t BLOCK_LINE_SIZE_BYTES = 32;
49 50 51
#elif MEGDNN_RISCV64
//! ref U54-MC arch
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64;
52 53 54 55 56 57 58 59 60 61 62 63 64 65
#else
#error "unknown megdnn arch"
#endif

/**
 * \brief transpose traits
 * \tparam T element type
 */
template <typename T>
struct transpose_traits {
    static constexpr size_t block_size = BLOCK_LINE_SIZE_BYTES / sizeof(T);
};

template <typename T>
M
Megvii Engine Team 已提交
66 67 68
void transpose_block_fallback(
        const T* src, T* dst, const size_t src_stride, const size_t dst_stride,
        size_t block_h, size_t block_w) {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    constexpr size_t block_size = transpose_traits<T>::block_size;
    T block[block_size][block_size];

    for (size_t i = 0; i < block_h; ++i) {
        auto src_ptr = src + i * src_stride;
        for (size_t j = 0; j < block_w; ++j) {
            block[j][i] = src_ptr[j];
        }
    }
    for (size_t i = 0; i < block_w; ++i) {
        auto dst_ptr = dst + i * dst_stride;
        for (size_t j = 0; j < block_h; ++j) {
            dst_ptr[j] = block[i][j];
        }
    }
}

template <typename T>
M
Megvii Engine Team 已提交
87 88 89 90
void transpose_block(
        const T* src, T* dst, const size_t src_stride, const size_t dst_stride,
        size_t block_h, size_t block_w) {
    transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w);
91 92 93 94 95 96 97 98 99
}

/*!
 * \brief transpose a single block whose size is transpose_traits<T>::block_size
 *
 * This function and transpose_traits can be specialized to implement optimized
 * block transpose
 */
template <typename T>
M
Megvii Engine Team 已提交
100 101
void transpose_block(
        const T* src, T* dst, const size_t src_stride, const size_t dst_stride) {
102
    constexpr size_t block_size = transpose_traits<T>::block_size;
M
Megvii Engine Team 已提交
103
    transpose_block_fallback(src, dst, src_stride, dst_stride, block_size, block_size);
104 105 106 107 108 109
}

/*!
 * \brief transpose contiguous (batch, m, n) to (batch, n, m)
 */
template <typename T>
110 111 112 113
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst, size_t stride_m = 0) {
    if (stride_m == 0) {
        stride_m = n;
    }
114 115 116 117
    auto batch_src = src;
    auto batch_dst = dst;
    constexpr size_t B = transpose_traits<T>::block_size;

118
    auto work_block = [m, stride_m, &batch_src, &batch_dst](
119 120
                              const size_t i, const size_t j, const size_t h,
                              const size_t w) {
121
        auto src = batch_src + i * stride_m + j, dst = batch_dst + j * m + i;
122 123
        MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) {
            if (h == B && w == B) {
124
                transpose_block(src, dst, stride_m, m);
125
            } else {
126
                transpose_block(src, dst, stride_m, m, h, w);
127
            }
128
        }
129
        MIDOUT_END();
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    };
    auto work_row = [&work_block, n](size_t i, size_t h) {
        size_t j = 0;
        for (; j + B <= n; j += B) {
            work_block(i, j, h, B);
        }
        if (j < n) {
            work_block(i, j, h, n - j);
        }
    };

    for (size_t b = 0; b < batch; ++b) {
        size_t i = 0;
        for (; i + B <= m; i += B) {
            work_row(i, B);
        }
        if (i < m) {
            work_row(i, m - i);
        }
149
        batch_src += m * stride_m;
150 151 152 153 154 155 156 157 158
        batch_dst += m * n;
    }
}
}  // namespace transpose_fallback

}  // namespace relayout
}  // namespace megdnn

// vim: syntax=cpp.doxygen