relayout.cpp 3.4 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/common/relayout.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
 */

#include "megdnn/oprs.h"
#include "src/common/relayout_helper.h"
#include "src/common/utils.h"

#include <algorithm>

using namespace megdnn;
using namespace megdnn::relayout;

namespace {

//! whether current shape is [b][n][m][c] and is a transpose of contig
//! [b][m][n][c]
bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) {
    /*
     * assuming contig layout is:
     *  shape: b, m, n, c
     *  stride: mnc, nc, c, 1
     *
     * then given layout should be:
     *  shape: b, n, m, c
     *  stride: mnc, c, nc, 1
     *
     * if c == 1:
     *  shape: b, n, m
     *  stride: mn, 1, n
     * if b == 1:
     *  shape: n, m, c
     *  stride: c, nc, 1
     *
     * if b == 1 && c == 1:
     *  shape: n, m
     *  stride: 1, n
     */
    auto strd = [&](size_t idx, ptrdiff_t v) {
        return layout.stride[idx] == v;
    };
    if (layout.ndim == 4) {
        p.batch = layout[0];
        p.n = layout[1];
        p.m = layout[2];
        p.c = layout[3];
        if (strd(3, 1) && strd(1, p.c)) {
            auto t = p.c * p.n;
            return strd(2, t) && strd(0, t * p.m);
        }
        return false;
    }
    if (layout.ndim == 3) {
        if (strd(1, 1)) {
            // c == 1
            p.batch = layout[0];
            p.n = layout[1];
            p.m = layout[2];
            p.c = 1;
            return strd(2, p.n) && strd(0, p.m * p.n);
        }
        if (strd(2, 1)) {
            // b == 1
            p.batch = 1;
            p.n = layout[0];
            p.m = layout[1];
            p.c = layout[2];
            return strd(0, p.c) && strd(1, p.n * p.c);
        }
        return false;
    }
    if (layout.ndim == 2) {
        p.batch = 1;
        p.n = layout.shape[0];
        p.m = layout.shape[1];
        p.c = 1;
        return strd(0, 1) && strd(1, p.n);
    }
    return false;
}

}  // anonymous namespace

void RelayoutForward::check_layout_and_canonize(TensorLayout& src,
                                                TensorLayout& dst) {
    megdnn_assert(dst.is_non_overlapping_strong());
    src = src.collapse_contiguous();
    dst = dst.collapse_contiguous();
    megdnn_assert(src.dtype == dst.dtype &&
98 99 100
                          src.total_nr_elems() == dst.total_nr_elems(),
                  "check %s == %s and %zu == %zu", src.dtype.name(),
                  dst.dtype.name(), src.total_nr_elems(), dst.total_nr_elems());
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
}

bool relayout::is_transpose(const TensorLayout& src, const TensorLayout& dst,
                            TransposeParam& p) {
    if (is_contig(dst) && is_transpose_single(src, p)) {
        // if the original intention is to transpose (m, n) to (n, m),
        // then we should use (n, m) as the contig dst and use a corrsponding
        // non-contig src with the same (n, m) shape (remember relayout is
        // defined on element correspondence on the logical view)
        return true;
    }
    if (is_contig(src) && is_transpose_single(dst, p)) {
        std::swap(p.m, p.n);
        return true;
    }
    return false;
}

// vim: syntax=cpp.doxygen