/** * \file dnn/src/cuda/relayout/param_visitor.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "megdnn/basic_types.h" #include "src/cuda/int_fastdiv.cuh" #include "src/cuda/integer_subbyte_utils.cuh" #include "src/cuda/utils.cuh" #pragma once namespace megdnn { namespace cuda { #define devfunc __device__ __forceinline__ /*! * \brief contiguous type * If the layout is contiguous, then the type is CONTIG_FULL, CONTIG_OTHER * otherwise. */ enum ContigType { CONTIG_OTHER, CONTIG_FULL }; /* f{{{ ParamElemVisitor specialization */ /*! * \brief visitor to access an element in a tensor at given logic index * \tparam ctype plain element ctype (i.e. ctype in DTypeTrait) * \tparam contig_mask bit mask for contig of params; * * host interface: * void host_init( * const TensorND &tensor, int grid_size, int block_size) * * device interface: * void thread_init(uint32_t idx) * called on thread entrance, with logical indexing; the index y * go beyond buffer range * * ctype* ptr() * return buffer pointer; can be used by specialized OpCaller * * int offset(uint32_t idx) * get physical offset from logical index * * ctype& at(uint32_t idx) * ptr()[offset(idx)] * */ template class ParamElemVisitor; #define PARAM_ELEM_VISITOR_COMMON_DEV \ devfunc ctype* ptr() { return m_ptr; } \ devfunc ctype& at(uint32_t idx) { return m_ptr[offset(idx)]; } //! specialization for CONTIG_OTHER template class ParamElemVisitor { ctype* __restrict m_ptr; int m_stride[ndim]; //! m_shape_highdim[i] = original_shape[i + 1] #ifdef _MSC_VER Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1]; #else Uint32Fastdiv m_shape_highdim[ndim - 1]; #endif public: static const int NDIM = ndim; void host_init(const TensorND& rv, int grid_size, int block_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t) {} devfunc void next() {} devfunc int offset(uint32_t idx) { int offset = 0; #pragma unroll for (int i = ndim - 1; i >= 1; --i) { Uint32Fastdiv& shp = m_shape_highdim[i - 1]; uint32_t idx_div = idx / shp; offset += (idx - idx_div * shp.divisor()) * m_stride[i]; idx = idx_div; } offset += idx * m_stride[0]; return offset; } PARAM_ELEM_VISITOR_COMMON_DEV #endif }; //! specialization for CONTIG_FULL template class ParamElemVisitor { ctype* __restrict m_ptr; public: static const int NDIM = ndim; void host_init(const TensorND& rv, int grid_size, int block_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t) {} devfunc void next() {} devfunc int offset(uint32_t idx) { return idx; } PARAM_ELEM_VISITOR_COMMON_DEV #endif }; #undef PARAM_ELEM_VISITOR_COMMON_DEV template class ParamElemVisitor { using Storage = uint8_t; Storage* __restrict m_ptr; int m_stride[ndim]; int m_shape[ndim]; bool m_is_contiguous; bool m_is_physical_contiguous; //! m_shape_highdim[i] = original_shape[i + 1] #ifdef _MSC_VER Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1]; Uint32Fastdiv m_align_shape_highdim[ndim > 1 ? ndim - 1 : 1]; #else Uint32Fastdiv m_shape_highdim[ndim]; Uint32Fastdiv m_align_shape_highdim[ndim]; #endif public: static const Storage kMask = 0xf; static const Storage kBits = 4; static const int NDIM = ndim; void host_init(const TensorND& rv, int grid_size, int block_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t) {} devfunc void next() {} devfunc void get_shape_from_access(uint32_t access_idx, int (&shape_idx)[ndim]) { #pragma unroll for (int i = ndim - 1; i >= 1; --i) { Uint32Fastdiv& align_shp = m_align_shape_highdim[i - 1]; uint32_t access_idx_div = access_idx / align_shp; shape_idx[i] = access_idx - access_idx_div * align_shp.divisor(); access_idx = access_idx_div; } shape_idx[0] = access_idx; } devfunc int offset(uint32_t idx) { int offset = 0; #pragma unroll for (int i = ndim - 1; i >= 1; --i) { Uint32Fastdiv& shp = m_shape_highdim[i - 1]; uint32_t idx_div = idx / shp; offset += (idx - idx_div * shp.divisor()) * m_stride[i]; idx = idx_div; } offset += idx * m_stride[0]; return offset; } devfunc int offset_from_access(uint32_t access_idx) { int offset = 0; if (m_is_contiguous) { offset = access_idx; } else { int shape_idx[ndim]; get_shape_from_access(access_idx, shape_idx); #pragma unroll for (int i = ndim - 1; i >= 0; --i) { offset += shape_idx[i] * m_stride[i]; } } return offset; } devfunc int idx(uint32_t access_idx) { int idx = 0; if (m_is_physical_contiguous) { idx = access_idx; } else { int shape_idx[ndim]; bool valid = true; get_shape_from_access(access_idx, shape_idx); #pragma unroll for (int i = 0; i < ndim; ++i) { valid &= (shape_idx[i] < m_shape[i]); } for (int i = 0; i < ndim - 1; ++i) { idx = (idx + shape_idx[i]) * m_shape[i + 1]; } idx = valid ? idx + shape_idx[ndim - 1] : -1; } return idx; } devfunc Storage* ptr() { return m_ptr; } devfunc Storage at(uint32_t idx) { int offset_ = offset(idx); int vec_idx = offset_ >> 1; int lane_idx = offset_ & 0x1; Storage item = Storage(integer_subbyte::unpack_integer_4bits( *(Storage*)&m_ptr[vec_idx], lane_idx * 4)); return item; } using rwtype = typename elemwise_intl::VectTypeTrait::vect_type; devfunc rwtype make_vector(Storage x, Storage y) { return elemwise_intl::VectTypeTrait::make_vector(x, y); } #endif }; } // namespace cuda } // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen