relayout.cpp 3.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/relayout.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
 */

#include "megdnn/oprs.h"
#include "src/common/relayout_helper.h"
#include "src/common/utils.h"

#include <algorithm>

using namespace megdnn;
using namespace megdnn::relayout;

namespace {

//! whether current shape is [b][n][m][c] and is a transpose of contig
//! [b][m][n][c]
26 27
bool is_transpose_single(
        const TensorLayout& layout, TransposeParam& p, bool allow_no_contig) {
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    /*
     * assuming contig layout is:
     *  shape: b, m, n, c
     *  stride: mnc, nc, c, 1
     *
     * then given layout should be:
     *  shape: b, n, m, c
     *  stride: mnc, c, nc, 1
     *
     * if c == 1:
     *  shape: b, n, m
     *  stride: mn, 1, n
     * if b == 1:
     *  shape: n, m, c
     *  stride: c, nc, 1
     *
     * if b == 1 && c == 1:
     *  shape: n, m
46
     *  stride: 1, n(stride_m for no-contig)
47
     */
48
    p.stride_m = 0;
M
Megvii Engine Team 已提交
49
    auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; };
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    if (layout.ndim == 4) {
        p.batch = layout[0];
        p.n = layout[1];
        p.m = layout[2];
        p.c = layout[3];
        if (strd(3, 1) && strd(1, p.c)) {
            auto t = p.c * p.n;
            return strd(2, t) && strd(0, t * p.m);
        }
        return false;
    }
    if (layout.ndim == 3) {
        if (strd(1, 1)) {
            // c == 1
            p.batch = layout[0];
            p.n = layout[1];
            p.m = layout[2];
            p.c = 1;
            return strd(2, p.n) && strd(0, p.m * p.n);
        }
        if (strd(2, 1)) {
            // b == 1
            p.batch = 1;
            p.n = layout[0];
            p.m = layout[1];
            p.c = layout[2];
            return strd(0, p.c) && strd(1, p.n * p.c);
        }
        return false;
    }
    if (layout.ndim == 2) {
        p.batch = 1;
        p.n = layout.shape[0];
        p.m = layout.shape[1];
        p.c = 1;
85 86 87 88 89 90 91 92 93
        if (strd(0, 1) && strd(1, p.n)) {
            return true;
        } else if (
                strd(0, 1) && layout.stride[1] > 0 &&
                (size_t)(layout.stride[1]) >= p.n && allow_no_contig) {
            //! stride_m used in no-contig mode, stride_m >= p.n
            p.stride_m = layout.stride[1];
            return true;
        }
94 95 96 97 98 99
    }
    return false;
}

}  // anonymous namespace

M
Megvii Engine Team 已提交
100
void RelayoutForward::check_layout_and_canonize(TensorLayout& src, TensorLayout& dst) {
101 102 103
    megdnn_assert(dst.is_non_overlapping_strong());
    src = src.collapse_contiguous();
    dst = dst.collapse_contiguous();
M
Megvii Engine Team 已提交
104 105 106 107
    megdnn_assert(
            src.dtype == dst.dtype && src.total_nr_elems() == dst.total_nr_elems(),
            "check %s == %s and %zu == %zu", src.dtype.name(), dst.dtype.name(),
            src.total_nr_elems(), dst.total_nr_elems());
108 109
}

M
Megvii Engine Team 已提交
110
bool relayout::is_transpose(
111 112 113
        const TensorLayout& src, const TensorLayout& dst, TransposeParam& p,
        bool allow_non_contig) {
    if (is_contig(dst) && is_transpose_single(src, p, allow_non_contig)) {
114 115 116 117 118 119
        // if the original intention is to transpose (m, n) to (n, m),
        // then we should use (n, m) as the contig dst and use a corrsponding
        // non-contig src with the same (n, m) shape (remember relayout is
        // defined on element correspondence on the logical view)
        return true;
    }
120
    if (is_contig(src) && is_transpose_single(dst, p, allow_non_contig)) {
121 122 123 124 125 126 127
        std::swap(p.m, p.n);
        return true;
    }
    return false;
}

// vim: syntax=cpp.doxygen