diff --git a/dnn/include/megdnn/oprs/imgproc.h b/dnn/include/megdnn/oprs/imgproc.h index 9f951753b999274e9c37e080407880399adf779d..0a8aea551358d864017cfef936089139cd6b149c 100644 --- a/dnn/include/megdnn/oprs/imgproc.h +++ b/dnn/include/megdnn/oprs/imgproc.h @@ -182,6 +182,48 @@ class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { size_t workspace_in_bytes); }; +class DctChannelSelectForward : public OperatorBase { + DEF_OPR_PARAM(DctChannelSelect); + DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1); + +public: + /** + * \param[in] DctChannelSelectForward input, must be uint8 nchw tensor + * \param[in] mask_offset input, must be int32 nchw tensor + * \param[in] mask_val input, must be int32 nchw tensor + * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor + * \param[out] workspace temporary workspace to perform forward + */ + virtual void exec(_megdnn_tensor_in src, + _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + + void deduce_layout(const TensorLayout& src, + const TensorLayout& mask_offset, + const TensorLayout& mask_val, + TensorLayout& dst); + + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& mask_offset, + const TensorLayout& mask_val, + const TensorLayout& dst) = 0; + +protected: + void check_layout_fwd(const TensorLayout& src, + const TensorLayout& mask_offset, + const TensorLayout& mask_val, + const TensorLayout& dst); + + void deduce_layout_fwd(const TensorLayout& src, + const TensorLayout& mask_offset, + const TensorLayout& mask_val, + TensorLayout& dst); + + std::string param_msg() const; +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 1b16b5dbb76eb5a9faf8fdce54a949183cef4da6..6379eba1626891917ee898c600bc2369e822487c 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -411,6 +411,9 @@ pdef('ElemwiseMultiType').add_enum( pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) +(pdef('DctChannelSelect', '2d discrete cosine transform').add_enum_alias('Format', 'ConvolutionV0'). + add_enum('FastImpl', 'NONE', 'FIX_32_MASK').add_fields('int32', 'dct_block_size', 8)) + (pdef('MatrixMul', version=0, is_legacy=True). add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). add_enum('DataType', diff --git a/dnn/src/common/dct.cpp b/dnn/src/common/dct.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8441115b48cc552c657010cde56fe5aa2b4fae8c --- /dev/null +++ b/dnn/src/common/dct.cpp @@ -0,0 +1,82 @@ +/** + * \file dnn/src/common/dct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs.h" + +#include "src/common/utils.h" + +namespace megdnn { + +void DctChannelSelectForward::deduce_layout_fwd(const TensorLayout& src, + const TensorLayout& mask_offset, + const TensorLayout& mask_val, + TensorLayout& dst) { + const size_t dct_block = param().dct_block_size; + const size_t in = src.shape[0]; + const size_t ic = src.shape[1]; + const size_t ih = src.shape[2]; + const size_t iw = src.shape[3]; + check_layout_fwd(src, mask_offset, mask_val, dst); + const size_t oh = ih / dct_block; + const size_t ow = iw / dct_block; + //! mask will be empty or (ic + 1) elements + size_t oc = mask_offset.ndim > 0 && mask_offset[0] >= 2 + ? mask_val.shape[0] + : ic * dct_block * dct_block; + if (param().fastImpl == Param::FastImpl::FIX_32_MASK) { + megdnn_assert(oc == 32, + "Param::FastImpl::FIX_32_MASK oc must be 32, but %zu", + oc); + } + if (param().format == Param::Format::NCHW) { + dst = TensorLayout(TensorShape({in, oc, oh, ow}), dst.dtype); + } else { + megdnn_assert(param().format == Param::Format::NCHW4, + "dct format must be nchw or nchw4"); + megdnn_assert(oc % 4 == 0, "oc mod 4 == 0 in nchw4"); + dst = TensorLayout(TensorShape({in, oc / 4, oh, ow, 4}), dst.dtype); + } +} + +void DctChannelSelectForward::deduce_layout(const TensorLayout& src, + const TensorLayout& mask_offset, + const TensorLayout& mask_val, + TensorLayout& dst) { + deduce_layout_fwd(src, mask_offset, mask_val, dst); +} + +void DctChannelSelectForward::check_layout_fwd(const TensorLayout& src, + const TensorLayout& mask_offset, + const TensorLayout& mask_val, + const TensorLayout& dst) { + const size_t dct_block = param().dct_block_size; + const size_t ih = src.shape[2]; + const size_t iw = src.shape[3]; + + megdnn_assert(mask_offset.ndim == 0 || (mask_offset.ndim == 1 && + (mask_offset.shape[0] == 0 || + mask_offset.shape[0] >= 2) && + mask_val.ndim == 1), + "mask only support one valid dim"); + megdnn_assert(mask_val.ndim <= 1, "only support one dim"); + megdnn_assert(src.dtype.enumv() == DTypeEnum::Uint8, + "src.dtype == dtype::Uint8"); + megdnn_assert(dst.dtype.enumv() == DTypeEnum::Float32 || + dst.dtype.enumv() == DTypeEnum::QuantizedS8, + "dst.dtype == dtype::Float32 || dst.dtype.enumv() == " + "DTypeEnum::QuantizedS8"); + megdnn_assert(ih % dct_block == 0, "ih mod dctblock == 0"); + megdnn_assert(iw % dct_block == 0, "iw mod dctblock == 0"); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 8446997282c7b42b7856c515521803ec2dc5de8a..84df453d74c127ba9fbdd35e9f304b2929d86cae 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -201,6 +201,7 @@ private: cb(RemapBackwardMat) \ cb(AdaptivePoolingForward) \ cb(AdaptivePoolingBackward) \ + cb(DctChannelSelectForward) /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/cuda/dct/dct_channel_select.cu b/dnn/src/cuda/dct/dct_channel_select.cu new file mode 100644 index 0000000000000000000000000000000000000000..17e5d5b7d3b6b146aadbf583710103b8ddf4b8c0 --- /dev/null +++ b/dnn/src/cuda/dct/dct_channel_select.cu @@ -0,0 +1,429 @@ +/** + * \file dnn/src/cuda/dct/dct_channel_select.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megcore_cdefs.h" +#include "src/cuda/dct/dct_channel_select.cuh" +#include "src/cuda/error_info.cuh" + +namespace megdnn { +namespace cuda { + +template +struct CudaPostProcess; + +template <> +struct CudaPostProcess { + CudaPostProcess(float){}; + static inline __device__ float func(float val) { return val; } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_type_cvt; + CudaPostProcess(float scale) { m_type_cvt.inv_scale = 1.f / scale; }; + inline __device__ int8_t func(float val) { + return m_type_cvt.quantize(val).as_int8(); + } +}; + +template +struct ChannelBlockHelper; +template <> +struct ChannelBlockHelper { + static constexpr int channel_block = 4; +}; +template <> +struct ChannelBlockHelper { + static constexpr int channel_block = 1; +}; + +namespace dct { +namespace { +inline __device__ void load_row(float (&row_cache)[8], const uint8_t* src) { + int2 row = *((int2*)src); + row_cache[0] = (float)(((uchar4*)&(row.x))->x); + row_cache[1] = (float)(((uchar4*)&(row.x))->y); + row_cache[2] = (float)(((uchar4*)&(row.x))->z); + row_cache[3] = (float)(((uchar4*)&(row.x))->w); + row_cache[4] = (float)(((uchar4*)&(row.y))->x); + row_cache[5] = (float)(((uchar4*)&(row.y))->y); + row_cache[6] = (float)(((uchar4*)&(row.y))->z); + row_cache[7] = (float)(((uchar4*)&(row.y))->w); +} + +inline __device__ void fast_dct_1d_internel(float& src0, float& src1, + float& src2, float& src3, + float& src4, float& src5, + float& src6, float& src7) { + constexpr float rsqrt_8 = 0.3535533905932737f; //!< rsqrt_8 = sqrt(1 / 8) + constexpr float a = 1.387039845322148f; //!< a = sqrt2 * cos(pi * 1 / 16) + constexpr float b = 1.306562964876377f; //!< b = sqrt2 * cos(pi * 2 / 16) + constexpr float c = 1.175875602419359f; //!< c = sqrt2 * cos(pi * 3 / 16) + constexpr float d = 0.785694958387102f; //!< d = sqrt2 * cos(pi * 5 / 16) + constexpr float e = 0.541196100146197f; //!< e = sqrt2 * cos(pi * 6 / 16) + constexpr float f = 0.275899379282943f; //!< f = sqrt2 * cos(pi * 7 / 16) + + const float add_0_7 = src0 + src7; + const float add_1_6 = src1 + src6; + const float add_2_5 = src2 + src5; + const float add_3_4 = src3 + src4; + const float sub_0_7 = src0 - src7; + const float sub_6_1 = src6 - src1; + const float sub_2_5 = src2 - src5; + const float sub_4_3 = src4 - src3; + + const float add_0_7_3_4 = add_0_7 + add_3_4; + const float add_1_6_2_5 = add_1_6 + add_2_5; + const float add_0_7_sub_3_4 = add_0_7 - add_3_4; + const float add_1_6_sub_2_5 = add_1_6 - add_2_5; + + src0 = rsqrt_8 * (add_0_7_3_4 + add_1_6_2_5); + src2 = rsqrt_8 * (b * add_0_7_sub_3_4 + e * add_1_6_sub_2_5); + src4 = rsqrt_8 * (add_0_7_3_4 - add_1_6_2_5); + src6 = rsqrt_8 * (e * add_0_7_sub_3_4 - b * add_1_6_sub_2_5); + + src1 = rsqrt_8 * (a * sub_0_7 - c * sub_6_1 + d * sub_2_5 - f * sub_4_3); + src3 = rsqrt_8 * (c * sub_0_7 + f * sub_6_1 - a * sub_2_5 + d * sub_4_3); + src5 = rsqrt_8 * (d * sub_0_7 + a * sub_6_1 + f * sub_2_5 - c * sub_4_3); + src7 = rsqrt_8 * (f * sub_0_7 + d * sub_6_1 + c * sub_2_5 + a * sub_4_3); +} + +inline __device__ void fast_dct_1d(float (&src)[8]) { + fast_dct_1d_internel(src[0], src[1], src[2], src[3], src[4], src[5], src[6], + src[7]); +} + +inline __device__ void fast_dct_1d_col(float (&src)[8][8], const int col) { + fast_dct_1d_internel(src[0][col], src[1][col], src[2][col], src[3][col], + src[4][col], src[5][col], src[6][col], src[7][col]); +} +enum class MaskType { + NO_MASK = 0, + USER_DEFINE_MASK = 1, + FIX_32_MASK = 2, + MASK_END +}; +template +struct StoreMask; + +template +struct StoreMask { + static inline __device__ void func( + const float (&thread_cache)[dct_block][dct_block], float* dst_tid, + const int oc_stride, int channel_idx, const int* mask_offset, + const int* mask_val, CudaPostProcess& quant_param, + megcore::AsyncErrorInfo* error_info, void* error_tracker) { + __shared__ float shared[dct_block][dct_block][block_oh][block_ow]; +#pragma unroll + for (int i = 0; i < dct_block; ++i) +#pragma unroll + for (int j = 0; j < dct_block; ++j) { + shared[i][j][threadIdx.y][threadIdx.x] = thread_cache[i][j]; + } + const int store_channel_offset = mask_offset[channel_idx]; + const int nr_store_channel = + mask_offset[channel_idx + 1] - store_channel_offset; + if (nr_store_channel < 0) { + set_async_error_info(error_info, error_tracker, + "nchw sub mask len must > 0"); + } + for (int store_channel_idx = 0; store_channel_idx < nr_store_channel; + ++store_channel_idx) { + const int index = + mask_val[store_channel_offset + store_channel_idx]; + dst_tid[store_channel_idx * oc_stride] = + shared[index / dct_block][index % dct_block][threadIdx.y] + [threadIdx.x]; + } + } +}; + +template +struct StoreMask { + static inline __device__ void func( + const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid, + const int oc_stride, int channel_idx, const int* mask_offset, + const int* mask_val, CudaPostProcess& quant_param, + megcore::AsyncErrorInfo* error_info, void* error_tracker) { + //! nchw4 channel_block is 4 + constexpr int channel_block = + ChannelBlockHelper::channel_block; + __shared__ float shared[dct_block][dct_block][block_oh][block_ow]; +#pragma unroll + for (int i = 0; i < dct_block; ++i) +#pragma unroll + for (int j = 0; j < dct_block; ++j) { + shared[i][j][threadIdx.y][threadIdx.x] = thread_cache[i][j]; + } + const int store_channel_offset = mask_offset[channel_idx]; + const int nr_store_channel = + mask_offset[channel_idx + 1] - store_channel_offset; + if (nr_store_channel % 4 != 0 || nr_store_channel < 0) { + set_async_error_info(error_info, error_tracker, + "nchw4 sub_mask_len mod 4 should be 0 and " + "sub_mask_len must > 0"); + } + for (int store_channel_idx = 0; store_channel_idx < nr_store_channel; + store_channel_idx += channel_block) { + const int index0 = + mask_val[store_channel_offset + store_channel_idx]; + const int index1 = + mask_val[store_channel_offset + store_channel_idx + 1]; + const int index2 = + mask_val[store_channel_offset + store_channel_idx + 2]; + const int index3 = + mask_val[store_channel_offset + store_channel_idx + 3]; + const int store_c4_idx = store_channel_idx / channel_block; + *(char4*)(&dst_tid[store_c4_idx * channel_block * oc_stride]) = { + quant_param.func( + shared[index0 / dct_block][index0 % dct_block] + [threadIdx.y][threadIdx.x]), + quant_param.func( + shared[index1 / dct_block][index1 % dct_block] + [threadIdx.y][threadIdx.x]), + quant_param.func( + shared[index2 / dct_block][index2 % dct_block] + [threadIdx.y][threadIdx.x]), + quant_param.func( + shared[index3 / dct_block][index3 % dct_block] + [threadIdx.y][threadIdx.x])}; + } + } +}; + +template +struct StoreMask { + static inline __device__ void func( + const float (&thread_cache)[dct_block][dct_block], + DstDtype* dst_tid, const int oc_stride, int channel_idx, + const int* mask_offset, const int* mask_val, + CudaPostProcess& quant_param, + megcore::AsyncErrorInfo* error_info, void* error_tracker) { + constexpr int channel_block = ChannelBlockHelper::channel_block; +#pragma unroll + for (int i = 0; i < dct_block; i++) { +#pragma unroll + for (int j = 0; j < dct_block; j++) { + dst_tid[(i * dct_block + j) / channel_block * channel_block * + oc_stride + + (i * dct_block + j) % channel_block] = + quant_param.func(thread_cache[i][j]); + } + } + } +}; + +template +struct StoreMask { + static inline __device__ void func( + const float (&thread_cache)[dct_block][dct_block], float* dst_tid, + const int oc_stride, int channel_idx, const int* mask_offset, + const int* mask_val, CudaPostProcess& quant_param, + megcore::AsyncErrorInfo* error_info, void* error_tracker) { +#define STORE(store_index, index) \ + dst_tid[store_index * oc_stride] = \ + thread_cache[index / dct_block][index % dct_block] + + STORE(0, 0); + STORE(1, 1); + STORE(2, 8); + STORE(3, 16); + STORE(4, 9); + STORE(5, 2); + STORE(6, 3); + STORE(7, 10); + + if (channel_idx == 0) { + STORE(8, 17); + STORE(9, 24); + STORE(10, 32); + STORE(11, 25); + STORE(12, 18); + STORE(13, 11); + STORE(14, 4); + STORE(15, 5); + } +#undef STORE + } +}; + +template +struct StoreMask { + static inline __device__ void func( + const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid, + const int oc_stride, int channel_idx, const int* mask_offset, + const int* mask_val, CudaPostProcess& quant_param, + megcore::AsyncErrorInfo* error_info, void* error_tracker) { +#define STORE(store_index, index0, index1, index2, index3) \ + *(char4*)(&dst_tid[store_index * oc_stride]) = { \ + quant_param.func( \ + thread_cache[index0 / dct_block][index0 % dct_block]), \ + quant_param.func( \ + thread_cache[index1 / dct_block][index1 % dct_block]), \ + quant_param.func( \ + thread_cache[index2 / dct_block][index2 % dct_block]), \ + quant_param.func( \ + thread_cache[index3 / dct_block][index3 % dct_block])} + + STORE(0, 0, 1, 8, 16); + STORE(4, 9, 2, 3, 10); + if (channel_idx == 0) { + STORE(8, 17, 24, 32, 25); + STORE(12, 18, 11, 4, 5); + } +#undef STORE + } +}; + +template +__global__ void kern_dct(const uint8_t* src, DstDtype* dst, const int n, + const int c, const int h, const int w, const int oh, + const int ow, const int oc_stride, const int oc, + const int* mask_offset, const int* mask_val, + CudaPostProcess quant_param, + megcore::AsyncErrorInfo* error_info, + void* error_tracker) { + constexpr int block_oh = ker_block_h / dct_block; + constexpr int block_ow = ker_block_w / dct_block; + const int channel_stride = h * w; + const int oc_idx = blockIdx.z % c; + const int oh_idx = blockIdx.y * block_oh + threadIdx.y; + const int ow_idx = blockIdx.x * block_ow + threadIdx.x; + float thread_cache[dct_block][dct_block]; + const uint8_t* src_tid = + src + blockIdx.z * channel_stride + + (blockIdx.y * ker_block_h + threadIdx.y * dct_block) * w + + (blockIdx.x * ker_block_w + threadIdx.x * dct_block); + const int inner_channel_offset = + (oh_idx * ow + ow_idx) * ChannelBlockHelper::channel_block; + + DstDtype* dst_tid = + dst + blockIdx.z * channel_stride + inner_channel_offset; + if (mask_type != MaskType::NO_MASK) { + const int batch_idx = blockIdx.z / c; + const int batch_stride = oc_stride * oc; + int out_channel_offset = 0; + if (mask_type == MaskType::FIX_32_MASK) { + //! trick out_channel_offset = {0, 16, 24}[oc_idx]; oc_idx = 0, 1, 2 + out_channel_offset = 16 * oc_idx - 8 * (oc_idx >> 1); + } else { + out_channel_offset = mask_offset[oc_idx]; + } + dst_tid = dst + batch_idx * batch_stride + + out_channel_offset * oc_stride + inner_channel_offset; + } + + if (oh_idx < oh && ow_idx < ow) { + load_row(thread_cache[0], src_tid + 0 * w); + load_row(thread_cache[1], src_tid + 1 * w); + load_row(thread_cache[2], src_tid + 2 * w); + load_row(thread_cache[3], src_tid + 3 * w); + load_row(thread_cache[4], src_tid + 4 * w); + load_row(thread_cache[5], src_tid + 5 * w); + load_row(thread_cache[6], src_tid + 6 * w); + load_row(thread_cache[7], src_tid + 7 * w); + + //! TMP = A @ C.T + fast_dct_1d(thread_cache[0]); + fast_dct_1d(thread_cache[1]); + fast_dct_1d(thread_cache[2]); + fast_dct_1d(thread_cache[3]); + fast_dct_1d(thread_cache[4]); + fast_dct_1d(thread_cache[5]); + fast_dct_1d(thread_cache[6]); + fast_dct_1d(thread_cache[7]); + + //! TMP = C @ TMP + fast_dct_1d_col(thread_cache, 0); + fast_dct_1d_col(thread_cache, 1); + fast_dct_1d_col(thread_cache, 2); + fast_dct_1d_col(thread_cache, 3); + fast_dct_1d_col(thread_cache, 4); + fast_dct_1d_col(thread_cache, 5); + fast_dct_1d_col(thread_cache, 6); + fast_dct_1d_col(thread_cache, 7); + + StoreMask::func(thread_cache, dst_tid, oc_stride, oc_idx, + mask_offset, mask_val, quant_param, error_info, + error_tracker); + } +} + +} // namespace + +template +void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n, + const int c, const int h, const int w, const int oc, + bool fix_32_mask, const int* mask_offset, + const int* mask_val, cudaStream_t stream, + megcore::AsyncErrorInfo* error_info, void* error_tracker, + float scale) { + constexpr int ker_block_h = 32; + constexpr int ker_block_w = 256; + const int oh = h / dct_block; + const int ow = w / dct_block; + const int oc_stride = oh * ow; + const dim3 block_dim(DIVUP(w, ker_block_w), DIVUP(h, ker_block_h), n * c); + const dim3 thread_dim(DIVUP(ker_block_w, dct_block), + DIVUP(ker_block_h, dct_block)); + auto cuda_dtype_param = CudaPostProcess(scale); + if (fix_32_mask) { + kern_dct<<>>( + d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, + mask_val, cuda_dtype_param, error_info, error_tracker); + } else if (mask_offset && mask_val) { + kern_dct<<>>( + d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, + mask_val, cuda_dtype_param, error_info, error_tracker); + } else { + kern_dct + <<>>( + d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, + mask_offset, mask_val, cuda_dtype_param, error_info, + error_tracker); + } +} + +template void call_kern_dct<8, DctLayoutFormat::NCHW, float>( + const uint8_t* d_src, float* d_dst, const int n, const int c, + const int h, const int w, const int oc, bool fix_32_mask, + const int* mask_offset, const int* mask_val, cudaStream_t stream, + megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale); + +template void call_kern_dct<8, DctLayoutFormat::NCHW4, int8_t>( + const uint8_t* d_src, int8_t* d_dst, const int n, const int c, + const int h, const int w, const int oc, bool fix_32_mask, + const int* mask_offset, const int* mask_val, cudaStream_t stream, + megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale); + +} // namespace dct + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/cuda/dct/dct_channel_select.cuh b/dnn/src/cuda/dct/dct_channel_select.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ae10374b1eaa540afbed16e65651b38b124d4660 --- /dev/null +++ b/dnn/src/cuda/dct/dct_channel_select.cuh @@ -0,0 +1,38 @@ +/** + * \file dnn/src/cuda/dct/dct_channel_select.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the + "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express + or + * implied. + */ +#pragma once +#include +#include +#include "src/common/opr_param_defs_enumv.cuh" +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace dct { + +using DctLayoutFormat = megdnn::param_enumv::DctChannelSelect::Format; + +template +void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n, + const int c, const int h, const int w, const int oc, + bool fix_32_mask, const int* mask_offset, + const int* mask_val, cudaStream_t stream, + megcore::AsyncErrorInfo* error_info, void* error_tracker, + float scale = 1.f); + +} // namespace dct +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/cuda/dct/opr_impl.cpp b/dnn/src/cuda/dct/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7865e62ce7ef917a61d0e3b70dfdd3d930a3a0d1 --- /dev/null +++ b/dnn/src/cuda/dct/opr_impl.cpp @@ -0,0 +1,73 @@ +/** + * \file dnn/src/naive/dct/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/common/utils.h" +#include "src/cuda/dct/dct_channel_select.cuh" +#include "src/cuda/dct/opr_impl.h" +#include "src/cuda/handle.h" +#include "src/cuda/utils.h" +namespace megdnn { +namespace cuda { + +void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, + _megdnn_tensor_out dst, + _megdnn_workspace /*workspace*/) { + auto stream = cuda_stream(this->handle()); + const int in = src.layout.shape[0]; + const int ic = src.layout.shape[1]; + const int ih = src.layout.shape[2]; + const int iw = src.layout.shape[3]; + int oc = dst.layout.shape[1]; + const bool with_fix_32_mask = + param().fastImpl == Param::FastImpl::FIX_32_MASK; + if (param().format == Param::Format::NCHW4) { + megdnn_assert(dst.layout.ndim == 5 && dst.layout.shape[4] == 4, + "dst must be nchw4"); + oc = oc * 4; + } + megdnn_assert(!with_fix_32_mask || (with_fix_32_mask && oc == 32), + "only support specify mask"); + megdnn_assert(param().dct_block_size == 8, "only support dct block = 8"); + auto error_info = + concrete_handle(this->handle())->megcore_context().error_info; + constexpr int dct_block = 8; + const int* mask_offset_ptr = nullptr; + const int* mask_val_ptr = nullptr; + if (mask_offset.layout.ndim == 1 && mask_offset.layout.shape[0] >= 2) { + mask_offset_ptr = mask_offset.ptr(); + mask_val_ptr = mask_val.ptr(); + } + if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { + megdnn_assert(param().format == Param::Format::NCHW, + "fp32 only support nchw"); + dct::call_kern_dct( + src.ptr(), dst.ptr(), in, ic, ih, iw, oc, + with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, + error_info, m_error_tracker); + } else { + megdnn_assert(dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8, + "only support fp32 and qs8"); + megdnn_assert(param().format == Param::Format::NCHW4, + "qint8 only support nchw4"); + dct::call_kern_dct( + src.ptr(), (int8_t*)dst.raw_ptr, in, ic, ih, iw, oc, + with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, + error_info, m_error_tracker, + dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale); + } +} + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/dct/opr_impl.h b/dnn/src/cuda/dct/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..60c899b118c94f0831e05a5aadbe208050f90d1a --- /dev/null +++ b/dnn/src/cuda/dct/opr_impl.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/cuda/dct/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class DctChannelSelectForwardImpl : public DctChannelSelectForward { +public: + using DctChannelSelectForward::DctChannelSelectForward; + void* m_error_tracker = nullptr; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout& /*src*/, + const TensorLayout& /*mask_offset*/, + const TensorLayout& /*mask_val*/, + const TensorLayout& /*dst*/) { + return 0; + }; + void set_error_tracker(void* tracker) override { + m_error_tracker = tracker; + } +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 59262b5d1bc36639b2b489563c992b2466b0385b..e3b8ccbe6cef1a83ad9c9221ca7f9d7f9000076a 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -26,6 +26,7 @@ #include "src/cuda/convpooling/opr_impl.h" #include "src/cuda/cumsum/opr_impl.h" #include "src/cuda/cvt_color/opr_impl.h" +#include "src/cuda/dct/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" #include "src/cuda/dot/opr_impl.h" diff --git a/dnn/src/naive/dct/opr_impl.cpp b/dnn/src/naive/dct/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb986fd56bb254eea95563f4b7443744ee1956c6 --- /dev/null +++ b/dnn/src/naive/dct/opr_impl.cpp @@ -0,0 +1,242 @@ +/** + * \file dnn/src/naive/dct/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include +#include "megdnn/basic_types.h" +#include "megdnn/dtype.h" +#include "midout.h" +#include "src/naive/dct/opr_impl.h" +#include "src/naive/handle.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" +MIDOUT_DECL(megdnn_naive_dct_fwd) +namespace megdnn { +namespace naive { + +namespace { + +static inline void generate_c_matrix(float* result, int block) { + constexpr float pi = M_PI; + for (int i = 0; i < block; ++i) { + for (int j = 0; j < block; ++j) { + float alpha = i == 0 ? sqrt(1.f / static_cast(block)) + : sqrt(2.f / static_cast(block)); + result[i * block + j] = alpha * cos((2.f * j + 1.f) * i * pi / + static_cast(2 * block)); + } + } +} + +template +void matmul(int m, int n, int k, int lda, int ldb, int ldc, const float* a, + const T* b, float* c, bool trans_a, bool trans_b) { + for (int m_idx = 0; m_idx < m; ++m_idx) { + for (int n_idx = 0; n_idx < n; ++n_idx) { + float res = 0.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + float av = trans_a ? a[k_idx * lda + m_idx] + : a[m_idx * lda + k_idx]; + float bv = trans_b ? b[n_idx * ldb + k_idx] + : b[k_idx * ldb + n_idx]; + res += av * bv; + } + c[m_idx * ldc + n_idx] = res; + } + } +} + +std::vector> mask_offset_to_2dmask( + _megdnn_tensor_in mask_offset, _megdnn_tensor_in mask_val) { + std::vector> mask; + if (mask_offset.layout.ndim > 0 && mask_offset.layout[0] >= 2) { + const int offset_len = mask_offset.layout.shape[0]; + const int32_t* mask_offset_ptr = mask_offset.ptr(); + const int32_t* mask_val_ptr = mask_val.ptr(); + megdnn_assert( + mask_val.layout.shape[0] == + static_cast(mask_offset_ptr[offset_len - 1]), + "check mask offset %zu != %zu", mask_val.layout.shape[0], + static_cast(mask_offset_ptr[offset_len - 1])); + + for (int offset_idx = 1; offset_idx < offset_len; ++offset_idx) { + mask.push_back({}); + const int mask_len = mask_offset_ptr[offset_idx] - + mask_offset_ptr[offset_idx - 1]; + const int32_t* mask_ptr = + &mask_val_ptr[mask_offset_ptr[offset_idx - 1]]; + for (int val_idx = 0; val_idx < mask_len; ++val_idx) { + mask[offset_idx - 1].push_back(mask_ptr[val_idx]); + } + } + } + return mask; +} + +inline bool is_layout_nchw4(const TensorLayout& layout) { + if (layout.ndim == 5 && layout[4] == 4) { + return true; + } else { + return false; + } +} + +template +using QuantizedCType = + std::enable_if_t::category == DTypeCategory::QUANTIZED, + typename DTypeTrait::ctype>; + +inline int8_t quant_float_2_int8(float val, DType dtype) { + return dtype.param<::megdnn::dtype::QuantizedS8>().quantize(val).as_int8(); +} + +template +inline void dct_output(Dtype* dst_ptr, const int oc_idx, const int img_size, + float val, DType) { + dst_ptr[oc_idx * img_size] = val; +} +template <> +inline void dct_output( + int8_t* dst_ptr, const int oc_idx, const int img_size, float val, + DType dtype) { + dst_ptr[oc_idx / 4 * 4 * img_size + oc_idx % 4] = + quant_float_2_int8(val, dtype); +} +template +struct ChannleBlock { + static constexpr int block = 1; +}; + +template <> +struct ChannleBlock { + static constexpr int block = 4; +}; + +template +void naive_dct(const uint8_t* src, Dtype* dst, int n, int c, int h, int w, + int block, const std::vector>& mask, + DType dtype) { + constexpr int block_channel = ChannleBlock::block; + const int block_h = block; + const int block_w = block; + std::vector c_matrix(block * block); + std::vector tmp(block * block); + std::vector tmp_result(block * block); + generate_c_matrix(&c_matrix[0], block); + megdnn_assert(h % block_h == 0, "h mod block_h == 0"); + megdnn_assert(w % block_w == 0, "w mod block_w == 0"); + const int oh = h / block_h; + const int ow = w / block_w; + const int o_img_size = oh * ow; + std::vector mask_offset; + int mask_len_sum = 0; + if (mask.size() > 0) { + for (auto& sub_mask : mask) { + mask_offset.push_back(mask_len_sum); + mask_len_sum += sub_mask.size(); + } + } else { + for (int c_idx = 0; c_idx < c; ++c_idx) { + mask_offset.push_back(mask_len_sum); + mask_len_sum += block_h * block_w; + } + } + const size_t o_batch_stride = mask_len_sum * oh * ow; + + for (int n_idx = 0; n_idx < n; ++n_idx) { + for (int c_idx = 0; c_idx < c; ++c_idx) { + megdnn_assert(mask_offset[c_idx] % block_channel == 0, + "%d mod %d == 0", mask_offset[c_idx], block_channel); + const size_t src_offset = n_idx * c * h * w + c_idx * h * w; + const uint8_t* src_channel = src + src_offset; + const size_t dst_offset = n_idx * o_batch_stride + + mask_offset[c_idx] / block_channel * oh * + ow * block_channel; + Dtype* dst_channel = dst + dst_offset; + for (int oh_idx = 0; oh_idx < oh; ++oh_idx) { + for (int ow_idx = 0; ow_idx < ow; ++ow_idx) { + matmul(block, block, block, block, w, block, &c_matrix[0], + &src_channel[oh_idx * block_h * w + + ow_idx * block_w], + &tmp[0], false, false); + matmul(block, block, block, block, block, block, &tmp[0], + &c_matrix[0], &tmp_result[0], false, true); + Dtype* dst_start = dst_channel + + (oh_idx * ow + ow_idx) * block_channel; + if (mask.size() == 0) { + for (int inner_h_idx = 0; inner_h_idx < block_h; + ++inner_h_idx) { + for (int inner_w_idx = 0; inner_w_idx < block_w; + ++inner_w_idx) { + const int oc_idx = + inner_h_idx * block_w + inner_w_idx; + dct_output( + dst_start, oc_idx, o_img_size, + tmp_result[inner_h_idx * block + + inner_w_idx], + dtype); + } + } + } else { + //! with mask + auto& sub_mask = mask[c_idx]; + int dst_offset = 0; + for (auto mask_idx : sub_mask) { + dct_output(dst_start, dst_offset, + o_img_size, tmp_result[mask_idx], + dtype); + ++dst_offset; + } + } + } + } + } + } +} + +} // namespace + +void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, + _megdnn_tensor_out dst, + _megdnn_workspace /*workspace*/) { + MIDOUT_BEGIN(megdnn_naive_dct_fwd) { + int in = src.layout.shape[0]; + int ic = src.layout.shape[1]; + int ih = src.layout.shape[2]; + int iw = src.layout.shape[3]; + megdnn_assert(dst.raw_ptr, "dst can not be nullptr"); + const int block = param().dct_block_size; + auto mask = mask_offset_to_2dmask(mask_offset, mask_val); + if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { + megdnn_assert(!is_layout_nchw4(dst.layout) && + param().format == Param::Format::NCHW, + "dst must be nchw"); + MEGDNN_DISPATCH_CPU_KERN_OPR(naive_dct( + src.ptr(), dst.ptr(), in, ic, ih, iw, block, + mask, dst.layout.dtype)); + } else { + megdnn_assert(dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8, + "dst must be q8"); + megdnn_assert(is_layout_nchw4(dst.layout) && + param().format == Param::Format::NCHW4, + "dst must be nchw4"); + MEGDNN_DISPATCH_CPU_KERN_OPR(naive_dct( + src.ptr(), static_cast(dst.raw_ptr), in, + ic, ih, iw, block, mask, dst.layout.dtype)); + } + } + MIDOUT_END(); +} + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/dct/opr_impl.h b/dnn/src/naive/dct/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..e0d1405ee701812f579753cc95859352f4e78b08 --- /dev/null +++ b/dnn/src/naive/dct/opr_impl.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/naive/dct/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class DctChannelSelectForwardImpl : public DctChannelSelectForward { +public: + using DctChannelSelectForward::DctChannelSelectForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& /*src*/, + const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + }; +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index fd4cb78035c90ee68c7853cb5064a5bbcb948dce..77f70009cef34ef5077d6ed3c1687eb9823f0ab0 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/naive/handle.h" @@ -29,6 +30,7 @@ #include "src/naive/convpooling/opr_impl.h" #include "src/naive/cumsum/opr_impl.h" #include "src/naive/cvt_color/opr_impl.h" +#include "src/naive/dct/opr_impl.h" #include "src/naive/deformable_conv/opr_impl.h" #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" #include "src/naive/dot/opr_impl.h" @@ -56,6 +58,7 @@ #include "src/naive/reduce/opr_impl.h" #include "src/naive/relayout/opr_impl.h" #include "src/naive/relayout_format/opr_impl.h" +#include "src/naive/remap/opr_impl.h" #include "src/naive/repeat/opr_impl.h" #include "src/naive/resize/opr_impl.h" #include "src/naive/rng/opr_impl.h" @@ -76,7 +79,6 @@ #include "src/naive/warp_affine/opr_impl.h" #include "src/naive/warp_perspective/opr_impl.h" #include "src/naive/winograd_filter_preprocess/opr_impl.h" -#include "src/naive/remap/opr_impl.h" static size_t g_image2d_pitch_alignment = 1; diff --git a/dnn/test/common/benchmarker.h b/dnn/test/common/benchmarker.h index cfd9f6e5e2d2f8fc95556f2f10e3815fe739ac0e..da7809d57fb6ca09acc0377c83ee2b05f0cdd6fa 100644 --- a/dnn/test/common/benchmarker.h +++ b/dnn/test/common/benchmarker.h @@ -6,20 +6,21 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include #include -#include #include +#include #include "megdnn/basic_types.h" #include "megdnn/tensor_format.h" +#include "test/common/opr_algo_proxy.h" #include "test/common/opr_proxy.h" #include "test/common/rng.h" #include "test/common/timer.h" -#include "test/common/opr_algo_proxy.h" namespace megdnn { namespace test { @@ -31,6 +32,7 @@ public: using TensorValueArray = TensorNDArray; using BeforeExecCallback = std::function; + using TensorsConstriant = std::function; BenchmarkerBase(Handle* handle, T timer) : m_timer(timer), @@ -54,6 +56,8 @@ public: } float exec(TensorLayoutArray layouts); + float exect(const TensorValueArray& testcase_in); + //! disabiguate overloaded exec float execs(const TensorShapeArray& shapes) { return exec(shapes); } float execl(const TensorLayoutArray& layouts) { return exec(layouts); } @@ -73,6 +77,11 @@ public: m_fmt[idx] = fmt; return *this; } + BenchmarkerBase& set_tensors_constraint( + const TensorsConstriant& tensor_constraint) { + m_tensor_constraint = tensor_constraint; + return *this; + } TensorLayoutArray make_layouts(const TensorShapeArray& shapes) { TensorLayoutArray layouts(shapes.size()); for (size_t i = 0; i < shapes.size(); ++i) { @@ -142,6 +151,7 @@ private: std::unique_ptr> m_proxy; BeforeExecCallback m_before_exec_callback; std::unique_ptr m_opr; + TensorsConstriant m_tensor_constraint; }; template @@ -184,10 +194,16 @@ float BenchmarkerBase::exec(TensorLayoutArray layouts) { auto rng = m_rng[i]; if (!rng) rng = m_default_rng.get(); - auto size = tensor.layout.span().high_byte; rng->gen(tensor); + } + if (m_tensor_constraint) { + m_tensor_constraint(tensors_cur_host); + } + for (size_t i = 0; i < tensors_cur_host.size(); ++i) { + TensorND& tensor = tensors_cur_host[i]; if (tensor.layout.ndim == 0) continue; + auto size = tensor.layout.span().high_byte; megdnn_memcpy_H2D(m_handle, tensors_cur[i].raw_ptr, tensor.raw_ptr, size); } @@ -243,6 +259,105 @@ float BenchmarkerBase::exec(TensorLayoutArray layouts) { return time_in_ms; } +template +float BenchmarkerBase::exect(const TensorValueArray& testcase_in) { + auto opr = this->opr(); + opr->param() = m_param; + TensorLayoutArray layouts; + TensorNDArray tensors_cur_host; + for (auto& inp : testcase_in) { + layouts.push_back(inp.layout); + tensors_cur_host.emplace_back(inp); + } + auto user_layouts = layouts; + m_proxy->deduce_layout(opr, layouts); + for (size_t i = 0; i < layouts.size(); ++i) + if (user_layouts[i].ndim > 0) { + auto run = [&]() { + ASSERT_TRUE(layouts[i].eq_shape(user_layouts[i])) + << "User provided shape is " + << user_layouts[i].TensorShape::to_string() + << "\nExpected shape is " + << layouts[i].TensorShape::to_string(); + }; + run(); + } + auto allocate = [&layouts](Handle* handle) { + TensorNDArray tensors(layouts.size()); + auto trans_func = [handle](const TensorLayout& layout) { + auto span = layout.span(); + TensorND res; + res.raw_ptr = static_cast( + megdnn_malloc(handle, span.dist_byte())) + + span.low_byte; + res.layout = layout; + return res; + }; + std::transform(layouts.begin(), layouts.end(), tensors.begin(), + trans_func); + return tensors; + }; + auto tensors_cur = allocate(m_handle); + //! init + for (size_t i = 0; i < tensors_cur_host.size(); ++i) { + TensorND& tensor = tensors_cur_host[i]; + auto size = tensor.layout.span().high_byte; + if (tensor.layout.ndim == 0) + continue; + megdnn_memcpy_H2D(m_handle, tensors_cur[i].raw_ptr, tensor.raw_ptr, + size); + } + if (m_before_exec_callback) { + m_before_exec_callback(opr, tensors_cur); + } + //! run + //! warm up + m_proxy->exec(opr, tensors_cur); + megcoreSynchronize(m_handle->megcore_computing_handle()); + + if (m_adaptive_secs) { + //! find m_times for adaptive benchmarking + m_times = 0; + int cur_times = 1; + auto remain_time = m_adaptive_secs * 1e6; + while (remain_time > 0) { + m_timer.reset(); + m_timer.start(); + for (int i = 0; i < cur_times; ++i) + m_proxy->exec(opr, tensors_cur); + megcoreSynchronize(m_handle->megcore_computing_handle()); + m_timer.stop(); + m_times += cur_times; + auto this_run_time = m_timer.get_time_in_us(); + remain_time -= this_run_time; + cur_times = std::min( + cur_times * 2, + std::max(1, remain_time / this_run_time * cur_times)); + } + } + m_timer.reset(); + m_timer.start(); + for (size_t t = 0; t < m_times; ++t) + m_proxy->exec(opr, tensors_cur); + megcoreSynchronize(m_handle->megcore_computing_handle()); + m_timer.stop(); + auto time_in_ms = m_timer.get_time_in_us() / 1e3; + if (m_display) { + std::cout << "Total time is " << time_in_ms << "ms " + << "for " << m_times << " run(s)." << std::endl; + } + auto free = [](Handle* handle, TensorNDArray& tensors) { + std::for_each(tensors.begin(), tensors.end(), + [handle](const TensorND& tensor) { + megdnn_free(handle, tensor.raw_ptr); + }); + }; + free(m_handle, tensors_cur); + if (m_adaptive_secs) + time_in_ms /= m_times; + return time_in_ms; +} + template class Benchmarker; diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index 5057fafd86093fcead01eccc3f60979277a797d3..0cb45c6b7424c0bbcc101640e5b456a7d296c010 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -294,8 +294,6 @@ void CheckerHelper::do_exec_with_testcases(const TensorValueArray& testcase_in, ASSERT_TRUE(testcase_in[i].layout.ndim == 0 || testcase_out[i].layout.ndim == 0 || testcase_in[i].layout.eq_layout(testcase_out[i].layout)); - ASSERT_TRUE(testcase_in[i].layout.ndim != 0 || - testcase_out[i].layout.ndim != 0); layouts.emplace_back(testcase_in[i].layout.ndim > 0 ? testcase_in[i].layout : testcase_out[i].layout); diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index e6b355fbeb678ff0a974488c2364f15e8d427291..2b4c4717f1f79319513dafe7dc269b26f24d3b08 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -392,7 +392,8 @@ TensorND TensorValue(const TensorShape& shape, T dtype, tensor.layout = {shape, dtype}; tensor.raw_ptr = static_cast(malloc(tensor.layout.span().dist_byte())); - megdnn_assert(values.size() == tensor.layout.total_nr_elems()); + megdnn_assert(values.size() == tensor.layout.total_nr_elems(), "%zu == %zu", values.size(), + tensor.layout.total_nr_elems()); auto ptr = tensor.ptr::ctype>(); for (const auto& v : values) { *ptr++ = typename DTypeTrait::ctype(v); diff --git a/dnn/test/common/dct_ref.cpp b/dnn/test/common/dct_ref.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7923915f498928cf35e8d033415b31083ec24e57 --- /dev/null +++ b/dnn/test/common/dct_ref.cpp @@ -0,0 +1,198 @@ +/** + * \file + * dnn/test/common/dct_ref.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/common/dct_ref.h" +namespace megdnn { +namespace test { +struct FixCase { + std::vector mask_offset; + std::vector mask_val; +}; +using Param = DctChannelSelectForward::Param; + +static inline FixCase get_fix_mask(Param::FastImpl impl) { + std::vector fix_32_mask_offset{0, 16, 24, 32}; + std::vector fix_32_mask_val{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, + 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, + 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}; + megdnn_assert(impl == Param::FastImpl::FIX_32_MASK, + "only support gen FIX_32_MASK"); + return {fix_32_mask_offset, fix_32_mask_val}; +} + +CheckerHelper::TensorsConstriant gen_dct_constriant( + const size_t /* n */, const size_t ic, const size_t ih, const size_t iw, + const size_t oc, Param param) { + auto constraint = [=](CheckerHelper::TensorValueArray& tensors_orig) { + const size_t block = param.dct_block_size; + const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1; + megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0"); + std::shared_ptr test_case_ptr = DctTestcase::make(); + DctTestcase& test_case = *test_case_ptr.get(); + UniformIntRNG rng(0, 255); + UniformIntRNG mask_rng(0, 64 / block_c - 1); + const size_t no_mask_oc = ic * block * block; + megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block); + megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block); + + TensorND mask_offset; + TensorND mask_val; + std::vector& mask_offset_vec = test_case.mask_offset_vec; + std::vector& mask_val_vec = test_case.mask_val_vec; + UniformIntRNG rng_oc(0, oc); + if (param.fastImpl == Param::FastImpl::FIX_32_MASK) { + auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK); + mask_offset_vec = fix_32_mask.mask_offset; + mask_val_vec = fix_32_mask.mask_val; + megdnn_assert(oc == 32, "oc must eq 32"); + } else if (no_mask_oc > oc) { + size_t remain_oc = oc; + mask_offset_vec.resize(ic + 1); + mask_val_vec.resize(oc); + mask_offset_vec[0] = 0; + for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) { + size_t random_len = (int)rng_oc.gen_single_val() * block_c; + size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0) + ? remain_oc + : random_len % remain_oc; + megdnn_assert(mask_len % block_c == 0, + "mask_len mod block_c == 0, but %zu mod %d ", + mask_len, block_c); + const size_t oc_idx = mask_offset_vec[ic_idx]; + remain_oc -= mask_len; + mask_offset_vec[ic_idx + 1] = oc_idx + mask_len; + for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) { + mask_val_vec[oc_idx + mask_idx] = + (int)mask_rng.gen_single_val(); + } + } + } + mask_offset = TensorND(mask_offset_vec.data(), + {{mask_offset_vec.size()}, dtype::Int32()}); + mask_val = TensorND(mask_val_vec.data(), + {{mask_val_vec.size()}, dtype::Int32()}); + if (tensors_orig.size() > 1) { + megdnn_assert(tensors_orig.size() == 4, "tensors_orig.size() == 4"); + megdnn_assert(mask_offset_vec.size() >= 2, + "mask_offset_vec.size() >= 2"); + megdnn_assert(tensors_orig[1].layout == mask_offset.layout, + "tensors_orig[1].layout == mask_offset.layout"); + megdnn_assert(tensors_orig[2].layout == mask_val.layout, + "tensors_orig[2].layout == mask_val.layout"); + auto naive_handle = create_cpu_handle(2, false); + megdnn_memcpy_D2D(naive_handle.get(), tensors_orig[1].raw_ptr, + mask_offset.raw_ptr, + mask_offset.layout.span().dist_byte()); + megdnn_memcpy_D2D(naive_handle.get(), tensors_orig[2].raw_ptr, + mask_val.raw_ptr, + mask_val.layout.span().dist_byte()); + } + }; + return constraint; +} + +std::shared_ptr gen_dct_case(const size_t n, const size_t ic, + const size_t ih, const size_t iw, + const size_t oc, Param param, + DType dst_dtype, + bool correct_result) { + const size_t block = param.dct_block_size; + const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1; + megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0"); + std::shared_ptr test_case_ptr = DctTestcase::make(); + DctTestcase& test_case = *test_case_ptr.get(); + UniformIntRNG rng(0, 255); + UniformIntRNG mask_rng(0, 64 / block_c - 1); + const size_t input_elements = n * ic * ih * iw; + const size_t no_mask_oc = ic * block * block; + megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block); + megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block); + std::vector& inp_vec = test_case.inp_vec; + inp_vec.resize(input_elements); + TensorShape input_shape{n, ic, ih, iw}; + for (auto& elm : inp_vec) { + elm = (uint8_t)rng.gen_single_val(); + } + auto src = TensorND(inp_vec.data(), {input_shape, dtype::Uint8()}); + TensorND mask_offset; + TensorND mask_val; + std::vector& mask_offset_vec = test_case.mask_offset_vec; + std::vector& mask_val_vec = test_case.mask_val_vec; + UniformIntRNG rng_oc(0, oc); + if (param.fastImpl == Param::FastImpl::FIX_32_MASK) { + auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK); + mask_offset_vec = fix_32_mask.mask_offset; + mask_val_vec = fix_32_mask.mask_val; + megdnn_assert(oc == 32, "oc must eq 32"); + } else if (no_mask_oc > oc) { + size_t remain_oc = oc; + mask_offset_vec.resize(ic + 1); + mask_val_vec.resize(oc); + mask_offset_vec[0] = 0; + for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) { + size_t random_len = (int)rng_oc.gen_single_val() * block_c; + size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0) + ? remain_oc + : random_len % remain_oc; + megdnn_assert(mask_len % block_c == 0, + "mask_len mod block_c == 0, but %zu mod %d ", + mask_len, block_c); + const size_t oc_idx = mask_offset_vec[ic_idx]; + remain_oc -= mask_len; + mask_offset_vec[ic_idx + 1] = oc_idx + mask_len; + for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) { + mask_val_vec[oc_idx + mask_idx] = + (int)mask_rng.gen_single_val(); + } + } + } + mask_offset = TensorND(mask_offset_vec.data(), + {{mask_offset_vec.size()}, dtype::Int32()}); + mask_val = TensorND(mask_val_vec.data(), + {{mask_val_vec.size()}, dtype::Int32()}); + if (mask_offset_vec.size() >= 2) { + test_case.testcase_in = { + src, mask_offset, mask_val, {nullptr, {{}, dst_dtype}}}; + } else { + test_case.testcase_in = {src, {}, {}, {nullptr, {{}, dst_dtype}}}; + } + + auto naive_handle = create_cpu_handle(2, false); + auto opr_naive = naive_handle->create_operator(); + opr_naive->param() = param; + using Proxy = OprProxy; + Proxy naive_proxy; + TensorLayout temp_dst_layout; + temp_dst_layout.dtype = dst_dtype; + + TensorLayoutArray layouts{src.layout, mask_offset.layout, mask_val.layout, + temp_dst_layout}; + naive_proxy.deduce_layout(opr_naive.get(), layouts); + const size_t output_elements = layouts[3].total_nr_elems(); + std::vector& output_vec = test_case.output_vec; + output_vec.resize(output_elements); + auto dst = TensorND(output_vec.data(), layouts[3]); + DctTestcase::TensorValueArray testcase_naive; + testcase_naive.emplace_back(test_case.testcase_in[0]); + testcase_naive.emplace_back(test_case.testcase_in[1]); + testcase_naive.emplace_back(test_case.testcase_in[2]); + testcase_naive.emplace_back(dst); + if (correct_result) { + naive_proxy.exec(opr_naive.get(), testcase_naive); + } + test_case.testcase_out = {{}, {}, {}, dst}; + return test_case_ptr; +} + +} // namespace test +} // namespace megdnn + // vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/test/common/dct_ref.h b/dnn/test/common/dct_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..8045c43e9652e20fee7c4b58cb521f4d89590a2f --- /dev/null +++ b/dnn/test/common/dct_ref.h @@ -0,0 +1,52 @@ +/** + * \file + * dnn/test/common/dct_ref.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include +#include +#include "megdnn/dtype.h" +#include "megdnn/oprs/nn.h" +#include "test/common/checker.h" +#include "test/common/opr_proxy.h" +#include "test/common/rng.h" + +namespace megdnn { +namespace test { + +using Param = DctChannelSelectForward::Param; + +struct DctTestcase { + using TensorValueArray = TensorNDArray; + TensorValueArray testcase_in; + TensorValueArray testcase_out; + std::vector inp_vec; + std::vector mask_offset_vec; + std::vector mask_val_vec; + std::vector output_vec; + static std::shared_ptr make() { + return std::make_shared(); + } +}; + +CheckerHelper::TensorsConstriant gen_dct_constriant( + const size_t n, const size_t ic, const size_t ih, const size_t iw, + const size_t oc, Param param); + +std::shared_ptr gen_dct_case(const size_t n, const size_t ic, + const size_t ih, const size_t iw, + const size_t oc, Param param, + DType dst_dtype = dtype::Float32(), + bool correct_result = true); + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/opr_trait.h b/dnn/test/common/opr_trait.h index 66af7db3c1cc0eadd187d9c3803a912919481e5f..0d4d66422bc9879261209a9709f8601726f1637b 100644 --- a/dnn/test/common/opr_trait.h +++ b/dnn/test/common/opr_trait.h @@ -110,6 +110,7 @@ DEF(BatchConvBiasForward, 5, true, true); DEF(Remap, 3, true, true); DEF(RemapBackwardData, 3, true, false); DEF(RemapBackwardMat, 4, true, false); +DEF(DctChannelSelectForward, 4, true, true); } // namespace test } // namespace megdnn diff --git a/dnn/test/cuda/dct.cpp b/dnn/test/cuda/dct.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0045cb86c3336ea5b52fae7a12554f528bf4f9b4 --- /dev/null +++ b/dnn/test/cuda/dct.cpp @@ -0,0 +1,360 @@ +/** + * \file dnn/test/cuda/dct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs/nn.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/dct_ref.h" +#include "test/common/rng.h" +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, DCT) { + DctChannelSelectForward::Param param; + Checker checker(handle_cuda()); + for (size_t n : {1, 3}) { + for (size_t ic : {1, 3}) { + for (size_t ih : {8, 16, 32, 512, 1024}) { + for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { + checker.set_param(param) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Int32()) + .set_dtype(2, dtype::Int32()) + .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}}); + } + } + } + } +} + +TEST_F(CUDA, DCT_QINT8) { + DctChannelSelectForward::Param param; + Checker checker(handle_cuda()); + param.format = Param::Format::NCHW4; + for (size_t n : {1, 3}) { + for (size_t ic : {1, 3}) { + for (size_t ih : {8, 16, 32, 512, 1024}) { + for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { + checker.set_param(param) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Int32()) + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::QuantizedS8(10.f)) + .set_epsilon(1) + .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}}); + } + } + } + } +} + +TEST_F(CUDA, DCT_WITH_FIX_32_MASK) { + using Param = DctChannelSelectForward::Param; + Param param; + Checker checker(handle_cuda(), false); + param.fastImpl = Param::FastImpl::FIX_32_MASK; + auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param); + checker.set_param(param).exect(test_case->testcase_in, + test_case->testcase_out); +} + +TEST_F(CUDA, DCT_WITH_FIX_32_MASK_QINT8) { + using Param = DctChannelSelectForward::Param; + Param param; + Checker checker(handle_cuda(), false); + param.fastImpl = Param::FastImpl::FIX_32_MASK; + param.format = Param::Format::NCHW4; + auto test_case = + gen_dct_case(3, 3, 1024, 768, 32, param, dtype::QuantizedS8(10.f)); + checker.set_param(param).set_epsilon(1).exect(test_case->testcase_in, + test_case->testcase_out); +} + +TEST_F(CUDA, DCT_WITH_MASK) { + Checker checker(handle_cuda(), false); + DctChannelSelectForward::Param param; + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 3, 8, 16}, dtype::Uint8(), + {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204}), + TensorValue({4}, dtype::Int32(), {0, 14, 22, 30}), + TensorValue({30}, dtype::Int32(), + {8, 16, 9, 2, 3, 10, 17, 24, 32, 25, + 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, + 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), + {}}, + Testcase{{}, + {}, + {}, + TensorValue({1, 30, 1, 2}, dtype::Float32(), + {-22.850792, -97.862236, -101.043236, + -4.727012, 28.275675, -157.96654, + 42.1377, 45.06531, -149.77373, + 24.487143, -8.054966, -13.990831, + -6.9395194, -3.9211385, 64.79172, + -12.363858, -47.875, 59., + 56.271786, -62.725567, 120.522675, + 16.559765, 85.74334, 112.904495, + 99.375, 29.499973, 2.0220923, + -19.681704, 890.12494, 941.25, + -7.0498576, 99.47632, -22.850792, + -97.862236, -101.043236, -4.727012, + 28.275675, -157.96654, 42.1377, + 45.06531, -149.77373, 24.487143, + -8.054966, -13.990831, 890.12494, + 941.25, -7.0498576, 99.47632, + -22.850792, -97.862236, -101.043236, + -4.727012, 28.275675, -157.96654, + 42.1377, 45.06531, -149.77373, + 24.487143, -8.054966, -13.990831})}); +} + +TEST_F(CUDA, DCT_WITH_MASK2) { + Checker checker(handle_cuda(), false); + DctChannelSelectForward::Param param; + UniformIntRNG rng_oc(0, 3 * 64); + for (size_t n : {1, 3}) { + for (size_t ic : {1, 3}) { + for (size_t ih : {8, 16, 32, 512, 1024}) { + for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { + int random_oc = static_cast(rng_oc.gen_single_val()); + int max_oc = ic * 64; + int mask_oc = (random_oc % max_oc) + 1; + auto test_case = + gen_dct_case(n, ic, ih, iw, mask_oc, param); + checker.set_param(param).exect(test_case->testcase_in, + test_case->testcase_out); + } + } + } + } +} + +TEST_F(CUDA, DCT_WITH_MASK2_QINT8) { + Checker checker(handle_cuda(), false); + DctChannelSelectForward::Param param; + param.format = DctChannelSelectForward::Param::Format::NCHW4; + + UniformIntRNG rng_oc(0, 3 * 64); + for (size_t n : {1, 3}) { + for (size_t ic : {1, 3}) { + for (size_t ih : {8, 16, 32, 512, 1024}) { + for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { + int random_oc = static_cast(rng_oc.gen_single_val()); + int max_oc = ic * 64; + int mask_oc = (random_oc % max_oc) + 1; + mask_oc = (mask_oc + 3) / 4 * 4; + auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param, + dtype::QuantizedS8(10.f)); + checker.set_param(param).set_epsilon(1).exect( + test_case->testcase_in, test_case->testcase_out); + } + } + } + } +} +TEST_F(CUDA, DCT_WITH_MASK2_QINT8_CONSTRAINT) { + DctChannelSelectForward::Param param; + param.format = DctChannelSelectForward::Param::Format::NCHW4; + + Checker checker(handle_cuda(), false); + checker.set_param(param) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Int32()) + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::QuantizedS8(10.f)) + .set_epsilon(1); + + UniformIntRNG rng_oc(0, 3 * 64); + for (size_t n : {1, 3}) { + for (size_t ic : {1, 3}) { + for (size_t ih : {8, 16, 32, 512, 1024}) { + for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { + int random_oc = static_cast(rng_oc.gen_single_val()); + int max_oc = ic * 64; + int mask_oc = (random_oc % max_oc) + 1; + mask_oc = (mask_oc + 3) / 4 * 4; + if (mask_oc < max_oc) { + checker + .set_tensors_constraint(gen_dct_constriant( + n, ic, ih, iw, mask_oc, param)) + .exec({TensorShape{n, ic, ih, iw}, + TensorShape{ic + 1}, + TensorShape{(size_t)mask_oc}, + {}}); + } else { + checker.set_tensors_constraint({}).exec( + {TensorShape{n, ic, ih, iw}, {}, {}, {}}); + } + } + } + } + } +} + +#if MEGDNN_WITH_BENCHMARK + +TEST_F(CUDA, BENCHMARK_DCT) { + using Param = DctChannelSelectForward::Param; + + auto run = [&](const TensorShapeArray& shapes, Param param) { + Benchmarker benchmarker(handle_cuda()); + benchmarker.set_param(param); + benchmarker.set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Int32()) + .set_dtype(2, dtype::Int32()); + for (auto&& shape : shapes) { + double computation = double(shape[0]) * shape[1] * shape[2] * + shape[3] * 32.0 * 1e-6; + auto time_ms = benchmarker.execs({shape, {}, {}, {}}); + printf("execute %s, %.4f Gops\n", shape.to_string().c_str(), + computation / time_ms); + } + }; + + auto run_case = [&](const DctTestcase& testcase, Param param, + std::string comment = "") { + Benchmarker benchmarker(handle_cuda()); + benchmarker.set_param(param); + benchmarker.set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Int32()) + .set_dtype(2, dtype::Int32()) + .set_dtype(3, testcase.testcase_out[3].layout.dtype); + + auto src_shape = testcase.testcase_in[0].layout; + double computation = double(src_shape[0]) * src_shape[1] * + src_shape[2] * src_shape[3] * 32.0 * 1e-6; + auto time_ms = benchmarker.exect(testcase.testcase_in); + printf("[%s] execute %s, %.4f Gops\n", comment.c_str(), + src_shape.to_string().c_str(), computation / time_ms); + }; + + auto run_case_constraint = + [&](const Benchmarker::TensorsConstriant& + constraint, + Param param, const TensorShapeArray& shapes, + std::string comment = "", DType output_dtype) { + Benchmarker benchmarker(handle_cuda()); + benchmarker.set_param(param) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Int32()) + .set_dtype(2, dtype::Int32()) + .set_dtype(3, output_dtype) + .set_tensors_constraint(constraint); + + auto src_shape = shapes[0]; + double computation = double(src_shape[0]) * src_shape[1] * + src_shape[2] * src_shape[3] * 32.0 * 1e-6; + auto time_ms = benchmarker.exec(shapes); + printf("[%s] execute %s, %.4f Gops\n", comment.c_str(), + src_shape.to_string().c_str(), computation / time_ms); + }; + + TensorShapeArray shapes = { + {1, 3, 512, 512}, + {8, 3, 2176, 3840}, + }; + { + Param param; + run(shapes, param); + } + + Param fix_32_param; + fix_32_param.fastImpl = Param::FastImpl::FIX_32_MASK; + { + auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param); + run_case(*test_case, fix_32_param, "FIX_32_MASK"); + } + + { + Param param; + auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param); + run_case(*test_case, param, "MASK 32"); + } + + { + Param fix_32_nchw4_param; + fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK; + fix_32_nchw4_param.format = Param::Format::NCHW4; + auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param, + dtype::QuantizedS8(10.f)); + run_case(*test_case, fix_32_nchw4_param, "FIX_32_MASK QINT8"); + } + + { + Param fix_32_nchw4_param; + fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK; + fix_32_nchw4_param.format = Param::Format::NCHW4; + auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param, + dtype::QuantizedS8(10.f)); + fix_32_nchw4_param.fastImpl = Param::FastImpl::NONE; + run_case(*test_case, fix_32_nchw4_param, "MASK 32 QINT8"); + } + + { + Param fix_32_nchw4_param; + fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK; + fix_32_nchw4_param.format = Param::Format::NCHW4; + TensorShapeArray shapes = {{8, 3, 2176, 3840}, {4}, {32}, {}}; + auto constraint = + gen_dct_constriant(8, 3, 2176, 3840, 32, fix_32_nchw4_param); + run_case_constraint(constraint, fix_32_nchw4_param, shapes, + "FIX_32_MASK QINT8 Constraint", + dtype::QuantizedS8(10.f)); + } +} +#endif + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/dct.cpp b/dnn/test/naive/dct.cpp new file mode 100644 index 0000000000000000000000000000000000000000..620e0a8281cc3f4945d8a79e7f394c8559cabee7 --- /dev/null +++ b/dnn/test/naive/dct.cpp @@ -0,0 +1,679 @@ +/** + * \file dnn/test/naive/dct.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs/nn.h" +#include "test/common/checker.h" +#include "test/common/dct_ref.h" +#include "test/common/rng.h" +#include "test/common/tensor.h" +#include "test/naive/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(NAIVE, DCT) { + Checker checker(handle(), + /* check_dispatch */ false); + DctChannelSelectForward::Param param; + + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 1, 16, 16}, dtype::Uint8(), + {87, 155, 59, 161, 24, 200, 58, 3, 40, 43, + 156, 7, 176, 232, 226, 78, 73, 236, 185, 109, + 196, 169, 62, 32, 167, 180, 96, 157, 101, 53, + 150, 47, 26, 238, 218, 210, 204, 236, 249, 111, + 16, 35, 169, 204, 117, 16, 3, 147, 12, 233, + 135, 162, 58, 118, 184, 237, 90, 105, 156, 195, + 196, 104, 138, 19, 82, 62, 126, 140, 220, 171, + 206, 232, 105, 123, 2, 135, 137, 41, 26, 219, + 167, 245, 104, 103, 24, 144, 141, 210, 208, 114, + 169, 170, 22, 11, 69, 106, 236, 150, 57, 184, + 75, 241, 28, 175, 178, 186, 190, 124, 187, 116, + 112, 162, 214, 154, 207, 31, 43, 40, 15, 188, + 81, 197, 20, 199, 246, 132, 159, 111, 79, 95, + 148, 184, 171, 173, 203, 146, 150, 33, 178, 9, + 141, 49, 237, 222, 72, 5, 23, 38, 248, 82, + 93, 229, 70, 180, 149, 232, 245, 72, 196, 138, + 4, 31, 160, 30, 8, 109, 153, 252, 204, 126, + 15, 182, 145, 130, 179, 234, 21, 240, 144, 105, + 77, 116, 155, 232, 168, 99, 159, 92, 251, 223, + 119, 173, 166, 39, 228, 91, 34, 5, 62, 172, + 131, 164, 143, 10, 161, 165, 221, 214, 178, 110, + 185, 254, 152, 149, 46, 144, 173, 237, 76, 210, + 221, 45, 200, 113, 58, 20, 47, 135, 228, 80, + 91, 51, 238, 194, 222, 231, 174, 244, 139, 96, + 71, 25, 25, 62, 172, 181, 71, 27, 86, 0, + 121, 38, 199, 236, 93, 158}), + {}, + {}, + {}}, + Testcase{{}, + {}, + {}, + TensorValue( + {1, 64, 2, 2}, dtype::Float32(), + {1.10687500e+03, 9.59500000e+02, 8.98125000e+02, + 1.21912500e+03, 1.38846378e+01, 3.91629181e+01, + -1.50343018e+02, -1.02085358e+02, 2.34341068e+01, + -8.40960388e+01, -4.23510742e+01, 1.72630596e+01, + -4.66624413e+01, -4.87857285e+01, -7.06332016e+01, + 6.31493912e+01, -9.96249924e+01, 7.72499924e+01, + 7.46250153e+01, 5.81250114e+01, -9.07061768e+01, + -7.68266630e+00, -3.15778809e+01, -3.35406876e+01, + 8.55864143e+00, -7.36760712e+01, 6.20557327e+01, + -2.92043419e+01, -1.39985870e+02, 2.56675129e+01, + 5.21866226e+01, 1.07624054e+02, -6.16851950e+00, + -8.56008530e+01, 7.35654449e+01, -2.56767311e+01, + -2.09981880e+01, -6.22950821e+01, -1.31617493e+02, + -6.30962448e+01, -2.21552780e+02, -4.79528542e+01, + 1.04179153e+02, 7.45253448e+01, 3.19730816e+01, + 1.24306192e+01, -9.93905945e+01, -8.95680237e+01, + -1.44870041e+02, -9.44738235e+01, -4.09417763e+01, + 4.50356903e+01, -3.65339231e+00, 5.79474449e+01, + -2.46253452e+01, 3.29394951e+01, -1.09065903e+02, + 5.23808861e+01, -1.00386992e+01, -7.92311325e+01, + -1.44292374e+01, 5.74285736e+01, 2.28798485e+01, + 6.84826508e+01, -1.49241837e+02, 9.35751495e+01, + -4.02763329e+01, -6.63586197e+01, 2.15622040e+02, + -7.83887939e+01, -8.06824951e+01, -2.51097183e+01, + 1.58941059e+01, -5.66967869e+00, -1.53566467e+02, + -4.33494377e+01, 8.12108078e+01, 1.21169144e+02, + 2.14673615e+02, -3.72018318e+01, 2.45811577e+01, + -1.27189613e+02, 4.98553581e+01, -5.83694696e+00, + -4.80477619e+00, -2.24601650e+01, -5.02191353e+00, + 5.16259460e+01, 1.07266571e+02, -3.41748886e+01, + -5.44621315e+01, 6.25573196e+01, -4.24649086e+01, + 4.42625465e+01, 2.71147366e+01, 4.83264275e+01, + -6.99711227e+01, -1.00299120e+01, 1.33173111e+02, + 2.48003254e+01, -1.74687519e+01, 9.44530487e-01, + 1.35930038e+02, 6.72219162e+01, 4.53297043e+01, + 1.37072708e+02, -7.73253784e+01, 6.12967606e+01, + 9.78184891e+01, 3.63894577e+01, -1.64039135e+01, + -6.67858887e+01, 5.27859840e+01, -4.99117432e+01, + 8.77927475e+01, -5.86666260e+01, 3.86430244e+01, + 2.17759323e+01, 8.34562683e+01, 3.06256886e+01, + 1.61030369e+01, 8.11268158e+01, 1.36932516e+01, + -1.06112595e+02, -9.31621475e+01, 3.13674717e+01, + -4.90609503e+00, 7.96453857e+01, -1.02625000e+02, + 1.40000076e+01, 3.18749981e+01, -1.08375000e+02, + -5.44420319e+01, -1.50944397e+02, 5.29974670e+01, + -1.44041641e+02, 4.86086197e+01, -7.13610382e+01, + 3.06417294e+01, 7.20477829e+01, -6.95384140e+01, + 1.25441925e+02, -1.54897385e+01, 3.78566666e+01, + 4.23749886e+01, -3.37500000e+01, -9.96250000e+01, + -6.73750076e+01, 3.34241295e+01, -6.24825974e+01, + 1.76387348e+01, -6.45708389e+01, 1.70728874e+01, + -5.73032570e+01, -1.71570969e+01, 1.84064590e+02, + 4.17566071e+01, 7.08248520e+00, -2.59306641e+01, + 1.37766739e+02, -2.16669798e+00, 6.03565750e+01, + 6.84421844e+01, 6.19825096e+01, -1.44220114e+01, + -3.12404213e+01, -2.50061111e+01, 6.73021851e+01, + 2.52050266e+01, -8.35850677e+01, -4.70746574e+01, + 1.73889160e+01, 1.18955564e+01, 6.16792488e+00, + -3.29667168e+01, 4.55779572e+01, -4.17868996e+00, + -9.40233841e+01, -9.77727051e+01, 1.74934635e+01, + 5.25992851e+01, 1.23662634e+01, 5.26129305e-01, + 4.69518929e+01, -1.52657738e+01, 9.96897888e+01, + -9.51726151e+01, 9.99432602e+01, -1.75949844e+02, + 1.00472336e+02, -5.89417953e+01, -1.72231483e+01, + 1.89282093e+01, -8.17851868e+01, 7.22908936e+01, + -9.06294174e+01, 2.46093607e+00, -4.03946457e+01, + 2.17710762e+01, -5.62999649e+01, 4.77665749e+01, + -4.04248848e+01, 4.78787374e+00, 1.05557320e+02, + -4.60584450e+01, -7.33774490e+01, -4.25107193e+01, + 1.71907139e+01, -8.01314316e+01, 1.69647141e+01, + -8.24824219e+01, 8.29206543e+01, 3.72900200e+01, + 3.77470016e+01, 6.70151443e+01, 1.79784470e+01, + -4.01441078e+01, 6.29196739e+01, 7.60664597e+01, + -5.59005699e+01, 8.81600475e+00, -6.89491081e+00, + -8.03825378e+01, -5.33856511e-01, 7.26196136e+01, + -3.76809120e+01, -1.08401566e+02, 6.35455990e+00, + -8.66767120e+01, -1.02679443e+02, -9.54313660e+00, + -3.55650787e+01, -1.21355652e+02, 2.32628040e+01, + 3.94072838e+01, 1.24754738e+02, 9.51344986e+01, + -5.84752541e+01, -4.65028038e+01, 6.00556993e+00, + 4.94889374e+01, 7.64868622e+01, -1.49546280e+01, + -3.70648766e+01, 5.55572205e+01, -1.17196434e+02, + 9.20216217e+01, 3.29843826e+01, 3.25113411e+01, + 5.62059135e+01, 6.30202141e+01, 4.99030991e+01, + 2.85804024e+01, -1.44606361e+01, 7.64952774e+01, + -2.95697536e+01})}); +} + +TEST_F(NAIVE, DCT_INT8) { + Checker checker(handle(), + /* check_dispatch */ false); + DctChannelSelectForward::Param param; + param.format = DctChannelSelectForward::Param::Format::NCHW4; + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 1, 16, 16}, dtype::Uint8(), + {113, 223, 229, 159, 249, 252, 89, 84, 45, 16, + 41, 72, 184, 236, 70, 184, 86, 172, 218, 211, + 47, 177, 18, 85, 174, 226, 37, 109, 38, 135, + 228, 195, 133, 238, 47, 246, 244, 118, 175, 143, + 34, 10, 28, 4, 82, 103, 89, 55, 235, 78, + 151, 178, 249, 62, 183, 84, 105, 0, 121, 98, + 249, 90, 161, 114, 121, 241, 21, 199, 196, 119, + 231, 209, 250, 180, 192, 213, 116, 105, 114, 169, + 1, 142, 3, 30, 140, 245, 201, 109, 19, 26, + 224, 68, 123, 228, 64, 150, 184, 212, 136, 172, + 241, 152, 222, 233, 15, 72, 130, 144, 107, 130, + 242, 79, 195, 46, 226, 57, 183, 36, 88, 161, + 121, 170, 2, 215, 109, 212, 35, 18, 76, 197, + 117, 81, 208, 8, 237, 75, 15, 20, 16, 192, + 61, 113, 96, 126, 211, 57, 49, 62, 185, 211, + 155, 87, 233, 163, 164, 84, 61, 28, 1, 11, + 190, 253, 145, 30, 38, 98, 153, 56, 231, 152, + 12, 204, 96, 8, 47, 87, 25, 237, 21, 150, + 173, 19, 41, 175, 164, 231, 39, 145, 39, 187, + 210, 123, 165, 98, 87, 242, 38, 136, 182, 145, + 41, 47, 147, 171, 172, 35, 170, 148, 26, 89, + 107, 151, 130, 232, 65, 217, 27, 206, 68, 219, + 60, 106, 3, 209, 175, 189, 191, 32, 119, 141, + 56, 48, 105, 58, 94, 163, 185, 60, 83, 249, + 112, 245, 137, 60, 178, 51, 177, 106, 199, 209, + 4, 247, 3, 127, 88, 46}), + {}, + {}, + {}}, + Testcase{{}, + {}, + {}, + TensorValue( + {1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f), + {122, -1, -8, 4, 92, -13, -5, 7, 99, 4, + 5, 3, 89, 7, 2, -6, 3, -8, -10, 2, + -1, 0, 4, -3, -5, -8, -11, 1, 14, 4, + -10, -18, 3, 12, -14, -2, -4, -9, 12, 4, + -2, -2, 2, 6, -9, 6, 1, 5, -5, -1, + 2, -12, 4, -5, -0, 4, 1, 5, -8, 5, + -3, 4, 2, 6, -0, 9, -4, -7, -4, -5, + -2, 8, 2, 4, 0, 7, -8, 4, -2, 3, + -6, -5, 19, 5, -4, -4, -5, -16, -8, -3, + -5, 19, 4, 3, 4, -6, 1, -12, -1, 7, + 11, -5, -1, -8, 2, -12, -9, -2, -4, -20, + -11, -15, -15, -9, -2, -9, -2, -3, 13, 2, + 5, 6, 7, -4, 1, -7, 6, 4, 2, 6, + 0, -0, 8, 8, -6, 5, 1, -2, -2, -12, + 2, -12, -2, 6, 7, 3, 4, 14, 14, -3, + 1, -3, 6, 0, -20, 2, -10, 10, -5, -5, + 13, 0, -3, 7, -12, -17, -13, 1, -6, 10, + -1, -9, 4, -16, 3, 2, 5, 1, -4, 9, + -0, 1, 3, 15, -4, -13, -6, 4, 3, -2, + -1, -4, -7, -7, -2, 8, -16, -4, -10, 5, + 1, -3, 2, -9, -4, 1, -1, -1, -4, -6, + -4, 1, 0, -9, 15, -1, -7, -3, -5, -0, + 3, -0, -6, -17, 16, -3, 3, -2, -3, 5, + 3, -2, 3, 13, 8, 1, -3, -8, -7, -4, + 6, -6, -15, -7, 0, 4, -3, -3, -10, 14, + 1, 3, 14, 4, -1, 14})}); +} + +TEST_F(NAIVE, DCT_INT8_MASK) { + Checker checker(handle(), + /* check_dispatch */ false); + DctChannelSelectForward::Param param; + param.format = DctChannelSelectForward::Param::Format::NCHW4; + auto src_tensor = TensorValue( + {1, 3, 8, 16}, dtype::Uint8(), + {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37, 250, + 107, 36, 80, 5, 54, 247, 218, 191, 211, 239, 76, 140, 33, + 253, 85, 132, 101, 105, 177, 46, 183, 102, 99, 19, 175, 108, + 252, 42, 238, 48, 251, 108, 90, 176, 2, 35, 46, 161, 252, + 38, 225, 195, 174, 58, 165, 198, 249, 162, 118, 198, 41, 154, + 10, 87, 24, 201, 12, 188, 1, 93, 179, 246, 134, 18, 178, + 173, 36, 122, 89, 115, 46, 43, 205, 232, 55, 149, 30, 206, + 97, 186, 125, 35, 209, 51, 48, 222, 222, 130, 173, 63, 0, + 223, 19, 5, 162, 154, 143, 134, 63, 123, 102, 102, 212, 145, + 80, 87, 212, 42, 26, 219, 225, 120, 94, 213, 238, + + 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76, 151, + 105, 138, 87, 125, 55, 60, 211, 15, 158, 198, 37, 54, 203, + 239, 79, 56, 6, 53, 201, 97, 233, 178, 74, 193, 46, 249, + 65, 5, 208, 130, 67, 191, 168, 152, 129, 253, 195, 231, 3, + 109, 229, 254, 193, 229, 202, 108, 22, 89, 251, 13, 53, 47, + 192, 12, 81, 19, 53, 93, 104, 41, 217, 215, 184, 136, 249, + 14, 244, 4, 220, 33, 53, 142, 219, 43, 28, 68, 198, 202, + 88, 235, 7, 233, 47, 84, 127, 28, 17, 189, 135, 183, 192, + 239, 116, 31, 118, 186, 49, 251, 233, 220, 27, 97, 30, 43, + 193, 217, 48, 24, 225, 15, 3, 26, 71, 82, 104, + + 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173, 43, + 195, 123, 111, 106, 5, 91, 254, 34, 76, 52, 82, 193, 179, + 185, 71, 57, 215, 18, 5, 151, 13, 59, 206, 154, 95, 149, + 40, 229, 16, 116, 144, 249, 67, 97, 223, 208, 144, 92, 174, + 246, 77, 196, 211, 20, 123, 239, 250, 235, 65, 184, 54, 239, + 168, 135, 17, 79, 117, 171, 173, 109, 39, 57, 13, 129, 79, + 236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226, 44, + 113, 164, 217, 46, 249, 182, 22, 98, 228, 49, 78, 101, 236, + 181, 5, 245, 72, 62, 182, 151, 210, 254, 190, 35, 73, 190, + 247, 50, 81, 49, 217, 86, 229, 139, 203, 57, 194}); + checker.set_param(param).exect( + Testcase{src_tensor, + TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), + TensorValue({32}, dtype::Int32(), + {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, + 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, + 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), + {}}, + Testcase{{}, + {}, + {}, + TensorValue( + {1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f), + {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, + 8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, + 6, 7, -0, -2, -7, 11, 6, 3, -1, 7, 94, + -5, 6, -5, 98, 0, -3, -16, 5, 7, 13, -8, + 1, 5, -5, -8, 108, -3, -8, -7, 110, 1, -2, + 5, -0, 7, 8, -9, 14, -0, 1, -4})}); + + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 3, 8, 16}, dtype::Uint8(), + {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, + 132, 37, 250, 107, 36, 80, 5, 54, 247, 218, + 191, 211, 239, 76, 140, 33, 253, 85, 132, 101, + 105, 177, 46, 183, 102, 99, 19, 175, 108, 252, + 42, 238, 48, 251, 108, 90, 176, 2, 35, 46, + 161, 252, 38, 225, 195, 174, 58, 165, 198, 249, + 162, 118, 198, 41, 154, 10, 87, 24, 201, 12, + 188, 1, 93, 179, 246, 134, 18, 178, 173, 36, + 122, 89, 115, 46, 43, 205, 232, 55, 149, 30, + 206, 97, 186, 125, 35, 209, 51, 48, 222, 222, + 130, 173, 63, 0, 223, 19, 5, 162, 154, 143, + 134, 63, 123, 102, 102, 212, 145, 80, 87, 212, + 42, 26, 219, 225, 120, 94, 213, 238, + + 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, + 74, 76, 151, 105, 138, 87, 125, 55, 60, 211, + 15, 158, 198, 37, 54, 203, 239, 79, 56, 6, + 53, 201, 97, 233, 178, 74, 193, 46, 249, 65, + 5, 208, 130, 67, 191, 168, 152, 129, 253, 195, + 231, 3, 109, 229, 254, 193, 229, 202, 108, 22, + 89, 251, 13, 53, 47, 192, 12, 81, 19, 53, + 93, 104, 41, 217, 215, 184, 136, 249, 14, 244, + 4, 220, 33, 53, 142, 219, 43, 28, 68, 198, + 202, 88, 235, 7, 233, 47, 84, 127, 28, 17, + 189, 135, 183, 192, 239, 116, 31, 118, 186, 49, + 251, 233, 220, 27, 97, 30, 43, 193, 217, 48, + 24, 225, 15, 3, 26, 71, 82, 104, + + 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, + 230, 173, 43, 195, 123, 111, 106, 5, 91, 254, + 34, 76, 52, 82, 193, 179, 185, 71, 57, 215, + 18, 5, 151, 13, 59, 206, 154, 95, 149, 40, + 229, 16, 116, 144, 249, 67, 97, 223, 208, 144, + 92, 174, 246, 77, 196, 211, 20, 123, 239, 250, + 235, 65, 184, 54, 239, 168, 135, 17, 79, 117, + 171, 173, 109, 39, 57, 13, 129, 79, 236, 117, + 134, 123, 149, 113, 198, 160, 249, 242, 220, 226, + 44, 113, 164, 217, 46, 249, 182, 22, 98, 228, + 49, 78, 101, 236, 181, 5, 245, 72, 62, 182, + 151, 210, 254, 190, 35, 73, 190, 247, 50, 81, + 49, 217, 86, 229, 139, 203, 57, 194}), + TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}), + TensorValue({28}, dtype::Int32(), + {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, + 32, 25, 0, 1, 8, 16, 9, 2, 3, 10, + 0, 1, 8, 16, 9, 2, 3, 10}), + {}}, + Testcase{{}, + {}, + {}, + TensorValue( + {1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f), + {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, + 8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, + 6, 7, + + 94, -5, 6, -5, 98, 0, -3, -16, 5, 7, 13, + -8, 1, 5, -5, -8, 108, -3, -8, -7, 110, 1, + -2, 5, -0, 7, 8, -9, 14, -0, 1, -4})}); +} + +TEST_F(NAIVE, DCT_4x4) { + Checker checker(handle(), + /* check_dispatch */ false); + DctChannelSelectForward::Param param; + param.dct_block_size = 4; + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 1, 8, 8}, dtype::Uint8(), + {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, + 175, 50, 240, 251, 76, 37, 34, 166, 250, 195, + 231, 139, 128, 233, 75, 80, 3, 2, 19, 140, + 193, 203, 115, 107, 250, 209, 14, 243, 199, 60, + 234, 107, 174, 156, 81, 87, 13, 116, 96, 140, + 197, 253, 113, 223, 229, 159, 249, 252, 89, 84, + 45, 16, 41, 72}), + {}, + {}, + {}}, + Testcase{{}, + {}, + {}, + TensorValue( + {1, 16, 2, 2}, dtype::Float32(), + {5.42000000e+02, 5.91750000e+02, 6.78000000e+02, + 4.27750000e+02, 3.49953423e+01, -1.17686939e+01, + -1.66842098e+01, -3.85316620e+01, -3.80000000e+01, + -1.22500000e+01, 2.00000000e+01, -9.77500000e+01, + -1.61191311e+01, -9.46695328e+00, 3.28882408e+01, + -4.92537880e+01, 1.66958221e+02, -4.26609573e+01, + 2.56999969e-01, 5.39384537e+01, 1.71819706e+01, + 9.00009003e+01, -1.23818558e+02, 1.18912420e+01, + 6.61014938e+01, -2.49261990e+01, 4.95798302e+00, + -1.02324417e+02, 7.85859919e+00, 3.73140755e+01, + 1.03783745e+02, -4.61430321e+01, -1.43000000e+02, + -7.57500000e+01, -5.00000000e-01, -8.27500000e+01, + 1.34834738e+01, -1.93409515e+02, 6.84791718e+01, + -4.01652241e+00, 1.22000000e+02, -8.57500000e+01, + -4.05000000e+01, -5.62500000e+01, -2.88564739e+01, + 5.76532059e+01, -2.67414131e+01, 1.70877876e+01, + 3.85416756e+01, 3.09300461e+01, 5.84670639e+00, + 1.85747864e+02, -2.05141403e+02, -9.91859360e+01, + -1.66716263e+02, -1.71430378e+01, 6.71520996e+00, + 8.41980438e+01, -3.50666313e+01, -1.48387482e+02, + 1.08180256e+01, 5.49991112e+01, -1.06814528e+01, + 1.86087704e+01})}); + + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 1, 8, 8}, dtype::Uint8(), + {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, + 175, 50, 240, 251, 76, 37, 34, 166, 250, 195, + 231, 139, 128, 233, 75, 80, 3, 2, 19, 140, + 193, 203, 115, 107, 250, 209, 14, 243, 199, 60, + 234, 107, 174, 156, 81, 87, 13, 116, 96, 140, + 197, 253, 113, 223, 229, 159, 249, 252, 89, 84, + 45, 16, 41, 72}), + TensorValue({2}, dtype::Int32(), {0, 6}), + TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}), + {}}, + Testcase{ + {}, + {}, + {}, + TensorValue( + {1, 6, 2, 2}, dtype::Float32(), + {5.4200000e+02, 5.9175000e+02, 6.7800000e+02, + 4.2775000e+02, 3.4995342e+01, -1.1768694e+01, + -1.6684210e+01, -3.8531662e+01, -1.4300000e+02, + -7.5750000e+01, -5.0000000e-01, -8.2750000e+01, + 1.6695822e+02, -4.2660957e+01, 2.5699997e-01, + 5.3938454e+01, -3.8000000e+01, -1.2250000e+01, + 2.0000000e+01, -9.7750000e+01, -1.6119131e+01, + -9.4669533e+00, 3.2888241e+01, -4.9253788e+01})}); +} + +TEST_F(NAIVE, DCT_WITH_MASK) { + Checker checker(handle(), + /* check_dispatch */ false); + DctChannelSelectForward::Param param; + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 3, 8, 16}, dtype::Uint8(), + {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204}), + TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), + TensorValue({32}, dtype::Int32(), + {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, + 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, + 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), + {}}, + Testcase{{}, + {}, + {}, + TensorValue({1, 32, 1, 2}, dtype::Float32(), + {890.12494, 941.25, -7.0498576, + 99.47632, -22.850792, -97.862236, + -101.043236, -4.727012, 28.275675, + -157.96654, 42.1377, 45.06531, + -149.77373, 24.487143, -8.054966, + -13.990831, -6.9395194, -3.9211385, + 64.79172, -12.363858, -47.875, + 59., 56.271786, -62.725567, + 120.522675, 16.559765, 85.74334, + 112.904495, 99.375, 29.499973, + 2.0220923, -19.681704, 890.12494, + 941.25, -7.0498576, 99.47632, + -22.850792, -97.862236, -101.043236, + -4.727012, 28.275675, -157.96654, + 42.1377, 45.06531, -149.77373, + 24.487143, -8.054966, -13.990831, + 890.12494, 941.25, -7.0498576, + 99.47632, -22.850792, -97.862236, + -101.043236, -4.727012, 28.275675, + -157.96654, 42.1377, 45.06531, + -149.77373, 24.487143, -8.054966, + -13.990831})}); + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 3, 8, 16}, dtype::Uint8(), + {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204}), + TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}), + TensorValue({24}, dtype::Int32(), + {17, 24, 32, 25, 18, 11, 4, 5, 0, 1, 8, 16, + 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), + {}}, + Testcase{{}, + {}, + {}, + TensorValue({1, 24, 1, 2}, dtype::Float32(), + {-6.9395194, -3.9211385, 64.79172, + -12.363858, -47.875, 59., + 56.271786, -62.725567, 120.522675, + 16.559765, 85.74334, 112.904495, + 99.375, 29.499973, 2.0220923, + -19.681704, 890.12494, 941.25, + -7.0498576, 99.47632, -22.850792, + -97.862236, -101.043236, -4.727012, + 28.275675, -157.96654, 42.1377, + 45.06531, -149.77373, 24.487143, + -8.054966, -13.990831, 890.12494, + 941.25, -7.0498576, 99.47632, + -22.850792, -97.862236, -101.043236, + -4.727012, 28.275675, -157.96654, + 42.1377, 45.06531, -149.77373, + 24.487143, -8.054966, -13.990831})}); +} + +TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) { + Checker checker(handle(), + /* check_dispatch */ false); + using Param = DctChannelSelectForward::Param; + Param param; + param.fastImpl = Param::FastImpl::FIX_32_MASK; + checker.set_param(param).exect( + Testcase{TensorValue( + {1, 3, 8, 16}, dtype::Uint8(), + {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204, + + 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, + 18, 16, 93, 185, 99, 102, 205, 172, 191, 29, + 185, 6, 47, 84, 0, 47, 105, 203, 251, 73, + 196, 83, 3, 211, 32, 181, 49, 111, 114, 83, + 148, 232, 77, 17, 35, 2, 154, 100, 41, 135, + 141, 206, 56, 91, 137, 199, 104, 192, 75, 122, + 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, + 49, 145, 87, 210, 97, 190, 179, 93, 125, 105, + 181, 207, 148, 178, 133, 53, 25, 198, 238, 151, + 14, 120, 213, 195, 145, 20, 122, 107, 217, 185, + 65, 5, 115, 110, 82, 206, 163, 86, 2, 2, + 44, 125, 50, 38, 41, 106, 30, 5, 151, 243, + 238, 181, 232, 191, 161, 57, 23, 204}), + TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), + TensorValue({32}, dtype::Int32(), + {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, + 25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, + 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), + {}}, + Testcase{{}, + {}, + {}, + TensorValue({1, 32, 1, 2}, dtype::Float32(), + {890.12494, 941.25, -7.0498576, + 99.47632, -22.850792, -97.862236, + -101.043236, -4.727012, 28.275675, + -157.96654, 42.1377, 45.06531, + -149.77373, 24.487143, -8.054966, + -13.990831, -6.9395194, -3.9211385, + 64.79172, -12.363858, -47.875, + 59., 56.271786, -62.725567, + 120.522675, 16.559765, 85.74334, + 112.904495, 99.375, 29.499973, + 2.0220923, -19.681704, 890.12494, + 941.25, -7.0498576, 99.47632, + -22.850792, -97.862236, -101.043236, + -4.727012, 28.275675, -157.96654, + 42.1377, 45.06531, -149.77373, + 24.487143, -8.054966, -13.990831, + 890.12494, 941.25, -7.0498576, + 99.47632, -22.850792, -97.862236, + -101.043236, -4.727012, 28.275675, + -157.96654, 42.1377, 45.06531, + -149.77373, 24.487143, -8.054966, + -13.990831})}); +} + +TEST_F(NAIVE, DCT_WITH_MASK2) { + Checker checker(handle(), false); + DctChannelSelectForward::Param param; + UniformIntRNG rng_oc(0, 3 * 64); + for (size_t n : {1, 3}) { + for (size_t ic : {1, 3}) { + for (size_t ih : {8, 16, 32, 512, 1024}) { + for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { + int random_oc = static_cast(rng_oc.gen_single_val()); + int max_oc = ic * 64; + int mask_oc = (random_oc % max_oc) + 1; + auto test_case = + gen_dct_case(n, ic, ih, iw, mask_oc, param); + checker.set_param(param).exect(test_case->testcase_in, + test_case->testcase_out); + } + } + } + } +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file