param_pack.cpp 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/**
 * \file dnn/src/common/param_pack.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megdnn/oprs/general.h"
#include "src/common/utils.h"

using namespace megdnn;

void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated,
18
                                          const TensorLayout& offsets,
19
                                          const TensorLayout& parts) {
20 21 22 23
    megdnn_assert(offsets.dtype == dtype::Int32{}, "bad dtype: %s",
                  offsets.dtype.name());
    megdnn_assert(concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 &&
                          concated.stride[0] == 1 && offsets.stride[0] == 1 &&
24
                          parts.stride[0] == 1,
25 26
                  "bad layout: concated=%s offsets=%s parts=%s",
                  concated.to_string().c_str(), offsets.to_string().c_str(),
27 28 29
                  parts.to_string().c_str());
}

30
std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets(
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
        const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) {
    megdnn_assert(alignment && (alignment & (alignment - 1)) == 0,
                  "alignment must be power of 2: %zu", alignment);
    if (alignment < dtype_size)
        alignment = dtype_size;

    megdnn_assert(alignment % dtype_size == 0,
                  "alignment must be multiple of dtype size: %zu vs %zu",
                  alignment, dtype_size);
    alignment /= dtype_size;

    auto get_aligned = [alignment](size_t v) {
        auto mod = v & (alignment - 1);
        return v + ((alignment - mod) & (alignment - 1));
    };

47
    std::vector<dt_int32> offsets(shapes.size() << 1);
48
    size_t offset = 0;
49
    for (size_t i = 0; i < shapes.size(); i++) {
50
        offset = get_aligned(offset);
51
        offsets[i << 1] = offset;
52
        offset += shapes[i].total_nr_elems();
53
        offsets[(i << 1) + 1] = offset;
54
    }
55
    return offsets;
56 57 58
}

// vim: syntax=cpp.doxygen