mesh_indexing.h 4.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/test/common/mesh_indexing.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include "megdnn/basic_types.h"
#include "megdnn/oprs/general.h"
#include "rng.h"
#include "test/common/indexing_multi_axis_vec.h"
#include "test/common/opr_proxy.h"

namespace megdnn {
namespace test {

M
Megvii Engine Team 已提交
22 23 24 25 26 27 28 29
#define MESH_INDEXING_LIKE_OPR_PROXY(__opr)                                           \
    template <>                                                                       \
    struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper {              \
        using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
        void exec(__opr* opr, const TensorNDArray& tensors) const {                   \
            WorkspaceWrapper W(                                                       \
                    opr->handle(),                                                    \
                    opr->get_workspace_in_bytes(                                      \
30
                            tensors[1].layout, axes, tensors.size() - 2, 1));         \
M
Megvii Engine Team 已提交
31 32 33 34 35 36 37 38
            opr->exec(                                                                \
                    tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \
        }                                                                             \
        void deduce_layout(__opr* opr, TensorLayoutArray& layouts) {                  \
            MEGDNN_MARK_USED_VAR(opr);                                                \
            MEGDNN_MARK_USED_VAR(layouts);                                            \
            opr->deduce_layout(layouts[0], make_index_layout(layouts), layouts[1]);   \
        }                                                                             \
39 40
    };

M
Megvii Engine Team 已提交
41 42 43 44 45 46 47 48
#define MESH_MODIFY_LIKE_OPR_PROXY(__opr)                                             \
    template <>                                                                       \
    struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper {              \
        using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
        void exec(__opr* opr, const TensorNDArray& tensors) const {                   \
            WorkspaceWrapper W(                                                       \
                    opr->handle(),                                                    \
                    opr->get_workspace_in_bytes(                                      \
49
                            tensors[1].layout, axes, tensors.size() - 2, 1));         \
M
Megvii Engine Team 已提交
50 51 52 53
            opr->exec(                                                                \
                    tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \
        }                                                                             \
        void deduce_layout(__opr*, TensorLayoutArray&) {}                             \
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    };

MESH_INDEXING_LIKE_OPR_PROXY(MeshIndexing);
MESH_INDEXING_LIKE_OPR_PROXY(BatchedMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(IncrMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(BatchedIncrMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(SetMeshIndexing);
MESH_MODIFY_LIKE_OPR_PROXY(BatchedSetMeshIndexing);

#undef MESH_PROXY_COMMON
#undef MESH_INDEXING_LIKE_OPR_PROXY
#undef MESH_MODIFY_LIKE_OPR_PROXY

namespace mesh_indexing {
class NoReplacementIndexRNG final : public RNG {
    size_t& m_size;
    std::mt19937_64 m_rng;

public:
    NoReplacementIndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {}

    void gen(const TensorND& tensor) override {
        std::vector<int> seq;
        for (size_t i = 0; i < m_size; ++i) {
            seq.push_back(i);
        }
        size_t stride = static_cast<size_t>(tensor.layout.stride[0]);
        size_t size = tensor.layout[0];
        if (tensor.layout.ndim == 1) {
            stride = tensor.layout[0];
            size = 1;
        }
        megdnn_assert(stride <= m_size);

        auto ptr = tensor.ptr<int>();
        for (size_t n = 0; n < size; ++n) {
            std::set<int> used;
91
            COMPAT_RANDOM(seq.begin(), seq.end());
92 93 94 95 96 97 98 99 100 101 102 103
            for (size_t step = 0; step < stride; ++step) {
                megdnn_assert(used.size() < m_size);
                ptr[n * stride + step] = seq[step];
                used.insert(seq[step]);
            }
        }
    }
};
}  // namespace mesh_indexing

}  // namespace test
}  // namespace megdnn