mesh_indexing.h 4.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/test/common/mesh_indexing.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include "megdnn/basic_types.h"
#include "megdnn/oprs/general.h"
#include "rng.h"
#include "test/common/indexing_multi_axis_vec.h"
#include "test/common/opr_proxy.h"

namespace megdnn {
namespace test {

#define MESH_INDEXING_LIKE_OPR_PROXY(__opr)                                    \
    template <>                                                                \
    struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper {       \
        using OprProxyIndexingMultiAxisVecHelper::                             \
                OprProxyIndexingMultiAxisVecHelper;                            \
        void exec(__opr* opr, const TensorNDArray& tensors) const {            \
            WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes(     \
                                                      tensors[1].layout, axes, \
                                                      tensors.size() - 2));    \
            opr->exec(tensors[0], make_index_desc(tensors), tensors[1],        \
                      W.workspace());                                          \
        }                                                                      \
        void deduce_layout(__opr* opr, TensorLayoutArray& layouts) {           \
            MEGDNN_MARK_USED_VAR(opr);                                         \
            MEGDNN_MARK_USED_VAR(layouts);                                     \
            opr->deduce_layout(layouts[0], make_index_layout(layouts),         \
                               layouts[1]);                                    \
        }                                                                      \
    };

#define MESH_MODIFY_LIKE_OPR_PROXY(__opr)                                      \
    template <>                                                                \
    struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper {       \
        using OprProxyIndexingMultiAxisVecHelper::                             \
                OprProxyIndexingMultiAxisVecHelper;                            \
        void exec(__opr* opr, const TensorNDArray& tensors) const {            \
            WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes(     \
                                                      tensors[1].layout, axes, \
                                                      tensors.size() - 2));    \
            opr->exec(tensors[0], tensors[1], make_index_desc(tensors),        \
                      W.workspace());                                          \
        }                                                                      \
        void deduce_layout(__opr*, TensorLayoutArray&) {}                      \
    };

MESH_INDEXING_LIKE_OPR_PROXY(MeshIndexing);
MESH_INDEXING_LIKE_OPR_PROXY(BatchedMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(IncrMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(BatchedIncrMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(SetMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(BatchedSetMeshIndexing);

#undef MESH_PROXY_COMMON
#undef MESH_INDEXING_LIKE_OPR_PROXY
#undef MESH_MODIFY_LIKE_OPR_PROXY

namespace mesh_indexing {
class NoReplacementIndexRNG final : public RNG {
    size_t& m_size;
    std::mt19937_64 m_rng;

public:
    NoReplacementIndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {}

    void gen(const TensorND& tensor) override {
        std::vector<int> seq;
        for (size_t i = 0; i < m_size; ++i) {
            seq.push_back(i);
        }
        size_t stride = static_cast<size_t>(tensor.layout.stride[0]);
        size_t size = tensor.layout[0];
        if (tensor.layout.ndim == 1) {
            stride = tensor.layout[0];
            size = 1;
        }
        megdnn_assert(stride <= m_size);

        auto ptr = tensor.ptr<int>();
        for (size_t n = 0; n < size; ++n) {
            std::set<int> used;
92
            COMPAT_RANDOM(seq.begin(), seq.end());
93 94 95 96 97 98 99 100 101 102 103 104
            for (size_t step = 0; step < stride; ++step) {
                megdnn_assert(used.size() < m_size);
                ptr[n * stride + step] = seq[step];
                used.insert(seq[step]);
            }
        }
    }
};
}  // namespace mesh_indexing

}  // namespace test
}  // namespace megdnn