/** * \file dnn/test/common/mesh_indexing.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/basic_types.h" #include "megdnn/oprs/general.h" #include "rng.h" #include "test/common/indexing_multi_axis_vec.h" #include "test/common/opr_proxy.h" namespace megdnn { namespace test { #define MESH_INDEXING_LIKE_OPR_PROXY(__opr) \ template <> \ struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \ using OprProxyIndexingMultiAxisVecHelper:: \ OprProxyIndexingMultiAxisVecHelper; \ void exec(__opr* opr, const TensorNDArray& tensors) const { \ WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes( \ tensors[1].layout, axes, \ tensors.size() - 2)); \ opr->exec(tensors[0], make_index_desc(tensors), tensors[1], \ W.workspace()); \ } \ void deduce_layout(__opr* opr, TensorLayoutArray& layouts) { \ MEGDNN_MARK_USED_VAR(opr); \ MEGDNN_MARK_USED_VAR(layouts); \ opr->deduce_layout(layouts[0], make_index_layout(layouts), \ layouts[1]); \ } \ }; #define MESH_MODIFY_LIKE_OPR_PROXY(__opr) \ template <> \ struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \ using OprProxyIndexingMultiAxisVecHelper:: \ OprProxyIndexingMultiAxisVecHelper; \ void exec(__opr* opr, const TensorNDArray& tensors) const { \ WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes( \ tensors[1].layout, axes, \ tensors.size() - 2)); \ opr->exec(tensors[0], tensors[1], make_index_desc(tensors), \ W.workspace()); \ } \ void deduce_layout(__opr*, TensorLayoutArray&) {} \ }; MESH_INDEXING_LIKE_OPR_PROXY(MeshIndexing); MESH_INDEXING_LIKE_OPR_PROXY(BatchedMeshIndexing); MESH_MODIFY_LIKE_OPR_PROXY(IncrMeshIndexing); MESH_MODIFY_LIKE_OPR_PROXY(BatchedIncrMeshIndexing); MESH_MODIFY_LIKE_OPR_PROXY(SetMeshIndexing); MESH_MODIFY_LIKE_OPR_PROXY(BatchedSetMeshIndexing); #undef MESH_PROXY_COMMON #undef MESH_INDEXING_LIKE_OPR_PROXY #undef MESH_MODIFY_LIKE_OPR_PROXY namespace mesh_indexing { class NoReplacementIndexRNG final : public RNG { size_t& m_size; std::mt19937_64 m_rng; public: NoReplacementIndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {} void gen(const TensorND& tensor) override { std::vector seq; for (size_t i = 0; i < m_size; ++i) { seq.push_back(i); } size_t stride = static_cast(tensor.layout.stride[0]); size_t size = tensor.layout[0]; if (tensor.layout.ndim == 1) { stride = tensor.layout[0]; size = 1; } megdnn_assert(stride <= m_size); auto ptr = tensor.ptr(); for (size_t n = 0; n < size; ++n) { std::set used; COMPAT_RANDOM(seq.begin(), seq.end()); for (size_t step = 0; step < stride; ++step) { megdnn_assert(used.size() < m_size); ptr[n * stride + step] = seq[step]; used.insert(seq[step]); } } } }; } // namespace mesh_indexing } // namespace test } // namespace megdnn