/** * \file dnn/src/cuda/relayout/param_visitor.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/common/utils.h" #include "src/cuda/query_blocksize.cuh" #include "src/cuda/utils.h" #include "src/cuda/relayout/kern.cuh" #include "src/cuda/relayout/kern_contiguous.cuh" namespace megdnn { namespace cuda { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Warray-bounds" template void ParamElemVisitor::host_init( const TensorND& rv, int /*grid_size*/, int /*block_size*/) { megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim); m_ptr = rv.ptr(); for (size_t i = 0; i < rv.layout.ndim; ++i) { m_stride[i] = rv.layout.stride[i]; if (i + 1 < rv.layout.ndim) m_shape_highdim[i] = rv.layout.shape[i + 1]; } for (int i = rv.layout.ndim - 1; i < ndim - 1; ++i) { m_shape_highdim[i] = 1; } for (int i = rv.layout.ndim; i < ndim; ++i) { m_stride[i] = 0; } } #pragma GCC diagnostic pop template void ParamElemVisitor::host_init(const TensorND &rv, int /*grid_size*/, int /*block_size*/) { megdnn_assert_contiguous(rv.layout); m_ptr = rv.ptr(); } #define INST(ndim, ctype, ctg) template class ParamElemVisitor #define INST_FOR_CTYPE MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) #define ndim_cb(_ndim) \ INST(_ndim, ct, CONTIG_OTHER); \ INST(_ndim, ct, CONTIG_FULL); #define ct dt_byte INST_FOR_CTYPE #undef ct #define ct dt_int32 INST_FOR_CTYPE #undef ct #define ct dt_float16 INST_FOR_CTYPE #undef ct #undef ndim_cb #undef INST_FOR_CTYPE #undef INST template void ParamElemVisitor::host_init( const TensorND& rv, int /*grid_size*/, int /*block_size*/) { megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim); m_ptr = reinterpret_cast(rv.raw_ptr); auto min_stride = rv.layout.stride[0]; for (size_t i = 0; i < rv.layout.ndim; ++i) { m_stride[i] = rv.layout.stride[i]; m_shape[i] = rv.layout.shape[i]; if (i + 1 < rv.layout.ndim) { m_shape_highdim[i] = rv.layout.shape[i + 1]; if (rv.layout.stride[i + 1] == 1) m_align_shape_highdim[i] = (uint32_t)round_up((int)rv.layout.shape[i + 1], 2); else m_align_shape_highdim[i] = rv.layout.shape[i + 1]; } if (min_stride > rv.layout.stride[i]) { min_stride = rv.layout.stride[i]; } } megdnn_assert(min_stride == 1 || min_stride == 2); m_is_min_stride_2 = (min_stride == 2); for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { m_shape_highdim[i] = 1; m_align_shape_highdim[i] = 1; } for (size_t i = rv.layout.ndim; i < ndim; ++i) { m_stride[i] = 0; m_shape[i] = 1; } m_is_physical_contiguous = rv.layout.is_physical_contiguous(); m_is_contiguous = rv.layout.is_contiguous(); } #define INST(ndim, ctg) template class ParamElemVisitor #define ndim_cb(_ndim) INST(_ndim, CONTIG_OTHER); MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) #undef ndim_cb #undef INST } // namespace cuda } // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen