提交 ba66e1d0 编写于 作者: M Megvii Engine Team

feat(dnn): add nchw_fp32 nchw44_qint8 cuda dct

GitOrigin-RevId: 581e31fc209008418f9821f32f7c71db76f84ddf
上级 b3278229
...@@ -182,6 +182,48 @@ class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { ...@@ -182,6 +182,48 @@ class WarpPerspectiveBackwardMat: public WarpPerspectiveBase {
size_t workspace_in_bytes); size_t workspace_in_bytes);
}; };
class DctChannelSelectForward : public OperatorBase {
DEF_OPR_PARAM(DctChannelSelect);
DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1);
public:
/**
* \param[in] DctChannelSelectForward input, must be uint8 nchw tensor
* \param[in] mask_offset input, must be int32 nchw tensor
* \param[in] mask_val input, must be int32 nchw tensor
* \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
* \param[out] workspace temporary workspace to perform forward
*/
virtual void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst) = 0;
protected:
void check_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst);
void deduce_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst);
std::string param_msg() const;
};
} // namespace megdnn } // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h" #include "megdnn/internal/opr_header_epilogue.h"
......
...@@ -411,6 +411,9 @@ pdef('ElemwiseMultiType').add_enum( ...@@ -411,6 +411,9 @@ pdef('ElemwiseMultiType').add_enum(
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
(pdef('DctChannelSelect', '2d discrete cosine transform').add_enum_alias('Format', 'ConvolutionV0').
add_enum('FastImpl', 'NONE', 'FIX_32_MASK').add_fields('int32', 'dct_block_size', 8))
(pdef('MatrixMul', version=0, is_legacy=True). (pdef('MatrixMul', version=0, is_legacy=True).
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). add_fields('bool', 'transposeA', 'false', 'transposeB', 'false').
add_enum('DataType', add_enum('DataType',
......
/**
* \file dnn/src/common/dct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void DctChannelSelectForward::deduce_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst) {
const size_t dct_block = param().dct_block_size;
const size_t in = src.shape[0];
const size_t ic = src.shape[1];
const size_t ih = src.shape[2];
const size_t iw = src.shape[3];
check_layout_fwd(src, mask_offset, mask_val, dst);
const size_t oh = ih / dct_block;
const size_t ow = iw / dct_block;
//! mask will be empty or (ic + 1) elements
size_t oc = mask_offset.ndim > 0 && mask_offset[0] >= 2
? mask_val.shape[0]
: ic * dct_block * dct_block;
if (param().fastImpl == Param::FastImpl::FIX_32_MASK) {
megdnn_assert(oc == 32,
"Param::FastImpl::FIX_32_MASK oc must be 32, but %zu",
oc);
}
if (param().format == Param::Format::NCHW) {
dst = TensorLayout(TensorShape({in, oc, oh, ow}), dst.dtype);
} else {
megdnn_assert(param().format == Param::Format::NCHW4,
"dct format must be nchw or nchw4");
megdnn_assert(oc % 4 == 0, "oc mod 4 == 0 in nchw4");
dst = TensorLayout(TensorShape({in, oc / 4, oh, ow, 4}), dst.dtype);
}
}
void DctChannelSelectForward::deduce_layout(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst) {
deduce_layout_fwd(src, mask_offset, mask_val, dst);
}
void DctChannelSelectForward::check_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst) {
const size_t dct_block = param().dct_block_size;
const size_t ih = src.shape[2];
const size_t iw = src.shape[3];
megdnn_assert(mask_offset.ndim == 0 || (mask_offset.ndim == 1 &&
(mask_offset.shape[0] == 0 ||
mask_offset.shape[0] >= 2) &&
mask_val.ndim == 1),
"mask only support one valid dim");
megdnn_assert(mask_val.ndim <= 1, "only support one dim");
megdnn_assert(src.dtype.enumv() == DTypeEnum::Uint8,
"src.dtype == dtype::Uint8");
megdnn_assert(dst.dtype.enumv() == DTypeEnum::Float32 ||
dst.dtype.enumv() == DTypeEnum::QuantizedS8,
"dst.dtype == dtype::Float32 || dst.dtype.enumv() == "
"DTypeEnum::QuantizedS8");
megdnn_assert(ih % dct_block == 0, "ih mod dctblock == 0");
megdnn_assert(iw % dct_block == 0, "iw mod dctblock == 0");
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -201,6 +201,7 @@ private: ...@@ -201,6 +201,7 @@ private:
cb(RemapBackwardMat) \ cb(RemapBackwardMat) \
cb(AdaptivePoolingForward) \ cb(AdaptivePoolingForward) \
cb(AdaptivePoolingBackward) \ cb(AdaptivePoolingBackward) \
cb(DctChannelSelectForward)
/*! /*!
* \brief specialize HandleImpl::create_operator for a single opr type; * \brief specialize HandleImpl::create_operator for a single opr type;
......
/**
* \file dnn/src/cuda/dct/dct_channel_select.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megcore_cdefs.h"
#include "src/cuda/dct/dct_channel_select.cuh"
#include "src/cuda/error_info.cuh"
namespace megdnn {
namespace cuda {
template <typename T>
struct CudaPostProcess;
template <>
struct CudaPostProcess<float> {
CudaPostProcess(float){};
static inline __device__ float func(float val) { return val; }
};
template <>
struct CudaPostProcess<int8_t> {
CudaDTypeParamImpl<dt_qint8> m_type_cvt;
CudaPostProcess(float scale) { m_type_cvt.inv_scale = 1.f / scale; };
inline __device__ int8_t func(float val) {
return m_type_cvt.quantize(val).as_int8();
}
};
template <uint32_t format>
struct ChannelBlockHelper;
template <>
struct ChannelBlockHelper<dct::DctLayoutFormat::NCHW4> {
static constexpr int channel_block = 4;
};
template <>
struct ChannelBlockHelper<dct::DctLayoutFormat::NCHW> {
static constexpr int channel_block = 1;
};
namespace dct {
namespace {
inline __device__ void load_row(float (&row_cache)[8], const uint8_t* src) {
int2 row = *((int2*)src);
row_cache[0] = (float)(((uchar4*)&(row.x))->x);
row_cache[1] = (float)(((uchar4*)&(row.x))->y);
row_cache[2] = (float)(((uchar4*)&(row.x))->z);
row_cache[3] = (float)(((uchar4*)&(row.x))->w);
row_cache[4] = (float)(((uchar4*)&(row.y))->x);
row_cache[5] = (float)(((uchar4*)&(row.y))->y);
row_cache[6] = (float)(((uchar4*)&(row.y))->z);
row_cache[7] = (float)(((uchar4*)&(row.y))->w);
}
inline __device__ void fast_dct_1d_internel(float& src0, float& src1,
float& src2, float& src3,
float& src4, float& src5,
float& src6, float& src7) {
constexpr float rsqrt_8 = 0.3535533905932737f; //!< rsqrt_8 = sqrt(1 / 8)
constexpr float a = 1.387039845322148f; //!< a = sqrt2 * cos(pi * 1 / 16)
constexpr float b = 1.306562964876377f; //!< b = sqrt2 * cos(pi * 2 / 16)
constexpr float c = 1.175875602419359f; //!< c = sqrt2 * cos(pi * 3 / 16)
constexpr float d = 0.785694958387102f; //!< d = sqrt2 * cos(pi * 5 / 16)
constexpr float e = 0.541196100146197f; //!< e = sqrt2 * cos(pi * 6 / 16)
constexpr float f = 0.275899379282943f; //!< f = sqrt2 * cos(pi * 7 / 16)
const float add_0_7 = src0 + src7;
const float add_1_6 = src1 + src6;
const float add_2_5 = src2 + src5;
const float add_3_4 = src3 + src4;
const float sub_0_7 = src0 - src7;
const float sub_6_1 = src6 - src1;
const float sub_2_5 = src2 - src5;
const float sub_4_3 = src4 - src3;
const float add_0_7_3_4 = add_0_7 + add_3_4;
const float add_1_6_2_5 = add_1_6 + add_2_5;
const float add_0_7_sub_3_4 = add_0_7 - add_3_4;
const float add_1_6_sub_2_5 = add_1_6 - add_2_5;
src0 = rsqrt_8 * (add_0_7_3_4 + add_1_6_2_5);
src2 = rsqrt_8 * (b * add_0_7_sub_3_4 + e * add_1_6_sub_2_5);
src4 = rsqrt_8 * (add_0_7_3_4 - add_1_6_2_5);
src6 = rsqrt_8 * (e * add_0_7_sub_3_4 - b * add_1_6_sub_2_5);
src1 = rsqrt_8 * (a * sub_0_7 - c * sub_6_1 + d * sub_2_5 - f * sub_4_3);
src3 = rsqrt_8 * (c * sub_0_7 + f * sub_6_1 - a * sub_2_5 + d * sub_4_3);
src5 = rsqrt_8 * (d * sub_0_7 + a * sub_6_1 + f * sub_2_5 - c * sub_4_3);
src7 = rsqrt_8 * (f * sub_0_7 + d * sub_6_1 + c * sub_2_5 + a * sub_4_3);
}
inline __device__ void fast_dct_1d(float (&src)[8]) {
fast_dct_1d_internel(src[0], src[1], src[2], src[3], src[4], src[5], src[6],
src[7]);
}
inline __device__ void fast_dct_1d_col(float (&src)[8][8], const int col) {
fast_dct_1d_internel(src[0][col], src[1][col], src[2][col], src[3][col],
src[4][col], src[5][col], src[6][col], src[7][col]);
}
enum class MaskType {
NO_MASK = 0,
USER_DEFINE_MASK = 1,
FIX_32_MASK = 2,
MASK_END
};
template <const int dct_block, const int block_oh, const int block_ow,
uint32_t format, MaskType mask_type, typename DstDtype, typename T2>
struct StoreMask;
template <const int dct_block, const int block_oh, const int block_ow,
typename T2>
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW,
MaskType::USER_DEFINE_MASK, float, T2> {
static inline __device__ void func(
const float (&thread_cache)[dct_block][dct_block], float* dst_tid,
const int oc_stride, int channel_idx, const int* mask_offset,
const int* mask_val, CudaPostProcess<T2>& quant_param,
megcore::AsyncErrorInfo* error_info, void* error_tracker) {
__shared__ float shared[dct_block][dct_block][block_oh][block_ow];
#pragma unroll
for (int i = 0; i < dct_block; ++i)
#pragma unroll
for (int j = 0; j < dct_block; ++j) {
shared[i][j][threadIdx.y][threadIdx.x] = thread_cache[i][j];
}
const int store_channel_offset = mask_offset[channel_idx];
const int nr_store_channel =
mask_offset[channel_idx + 1] - store_channel_offset;
if (nr_store_channel < 0) {
set_async_error_info(error_info, error_tracker,
"nchw sub mask len must > 0");
}
for (int store_channel_idx = 0; store_channel_idx < nr_store_channel;
++store_channel_idx) {
const int index =
mask_val[store_channel_offset + store_channel_idx];
dst_tid[store_channel_idx * oc_stride] =
shared[index / dct_block][index % dct_block][threadIdx.y]
[threadIdx.x];
}
}
};
template <const int dct_block, const int block_oh, const int block_ow,
typename T2>
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW4,
MaskType::USER_DEFINE_MASK, int8_t, T2> {
static inline __device__ void func(
const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid,
const int oc_stride, int channel_idx, const int* mask_offset,
const int* mask_val, CudaPostProcess<T2>& quant_param,
megcore::AsyncErrorInfo* error_info, void* error_tracker) {
//! nchw4 channel_block is 4
constexpr int channel_block =
ChannelBlockHelper<DctLayoutFormat::NCHW4>::channel_block;
__shared__ float shared[dct_block][dct_block][block_oh][block_ow];
#pragma unroll
for (int i = 0; i < dct_block; ++i)
#pragma unroll
for (int j = 0; j < dct_block; ++j) {
shared[i][j][threadIdx.y][threadIdx.x] = thread_cache[i][j];
}
const int store_channel_offset = mask_offset[channel_idx];
const int nr_store_channel =
mask_offset[channel_idx + 1] - store_channel_offset;
if (nr_store_channel % 4 != 0 || nr_store_channel < 0) {
set_async_error_info(error_info, error_tracker,
"nchw4 sub_mask_len mod 4 should be 0 and "
"sub_mask_len must > 0");
}
for (int store_channel_idx = 0; store_channel_idx < nr_store_channel;
store_channel_idx += channel_block) {
const int index0 =
mask_val[store_channel_offset + store_channel_idx];
const int index1 =
mask_val[store_channel_offset + store_channel_idx + 1];
const int index2 =
mask_val[store_channel_offset + store_channel_idx + 2];
const int index3 =
mask_val[store_channel_offset + store_channel_idx + 3];
const int store_c4_idx = store_channel_idx / channel_block;
*(char4*)(&dst_tid[store_c4_idx * channel_block * oc_stride]) = {
quant_param.func(
shared[index0 / dct_block][index0 % dct_block]
[threadIdx.y][threadIdx.x]),
quant_param.func(
shared[index1 / dct_block][index1 % dct_block]
[threadIdx.y][threadIdx.x]),
quant_param.func(
shared[index2 / dct_block][index2 % dct_block]
[threadIdx.y][threadIdx.x]),
quant_param.func(
shared[index3 / dct_block][index3 % dct_block]
[threadIdx.y][threadIdx.x])};
}
}
};
template <const int dct_block, const int block_oh, const int block_ow,
uint32_t format, typename DstDtype, typename T2>
struct StoreMask<dct_block, block_oh, block_ow, format, MaskType::NO_MASK,
DstDtype, T2> {
static inline __device__ void func(
const float (&thread_cache)[dct_block][dct_block],
DstDtype* dst_tid, const int oc_stride, int channel_idx,
const int* mask_offset, const int* mask_val,
CudaPostProcess<T2>& quant_param,
megcore::AsyncErrorInfo* error_info, void* error_tracker) {
constexpr int channel_block = ChannelBlockHelper<format>::channel_block;
#pragma unroll
for (int i = 0; i < dct_block; i++) {
#pragma unroll
for (int j = 0; j < dct_block; j++) {
dst_tid[(i * dct_block + j) / channel_block * channel_block *
oc_stride +
(i * dct_block + j) % channel_block] =
quant_param.func(thread_cache[i][j]);
}
}
}
};
template <const int dct_block, const int block_oh, const int block_ow,
typename T2>
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW,
MaskType::FIX_32_MASK, float, T2> {
static inline __device__ void func(
const float (&thread_cache)[dct_block][dct_block], float* dst_tid,
const int oc_stride, int channel_idx, const int* mask_offset,
const int* mask_val, CudaPostProcess<T2>& quant_param,
megcore::AsyncErrorInfo* error_info, void* error_tracker) {
#define STORE(store_index, index) \
dst_tid[store_index * oc_stride] = \
thread_cache[index / dct_block][index % dct_block]
STORE(0, 0);
STORE(1, 1);
STORE(2, 8);
STORE(3, 16);
STORE(4, 9);
STORE(5, 2);
STORE(6, 3);
STORE(7, 10);
if (channel_idx == 0) {
STORE(8, 17);
STORE(9, 24);
STORE(10, 32);
STORE(11, 25);
STORE(12, 18);
STORE(13, 11);
STORE(14, 4);
STORE(15, 5);
}
#undef STORE
}
};
template <const int dct_block, const int block_oh, const int block_ow,
typename T2>
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW4,
MaskType::FIX_32_MASK, int8_t, T2> {
static inline __device__ void func(
const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid,
const int oc_stride, int channel_idx, const int* mask_offset,
const int* mask_val, CudaPostProcess<T2>& quant_param,
megcore::AsyncErrorInfo* error_info, void* error_tracker) {
#define STORE(store_index, index0, index1, index2, index3) \
*(char4*)(&dst_tid[store_index * oc_stride]) = { \
quant_param.func( \
thread_cache[index0 / dct_block][index0 % dct_block]), \
quant_param.func( \
thread_cache[index1 / dct_block][index1 % dct_block]), \
quant_param.func( \
thread_cache[index2 / dct_block][index2 % dct_block]), \
quant_param.func( \
thread_cache[index3 / dct_block][index3 % dct_block])}
STORE(0, 0, 1, 8, 16);
STORE(4, 9, 2, 3, 10);
if (channel_idx == 0) {
STORE(8, 17, 24, 32, 25);
STORE(12, 18, 11, 4, 5);
}
#undef STORE
}
};
template <const int dct_block, MaskType mask_type, const int ker_block_h,
const int ker_block_w, uint32_t format, typename DstDtype,
typename T2>
__global__ void kern_dct(const uint8_t* src, DstDtype* dst, const int n,
const int c, const int h, const int w, const int oh,
const int ow, const int oc_stride, const int oc,
const int* mask_offset, const int* mask_val,
CudaPostProcess<T2> quant_param,
megcore::AsyncErrorInfo* error_info,
void* error_tracker) {
constexpr int block_oh = ker_block_h / dct_block;
constexpr int block_ow = ker_block_w / dct_block;
const int channel_stride = h * w;
const int oc_idx = blockIdx.z % c;
const int oh_idx = blockIdx.y * block_oh + threadIdx.y;
const int ow_idx = blockIdx.x * block_ow + threadIdx.x;
float thread_cache[dct_block][dct_block];
const uint8_t* src_tid =
src + blockIdx.z * channel_stride +
(blockIdx.y * ker_block_h + threadIdx.y * dct_block) * w +
(blockIdx.x * ker_block_w + threadIdx.x * dct_block);
const int inner_channel_offset =
(oh_idx * ow + ow_idx) * ChannelBlockHelper<format>::channel_block;
DstDtype* dst_tid =
dst + blockIdx.z * channel_stride + inner_channel_offset;
if (mask_type != MaskType::NO_MASK) {
const int batch_idx = blockIdx.z / c;
const int batch_stride = oc_stride * oc;
int out_channel_offset = 0;
if (mask_type == MaskType::FIX_32_MASK) {
//! trick out_channel_offset = {0, 16, 24}[oc_idx]; oc_idx = 0, 1, 2
out_channel_offset = 16 * oc_idx - 8 * (oc_idx >> 1);
} else {
out_channel_offset = mask_offset[oc_idx];
}
dst_tid = dst + batch_idx * batch_stride +
out_channel_offset * oc_stride + inner_channel_offset;
}
if (oh_idx < oh && ow_idx < ow) {
load_row(thread_cache[0], src_tid + 0 * w);
load_row(thread_cache[1], src_tid + 1 * w);
load_row(thread_cache[2], src_tid + 2 * w);
load_row(thread_cache[3], src_tid + 3 * w);
load_row(thread_cache[4], src_tid + 4 * w);
load_row(thread_cache[5], src_tid + 5 * w);
load_row(thread_cache[6], src_tid + 6 * w);
load_row(thread_cache[7], src_tid + 7 * w);
//! TMP = A @ C.T
fast_dct_1d(thread_cache[0]);
fast_dct_1d(thread_cache[1]);
fast_dct_1d(thread_cache[2]);
fast_dct_1d(thread_cache[3]);
fast_dct_1d(thread_cache[4]);
fast_dct_1d(thread_cache[5]);
fast_dct_1d(thread_cache[6]);
fast_dct_1d(thread_cache[7]);
//! TMP = C @ TMP
fast_dct_1d_col(thread_cache, 0);
fast_dct_1d_col(thread_cache, 1);
fast_dct_1d_col(thread_cache, 2);
fast_dct_1d_col(thread_cache, 3);
fast_dct_1d_col(thread_cache, 4);
fast_dct_1d_col(thread_cache, 5);
fast_dct_1d_col(thread_cache, 6);
fast_dct_1d_col(thread_cache, 7);
StoreMask<dct_block, block_oh, block_ow, format, mask_type, DstDtype,
T2>::func(thread_cache, dst_tid, oc_stride, oc_idx,
mask_offset, mask_val, quant_param, error_info,
error_tracker);
}
}
} // namespace
template <int dct_block, uint32_t format, typename DstDtype>
void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n,
const int c, const int h, const int w, const int oc,
bool fix_32_mask, const int* mask_offset,
const int* mask_val, cudaStream_t stream,
megcore::AsyncErrorInfo* error_info, void* error_tracker,
float scale) {
constexpr int ker_block_h = 32;
constexpr int ker_block_w = 256;
const int oh = h / dct_block;
const int ow = w / dct_block;
const int oc_stride = oh * ow;
const dim3 block_dim(DIVUP(w, ker_block_w), DIVUP(h, ker_block_h), n * c);
const dim3 thread_dim(DIVUP(ker_block_w, dct_block),
DIVUP(ker_block_h, dct_block));
auto cuda_dtype_param = CudaPostProcess<DstDtype>(scale);
if (fix_32_mask) {
kern_dct<dct_block, MaskType::FIX_32_MASK, ker_block_h, ker_block_w,
format><<<block_dim, thread_dim, 0, stream>>>(
d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset,
mask_val, cuda_dtype_param, error_info, error_tracker);
} else if (mask_offset && mask_val) {
kern_dct<dct_block, MaskType::USER_DEFINE_MASK, ker_block_h,
ker_block_w, format><<<block_dim, thread_dim, 0, stream>>>(
d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset,
mask_val, cuda_dtype_param, error_info, error_tracker);
} else {
kern_dct<dct_block, MaskType::NO_MASK, ker_block_h, ker_block_w, format>
<<<block_dim, thread_dim, 0, stream>>>(
d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc,
mask_offset, mask_val, cuda_dtype_param, error_info,
error_tracker);
}
}
template void call_kern_dct<8, DctLayoutFormat::NCHW, float>(
const uint8_t* d_src, float* d_dst, const int n, const int c,
const int h, const int w, const int oc, bool fix_32_mask,
const int* mask_offset, const int* mask_val, cudaStream_t stream,
megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale);
template void call_kern_dct<8, DctLayoutFormat::NCHW4, int8_t>(
const uint8_t* d_src, int8_t* d_dst, const int n, const int c,
const int h, const int w, const int oc, bool fix_32_mask,
const int* mask_offset, const int* mask_val, cudaStream_t stream,
megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale);
} // namespace dct
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/cuda/dct/dct_channel_select.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the
"License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
or
* implied.
*/
#pragma once
#include <stdint.h>
#include <cstdio>
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace dct {
using DctLayoutFormat = megdnn::param_enumv::DctChannelSelect::Format;
template <int dct_block, uint32_t format, typename DstDtype>
void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n,
const int c, const int h, const int w, const int oc,
bool fix_32_mask, const int* mask_offset,
const int* mask_val, cudaStream_t stream,
megcore::AsyncErrorInfo* error_info, void* error_tracker,
float scale = 1.f);
} // namespace dct
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/naive/dct/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
#include "src/cuda/dct/dct_channel_select.cuh"
#include "src/cuda/dct/opr_impl.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val,
_megdnn_tensor_out dst,
_megdnn_workspace /*workspace*/) {
auto stream = cuda_stream(this->handle());
const int in = src.layout.shape[0];
const int ic = src.layout.shape[1];
const int ih = src.layout.shape[2];
const int iw = src.layout.shape[3];
int oc = dst.layout.shape[1];
const bool with_fix_32_mask =
param().fastImpl == Param::FastImpl::FIX_32_MASK;
if (param().format == Param::Format::NCHW4) {
megdnn_assert(dst.layout.ndim == 5 && dst.layout.shape[4] == 4,
"dst must be nchw4");
oc = oc * 4;
}
megdnn_assert(!with_fix_32_mask || (with_fix_32_mask && oc == 32),
"only support specify mask");
megdnn_assert(param().dct_block_size == 8, "only support dct block = 8");
auto error_info =
concrete_handle(this->handle())->megcore_context().error_info;
constexpr int dct_block = 8;
const int* mask_offset_ptr = nullptr;
const int* mask_val_ptr = nullptr;
if (mask_offset.layout.ndim == 1 && mask_offset.layout.shape[0] >= 2) {
mask_offset_ptr = mask_offset.ptr<int32_t>();
mask_val_ptr = mask_val.ptr<int32_t>();
}
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) {
megdnn_assert(param().format == Param::Format::NCHW,
"fp32 only support nchw");
dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW>(
src.ptr<uint8_t>(), dst.ptr<float>(), in, ic, ih, iw, oc,
with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream,
error_info, m_error_tracker);
} else {
megdnn_assert(dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8,
"only support fp32 and qs8");
megdnn_assert(param().format == Param::Format::NCHW4,
"qint8 only support nchw4");
dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW4>(
src.ptr<uint8_t>(), (int8_t*)dst.raw_ptr, in, ic, ih, iw, oc,
with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream,
error_info, m_error_tracker,
dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale);
}
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/dct/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class DctChannelSelectForwardImpl : public DctChannelSelectForward {
public:
using DctChannelSelectForward::DctChannelSelectForward;
void* m_error_tracker = nullptr;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& /*src*/,
const TensorLayout& /*mask_offset*/,
const TensorLayout& /*mask_val*/,
const TensorLayout& /*dst*/) {
return 0;
};
void set_error_tracker(void* tracker) override {
m_error_tracker = tracker;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "src/cuda/convpooling/opr_impl.h" #include "src/cuda/convpooling/opr_impl.h"
#include "src/cuda/cumsum/opr_impl.h" #include "src/cuda/cumsum/opr_impl.h"
#include "src/cuda/cvt_color/opr_impl.h" #include "src/cuda/cvt_color/opr_impl.h"
#include "src/cuda/dct/opr_impl.h"
#include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h"
#include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" #include "src/cuda/deformable_ps_roi_pooling/opr_impl.h"
#include "src/cuda/dot/opr_impl.h" #include "src/cuda/dot/opr_impl.h"
......
/**
* \file dnn/src/naive/dct/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <cmath>
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"
#include "midout.h"
#include "src/naive/dct/opr_impl.h"
#include "src/naive/handle.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"
MIDOUT_DECL(megdnn_naive_dct_fwd)
namespace megdnn {
namespace naive {
namespace {
static inline void generate_c_matrix(float* result, int block) {
constexpr float pi = M_PI;
for (int i = 0; i < block; ++i) {
for (int j = 0; j < block; ++j) {
float alpha = i == 0 ? sqrt(1.f / static_cast<float>(block))
: sqrt(2.f / static_cast<float>(block));
result[i * block + j] = alpha * cos((2.f * j + 1.f) * i * pi /
static_cast<float>(2 * block));
}
}
}
template <typename T>
void matmul(int m, int n, int k, int lda, int ldb, int ldc, const float* a,
const T* b, float* c, bool trans_a, bool trans_b) {
for (int m_idx = 0; m_idx < m; ++m_idx) {
for (int n_idx = 0; n_idx < n; ++n_idx) {
float res = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
float av = trans_a ? a[k_idx * lda + m_idx]
: a[m_idx * lda + k_idx];
float bv = trans_b ? b[n_idx * ldb + k_idx]
: b[k_idx * ldb + n_idx];
res += av * bv;
}
c[m_idx * ldc + n_idx] = res;
}
}
}
std::vector<std::vector<int>> mask_offset_to_2dmask(
_megdnn_tensor_in mask_offset, _megdnn_tensor_in mask_val) {
std::vector<std::vector<int>> mask;
if (mask_offset.layout.ndim > 0 && mask_offset.layout[0] >= 2) {
const int offset_len = mask_offset.layout.shape[0];
const int32_t* mask_offset_ptr = mask_offset.ptr<int32_t>();
const int32_t* mask_val_ptr = mask_val.ptr<int32_t>();
megdnn_assert(
mask_val.layout.shape[0] ==
static_cast<size_t>(mask_offset_ptr[offset_len - 1]),
"check mask offset %zu != %zu", mask_val.layout.shape[0],
static_cast<size_t>(mask_offset_ptr[offset_len - 1]));
for (int offset_idx = 1; offset_idx < offset_len; ++offset_idx) {
mask.push_back({});
const int mask_len = mask_offset_ptr[offset_idx] -
mask_offset_ptr[offset_idx - 1];
const int32_t* mask_ptr =
&mask_val_ptr[mask_offset_ptr[offset_idx - 1]];
for (int val_idx = 0; val_idx < mask_len; ++val_idx) {
mask[offset_idx - 1].push_back(mask_ptr[val_idx]);
}
}
}
return mask;
}
inline bool is_layout_nchw4(const TensorLayout& layout) {
if (layout.ndim == 5 && layout[4] == 4) {
return true;
} else {
return false;
}
}
template <typename T>
using QuantizedCType =
std::enable_if_t<DTypeTrait<T>::category == DTypeCategory::QUANTIZED,
typename DTypeTrait<T>::ctype>;
inline int8_t quant_float_2_int8(float val, DType dtype) {
return dtype.param<::megdnn::dtype::QuantizedS8>().quantize(val).as_int8();
}
template <param::DctChannelSelect::Format format, typename Dtype>
inline void dct_output(Dtype* dst_ptr, const int oc_idx, const int img_size,
float val, DType) {
dst_ptr[oc_idx * img_size] = val;
}
template <>
inline void dct_output<param::DctChannelSelect::Format::NCHW4>(
int8_t* dst_ptr, const int oc_idx, const int img_size, float val,
DType dtype) {
dst_ptr[oc_idx / 4 * 4 * img_size + oc_idx % 4] =
quant_float_2_int8(val, dtype);
}
template <param::DctChannelSelect::Format format>
struct ChannleBlock {
static constexpr int block = 1;
};
template <>
struct ChannleBlock<param::DctChannelSelect::Format::NCHW4> {
static constexpr int block = 4;
};
template <param::DctChannelSelect::Format format, typename Dtype>
void naive_dct(const uint8_t* src, Dtype* dst, int n, int c, int h, int w,
int block, const std::vector<std::vector<int>>& mask,
DType dtype) {
constexpr int block_channel = ChannleBlock<format>::block;
const int block_h = block;
const int block_w = block;
std::vector<float> c_matrix(block * block);
std::vector<float> tmp(block * block);
std::vector<float> tmp_result(block * block);
generate_c_matrix(&c_matrix[0], block);
megdnn_assert(h % block_h == 0, "h mod block_h == 0");
megdnn_assert(w % block_w == 0, "w mod block_w == 0");
const int oh = h / block_h;
const int ow = w / block_w;
const int o_img_size = oh * ow;
std::vector<int> mask_offset;
int mask_len_sum = 0;
if (mask.size() > 0) {
for (auto& sub_mask : mask) {
mask_offset.push_back(mask_len_sum);
mask_len_sum += sub_mask.size();
}
} else {
for (int c_idx = 0; c_idx < c; ++c_idx) {
mask_offset.push_back(mask_len_sum);
mask_len_sum += block_h * block_w;
}
}
const size_t o_batch_stride = mask_len_sum * oh * ow;
for (int n_idx = 0; n_idx < n; ++n_idx) {
for (int c_idx = 0; c_idx < c; ++c_idx) {
megdnn_assert(mask_offset[c_idx] % block_channel == 0,
"%d mod %d == 0", mask_offset[c_idx], block_channel);
const size_t src_offset = n_idx * c * h * w + c_idx * h * w;
const uint8_t* src_channel = src + src_offset;
const size_t dst_offset = n_idx * o_batch_stride +
mask_offset[c_idx] / block_channel * oh *
ow * block_channel;
Dtype* dst_channel = dst + dst_offset;
for (int oh_idx = 0; oh_idx < oh; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow; ++ow_idx) {
matmul(block, block, block, block, w, block, &c_matrix[0],
&src_channel[oh_idx * block_h * w +
ow_idx * block_w],
&tmp[0], false, false);
matmul(block, block, block, block, block, block, &tmp[0],
&c_matrix[0], &tmp_result[0], false, true);
Dtype* dst_start = dst_channel +
(oh_idx * ow + ow_idx) * block_channel;
if (mask.size() == 0) {
for (int inner_h_idx = 0; inner_h_idx < block_h;
++inner_h_idx) {
for (int inner_w_idx = 0; inner_w_idx < block_w;
++inner_w_idx) {
const int oc_idx =
inner_h_idx * block_w + inner_w_idx;
dct_output<format>(
dst_start, oc_idx, o_img_size,
tmp_result[inner_h_idx * block +
inner_w_idx],
dtype);
}
}
} else {
//! with mask
auto& sub_mask = mask[c_idx];
int dst_offset = 0;
for (auto mask_idx : sub_mask) {
dct_output<format>(dst_start, dst_offset,
o_img_size, tmp_result[mask_idx],
dtype);
++dst_offset;
}
}
}
}
}
}
}
} // namespace
void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val,
_megdnn_tensor_out dst,
_megdnn_workspace /*workspace*/) {
MIDOUT_BEGIN(megdnn_naive_dct_fwd) {
int in = src.layout.shape[0];
int ic = src.layout.shape[1];
int ih = src.layout.shape[2];
int iw = src.layout.shape[3];
megdnn_assert(dst.raw_ptr, "dst can not be nullptr");
const int block = param().dct_block_size;
auto mask = mask_offset_to_2dmask(mask_offset, mask_val);
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) {
megdnn_assert(!is_layout_nchw4(dst.layout) &&
param().format == Param::Format::NCHW,
"dst must be nchw");
MEGDNN_DISPATCH_CPU_KERN_OPR(naive_dct<Param::Format::NCHW>(
src.ptr<uint8_t>(), dst.ptr<float>(), in, ic, ih, iw, block,
mask, dst.layout.dtype));
} else {
megdnn_assert(dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8,
"dst must be q8");
megdnn_assert(is_layout_nchw4(dst.layout) &&
param().format == Param::Format::NCHW4,
"dst must be nchw4");
MEGDNN_DISPATCH_CPU_KERN_OPR(naive_dct<Param::Format::NCHW4>(
src.ptr<uint8_t>(), static_cast<int8_t*>(dst.raw_ptr), in,
ic, ih, iw, block, mask, dst.layout.dtype));
}
}
MIDOUT_END();
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/dct/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class DctChannelSelectForwardImpl : public DctChannelSelectForward {
public:
using DctChannelSelectForward::DctChannelSelectForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& /*src*/,
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
};
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/naive/handle.h" #include "src/naive/handle.h"
...@@ -29,6 +30,7 @@ ...@@ -29,6 +30,7 @@
#include "src/naive/convpooling/opr_impl.h" #include "src/naive/convpooling/opr_impl.h"
#include "src/naive/cumsum/opr_impl.h" #include "src/naive/cumsum/opr_impl.h"
#include "src/naive/cvt_color/opr_impl.h" #include "src/naive/cvt_color/opr_impl.h"
#include "src/naive/dct/opr_impl.h"
#include "src/naive/deformable_conv/opr_impl.h" #include "src/naive/deformable_conv/opr_impl.h"
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h" #include "src/naive/deformable_ps_roi_pooling/opr_impl.h"
#include "src/naive/dot/opr_impl.h" #include "src/naive/dot/opr_impl.h"
...@@ -56,6 +58,7 @@ ...@@ -56,6 +58,7 @@
#include "src/naive/reduce/opr_impl.h" #include "src/naive/reduce/opr_impl.h"
#include "src/naive/relayout/opr_impl.h" #include "src/naive/relayout/opr_impl.h"
#include "src/naive/relayout_format/opr_impl.h" #include "src/naive/relayout_format/opr_impl.h"
#include "src/naive/remap/opr_impl.h"
#include "src/naive/repeat/opr_impl.h" #include "src/naive/repeat/opr_impl.h"
#include "src/naive/resize/opr_impl.h" #include "src/naive/resize/opr_impl.h"
#include "src/naive/rng/opr_impl.h" #include "src/naive/rng/opr_impl.h"
...@@ -76,7 +79,6 @@ ...@@ -76,7 +79,6 @@
#include "src/naive/warp_affine/opr_impl.h" #include "src/naive/warp_affine/opr_impl.h"
#include "src/naive/warp_perspective/opr_impl.h" #include "src/naive/warp_perspective/opr_impl.h"
#include "src/naive/winograd_filter_preprocess/opr_impl.h" #include "src/naive/winograd_filter_preprocess/opr_impl.h"
#include "src/naive/remap/opr_impl.h"
static size_t g_image2d_pitch_alignment = 1; static size_t g_image2d_pitch_alignment = 1;
......
...@@ -6,20 +6,21 @@ ...@@ -6,20 +6,21 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include <map> #include <map>
#include <memory> #include <memory>
#include <vector>
#include <regex> #include <regex>
#include <vector>
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"
#include "test/common/opr_algo_proxy.h"
#include "test/common/opr_proxy.h" #include "test/common/opr_proxy.h"
#include "test/common/rng.h" #include "test/common/rng.h"
#include "test/common/timer.h" #include "test/common/timer.h"
#include "test/common/opr_algo_proxy.h"
namespace megdnn { namespace megdnn {
namespace test { namespace test {
...@@ -31,6 +32,7 @@ public: ...@@ -31,6 +32,7 @@ public:
using TensorValueArray = TensorNDArray; using TensorValueArray = TensorNDArray;
using BeforeExecCallback = using BeforeExecCallback =
std::function<void(Opr*, const TensorValueArray&)>; std::function<void(Opr*, const TensorValueArray&)>;
using TensorsConstriant = std::function<void(TensorValueArray& tensors)>;
BenchmarkerBase(Handle* handle, T timer) BenchmarkerBase(Handle* handle, T timer)
: m_timer(timer), : m_timer(timer),
...@@ -54,6 +56,8 @@ public: ...@@ -54,6 +56,8 @@ public:
} }
float exec(TensorLayoutArray layouts); float exec(TensorLayoutArray layouts);
float exect(const TensorValueArray& testcase_in);
//! disabiguate overloaded exec //! disabiguate overloaded exec
float execs(const TensorShapeArray& shapes) { return exec(shapes); } float execs(const TensorShapeArray& shapes) { return exec(shapes); }
float execl(const TensorLayoutArray& layouts) { return exec(layouts); } float execl(const TensorLayoutArray& layouts) { return exec(layouts); }
...@@ -73,6 +77,11 @@ public: ...@@ -73,6 +77,11 @@ public:
m_fmt[idx] = fmt; m_fmt[idx] = fmt;
return *this; return *this;
} }
BenchmarkerBase& set_tensors_constraint(
const TensorsConstriant& tensor_constraint) {
m_tensor_constraint = tensor_constraint;
return *this;
}
TensorLayoutArray make_layouts(const TensorShapeArray& shapes) { TensorLayoutArray make_layouts(const TensorShapeArray& shapes) {
TensorLayoutArray layouts(shapes.size()); TensorLayoutArray layouts(shapes.size());
for (size_t i = 0; i < shapes.size(); ++i) { for (size_t i = 0; i < shapes.size(); ++i) {
...@@ -142,6 +151,7 @@ private: ...@@ -142,6 +151,7 @@ private:
std::unique_ptr<OprProxy<Opr>> m_proxy; std::unique_ptr<OprProxy<Opr>> m_proxy;
BeforeExecCallback m_before_exec_callback; BeforeExecCallback m_before_exec_callback;
std::unique_ptr<Opr> m_opr; std::unique_ptr<Opr> m_opr;
TensorsConstriant m_tensor_constraint;
}; };
template <typename Opr, typename T> template <typename Opr, typename T>
...@@ -184,10 +194,16 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) { ...@@ -184,10 +194,16 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) {
auto rng = m_rng[i]; auto rng = m_rng[i];
if (!rng) if (!rng)
rng = m_default_rng.get(); rng = m_default_rng.get();
auto size = tensor.layout.span().high_byte;
rng->gen(tensor); rng->gen(tensor);
}
if (m_tensor_constraint) {
m_tensor_constraint(tensors_cur_host);
}
for (size_t i = 0; i < tensors_cur_host.size(); ++i) {
TensorND& tensor = tensors_cur_host[i];
if (tensor.layout.ndim == 0) if (tensor.layout.ndim == 0)
continue; continue;
auto size = tensor.layout.span().high_byte;
megdnn_memcpy_H2D(m_handle, tensors_cur[i].raw_ptr, tensor.raw_ptr, megdnn_memcpy_H2D(m_handle, tensors_cur[i].raw_ptr, tensor.raw_ptr,
size); size);
} }
...@@ -243,6 +259,105 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) { ...@@ -243,6 +259,105 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) {
return time_in_ms; return time_in_ms;
} }
template <typename Opr, typename T>
float BenchmarkerBase<Opr, T>::exect(const TensorValueArray& testcase_in) {
auto opr = this->opr();
opr->param() = m_param;
TensorLayoutArray layouts;
TensorNDArray tensors_cur_host;
for (auto& inp : testcase_in) {
layouts.push_back(inp.layout);
tensors_cur_host.emplace_back(inp);
}
auto user_layouts = layouts;
m_proxy->deduce_layout(opr, layouts);
for (size_t i = 0; i < layouts.size(); ++i)
if (user_layouts[i].ndim > 0) {
auto run = [&]() {
ASSERT_TRUE(layouts[i].eq_shape(user_layouts[i]))
<< "User provided shape is "
<< user_layouts[i].TensorShape::to_string()
<< "\nExpected shape is "
<< layouts[i].TensorShape::to_string();
};
run();
}
auto allocate = [&layouts](Handle* handle) {
TensorNDArray tensors(layouts.size());
auto trans_func = [handle](const TensorLayout& layout) {
auto span = layout.span();
TensorND res;
res.raw_ptr = static_cast<uint8_t*>(
megdnn_malloc(handle, span.dist_byte())) +
span.low_byte;
res.layout = layout;
return res;
};
std::transform(layouts.begin(), layouts.end(), tensors.begin(),
trans_func);
return tensors;
};
auto tensors_cur = allocate(m_handle);
//! init
for (size_t i = 0; i < tensors_cur_host.size(); ++i) {
TensorND& tensor = tensors_cur_host[i];
auto size = tensor.layout.span().high_byte;
if (tensor.layout.ndim == 0)
continue;
megdnn_memcpy_H2D(m_handle, tensors_cur[i].raw_ptr, tensor.raw_ptr,
size);
}
if (m_before_exec_callback) {
m_before_exec_callback(opr, tensors_cur);
}
//! run
//! warm up
m_proxy->exec(opr, tensors_cur);
megcoreSynchronize(m_handle->megcore_computing_handle());
if (m_adaptive_secs) {
//! find m_times for adaptive benchmarking
m_times = 0;
int cur_times = 1;
auto remain_time = m_adaptive_secs * 1e6;
while (remain_time > 0) {
m_timer.reset();
m_timer.start();
for (int i = 0; i < cur_times; ++i)
m_proxy->exec(opr, tensors_cur);
megcoreSynchronize(m_handle->megcore_computing_handle());
m_timer.stop();
m_times += cur_times;
auto this_run_time = m_timer.get_time_in_us();
remain_time -= this_run_time;
cur_times = std::min(
cur_times * 2,
std::max<int>(1, remain_time / this_run_time * cur_times));
}
}
m_timer.reset();
m_timer.start();
for (size_t t = 0; t < m_times; ++t)
m_proxy->exec(opr, tensors_cur);
megcoreSynchronize(m_handle->megcore_computing_handle());
m_timer.stop();
auto time_in_ms = m_timer.get_time_in_us() / 1e3;
if (m_display) {
std::cout << "Total time is " << time_in_ms << "ms "
<< "for " << m_times << " run(s)." << std::endl;
}
auto free = [](Handle* handle, TensorNDArray& tensors) {
std::for_each(tensors.begin(), tensors.end(),
[handle](const TensorND& tensor) {
megdnn_free(handle, tensor.raw_ptr);
});
};
free(m_handle, tensors_cur);
if (m_adaptive_secs)
time_in_ms /= m_times;
return time_in_ms;
}
template <typename Opr, typename T = Timer> template <typename Opr, typename T = Timer>
class Benchmarker; class Benchmarker;
......
...@@ -294,8 +294,6 @@ void CheckerHelper::do_exec_with_testcases(const TensorValueArray& testcase_in, ...@@ -294,8 +294,6 @@ void CheckerHelper::do_exec_with_testcases(const TensorValueArray& testcase_in,
ASSERT_TRUE(testcase_in[i].layout.ndim == 0 || ASSERT_TRUE(testcase_in[i].layout.ndim == 0 ||
testcase_out[i].layout.ndim == 0 || testcase_out[i].layout.ndim == 0 ||
testcase_in[i].layout.eq_layout(testcase_out[i].layout)); testcase_in[i].layout.eq_layout(testcase_out[i].layout));
ASSERT_TRUE(testcase_in[i].layout.ndim != 0 ||
testcase_out[i].layout.ndim != 0);
layouts.emplace_back(testcase_in[i].layout.ndim > 0 layouts.emplace_back(testcase_in[i].layout.ndim > 0
? testcase_in[i].layout ? testcase_in[i].layout
: testcase_out[i].layout); : testcase_out[i].layout);
......
...@@ -392,7 +392,8 @@ TensorND TensorValue(const TensorShape& shape, T dtype, ...@@ -392,7 +392,8 @@ TensorND TensorValue(const TensorShape& shape, T dtype,
tensor.layout = {shape, dtype}; tensor.layout = {shape, dtype};
tensor.raw_ptr = tensor.raw_ptr =
static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte())); static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte()));
megdnn_assert(values.size() == tensor.layout.total_nr_elems()); megdnn_assert(values.size() == tensor.layout.total_nr_elems(), "%zu == %zu", values.size(),
tensor.layout.total_nr_elems());
auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>(); auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>();
for (const auto& v : values) { for (const auto& v : values) {
*ptr++ = typename DTypeTrait<T>::ctype(v); *ptr++ = typename DTypeTrait<T>::ctype(v);
......
/**
* \file
* dnn/test/common/dct_ref.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/common/dct_ref.h"
namespace megdnn {
namespace test {
struct FixCase {
std::vector<int> mask_offset;
std::vector<int> mask_val;
};
using Param = DctChannelSelectForward::Param;
static inline FixCase get_fix_mask(Param::FastImpl impl) {
std::vector<int> fix_32_mask_offset{0, 16, 24, 32};
std::vector<int> fix_32_mask_val{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32,
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
3, 10, 0, 1, 8, 16, 9, 2, 3, 10};
megdnn_assert(impl == Param::FastImpl::FIX_32_MASK,
"only support gen FIX_32_MASK");
return {fix_32_mask_offset, fix_32_mask_val};
}
CheckerHelper::TensorsConstriant gen_dct_constriant(
const size_t /* n */, const size_t ic, const size_t ih, const size_t iw,
const size_t oc, Param param) {
auto constraint = [=](CheckerHelper::TensorValueArray& tensors_orig) {
const size_t block = param.dct_block_size;
const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1;
megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0");
std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make();
DctTestcase& test_case = *test_case_ptr.get();
UniformIntRNG rng(0, 255);
UniformIntRNG mask_rng(0, 64 / block_c - 1);
const size_t no_mask_oc = ic * block * block;
megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block);
megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block);
TensorND mask_offset;
TensorND mask_val;
std::vector<int>& mask_offset_vec = test_case.mask_offset_vec;
std::vector<int>& mask_val_vec = test_case.mask_val_vec;
UniformIntRNG rng_oc(0, oc);
if (param.fastImpl == Param::FastImpl::FIX_32_MASK) {
auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK);
mask_offset_vec = fix_32_mask.mask_offset;
mask_val_vec = fix_32_mask.mask_val;
megdnn_assert(oc == 32, "oc must eq 32");
} else if (no_mask_oc > oc) {
size_t remain_oc = oc;
mask_offset_vec.resize(ic + 1);
mask_val_vec.resize(oc);
mask_offset_vec[0] = 0;
for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) {
size_t random_len = (int)rng_oc.gen_single_val() * block_c;
size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0)
? remain_oc
: random_len % remain_oc;
megdnn_assert(mask_len % block_c == 0,
"mask_len mod block_c == 0, but %zu mod %d ",
mask_len, block_c);
const size_t oc_idx = mask_offset_vec[ic_idx];
remain_oc -= mask_len;
mask_offset_vec[ic_idx + 1] = oc_idx + mask_len;
for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) {
mask_val_vec[oc_idx + mask_idx] =
(int)mask_rng.gen_single_val();
}
}
}
mask_offset = TensorND(mask_offset_vec.data(),
{{mask_offset_vec.size()}, dtype::Int32()});
mask_val = TensorND(mask_val_vec.data(),
{{mask_val_vec.size()}, dtype::Int32()});
if (tensors_orig.size() > 1) {
megdnn_assert(tensors_orig.size() == 4, "tensors_orig.size() == 4");
megdnn_assert(mask_offset_vec.size() >= 2,
"mask_offset_vec.size() >= 2");
megdnn_assert(tensors_orig[1].layout == mask_offset.layout,
"tensors_orig[1].layout == mask_offset.layout");
megdnn_assert(tensors_orig[2].layout == mask_val.layout,
"tensors_orig[2].layout == mask_val.layout");
auto naive_handle = create_cpu_handle(2, false);
megdnn_memcpy_D2D(naive_handle.get(), tensors_orig[1].raw_ptr,
mask_offset.raw_ptr,
mask_offset.layout.span().dist_byte());
megdnn_memcpy_D2D(naive_handle.get(), tensors_orig[2].raw_ptr,
mask_val.raw_ptr,
mask_val.layout.span().dist_byte());
}
};
return constraint;
}
std::shared_ptr<DctTestcase> gen_dct_case(const size_t n, const size_t ic,
const size_t ih, const size_t iw,
const size_t oc, Param param,
DType dst_dtype,
bool correct_result) {
const size_t block = param.dct_block_size;
const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1;
megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0");
std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make();
DctTestcase& test_case = *test_case_ptr.get();
UniformIntRNG rng(0, 255);
UniformIntRNG mask_rng(0, 64 / block_c - 1);
const size_t input_elements = n * ic * ih * iw;
const size_t no_mask_oc = ic * block * block;
megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block);
megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block);
std::vector<uint8_t>& inp_vec = test_case.inp_vec;
inp_vec.resize(input_elements);
TensorShape input_shape{n, ic, ih, iw};
for (auto& elm : inp_vec) {
elm = (uint8_t)rng.gen_single_val();
}
auto src = TensorND(inp_vec.data(), {input_shape, dtype::Uint8()});
TensorND mask_offset;
TensorND mask_val;
std::vector<int>& mask_offset_vec = test_case.mask_offset_vec;
std::vector<int>& mask_val_vec = test_case.mask_val_vec;
UniformIntRNG rng_oc(0, oc);
if (param.fastImpl == Param::FastImpl::FIX_32_MASK) {
auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK);
mask_offset_vec = fix_32_mask.mask_offset;
mask_val_vec = fix_32_mask.mask_val;
megdnn_assert(oc == 32, "oc must eq 32");
} else if (no_mask_oc > oc) {
size_t remain_oc = oc;
mask_offset_vec.resize(ic + 1);
mask_val_vec.resize(oc);
mask_offset_vec[0] = 0;
for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) {
size_t random_len = (int)rng_oc.gen_single_val() * block_c;
size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0)
? remain_oc
: random_len % remain_oc;
megdnn_assert(mask_len % block_c == 0,
"mask_len mod block_c == 0, but %zu mod %d ",
mask_len, block_c);
const size_t oc_idx = mask_offset_vec[ic_idx];
remain_oc -= mask_len;
mask_offset_vec[ic_idx + 1] = oc_idx + mask_len;
for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) {
mask_val_vec[oc_idx + mask_idx] =
(int)mask_rng.gen_single_val();
}
}
}
mask_offset = TensorND(mask_offset_vec.data(),
{{mask_offset_vec.size()}, dtype::Int32()});
mask_val = TensorND(mask_val_vec.data(),
{{mask_val_vec.size()}, dtype::Int32()});
if (mask_offset_vec.size() >= 2) {
test_case.testcase_in = {
src, mask_offset, mask_val, {nullptr, {{}, dst_dtype}}};
} else {
test_case.testcase_in = {src, {}, {}, {nullptr, {{}, dst_dtype}}};
}
auto naive_handle = create_cpu_handle(2, false);
auto opr_naive = naive_handle->create_operator<DctChannelSelectForward>();
opr_naive->param() = param;
using Proxy = OprProxy<DctChannelSelectForward>;
Proxy naive_proxy;
TensorLayout temp_dst_layout;
temp_dst_layout.dtype = dst_dtype;
TensorLayoutArray layouts{src.layout, mask_offset.layout, mask_val.layout,
temp_dst_layout};
naive_proxy.deduce_layout(opr_naive.get(), layouts);
const size_t output_elements = layouts[3].total_nr_elems();
std::vector<float>& output_vec = test_case.output_vec;
output_vec.resize(output_elements);
auto dst = TensorND(output_vec.data(), layouts[3]);
DctTestcase::TensorValueArray testcase_naive;
testcase_naive.emplace_back(test_case.testcase_in[0]);
testcase_naive.emplace_back(test_case.testcase_in[1]);
testcase_naive.emplace_back(test_case.testcase_in[2]);
testcase_naive.emplace_back(dst);
if (correct_result) {
naive_proxy.exec(opr_naive.get(), testcase_naive);
}
test_case.testcase_out = {{}, {}, {}, dst};
return test_case_ptr;
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file
* dnn/test/common/dct_ref.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <math.h>
#include <vector>
#include "megdnn/dtype.h"
#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"
#include "test/common/opr_proxy.h"
#include "test/common/rng.h"
namespace megdnn {
namespace test {
using Param = DctChannelSelectForward::Param;
struct DctTestcase {
using TensorValueArray = TensorNDArray;
TensorValueArray testcase_in;
TensorValueArray testcase_out;
std::vector<uint8_t> inp_vec;
std::vector<int> mask_offset_vec;
std::vector<int> mask_val_vec;
std::vector<float> output_vec;
static std::shared_ptr<DctTestcase> make() {
return std::make_shared<DctTestcase>();
}
};
CheckerHelper::TensorsConstriant gen_dct_constriant(
const size_t n, const size_t ic, const size_t ih, const size_t iw,
const size_t oc, Param param);
std::shared_ptr<DctTestcase> gen_dct_case(const size_t n, const size_t ic,
const size_t ih, const size_t iw,
const size_t oc, Param param,
DType dst_dtype = dtype::Float32(),
bool correct_result = true);
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -110,6 +110,7 @@ DEF(BatchConvBiasForward, 5, true, true); ...@@ -110,6 +110,7 @@ DEF(BatchConvBiasForward, 5, true, true);
DEF(Remap, 3, true, true); DEF(Remap, 3, true, true);
DEF(RemapBackwardData, 3, true, false); DEF(RemapBackwardData, 3, true, false);
DEF(RemapBackwardMat, 4, true, false); DEF(RemapBackwardMat, 4, true, false);
DEF(DctChannelSelectForward, 4, true, true);
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn
......
/**
* \file dnn/test/cuda/dct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs/nn.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/dct_ref.h"
#include "test/common/rng.h"
#include "test/cuda/fixture.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, DCT) {
DctChannelSelectForward::Param param;
Checker<DctChannelSelectForward> checker(handle_cuda());
for (size_t n : {1, 3}) {
for (size_t ic : {1, 3}) {
for (size_t ih : {8, 16, 32, 512, 1024}) {
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
checker.set_param(param)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Int32())
.set_dtype(2, dtype::Int32())
.execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
}
}
}
}
}
TEST_F(CUDA, DCT_QINT8) {
DctChannelSelectForward::Param param;
Checker<DctChannelSelectForward> checker(handle_cuda());
param.format = Param::Format::NCHW4;
for (size_t n : {1, 3}) {
for (size_t ic : {1, 3}) {
for (size_t ih : {8, 16, 32, 512, 1024}) {
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
checker.set_param(param)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Int32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::QuantizedS8(10.f))
.set_epsilon(1)
.execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
}
}
}
}
}
TEST_F(CUDA, DCT_WITH_FIX_32_MASK) {
using Param = DctChannelSelectForward::Param;
Param param;
Checker<DctChannelSelectForward> checker(handle_cuda(), false);
param.fastImpl = Param::FastImpl::FIX_32_MASK;
auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param);
checker.set_param(param).exect(test_case->testcase_in,
test_case->testcase_out);
}
TEST_F(CUDA, DCT_WITH_FIX_32_MASK_QINT8) {
using Param = DctChannelSelectForward::Param;
Param param;
Checker<DctChannelSelectForward> checker(handle_cuda(), false);
param.fastImpl = Param::FastImpl::FIX_32_MASK;
param.format = Param::Format::NCHW4;
auto test_case =
gen_dct_case(3, 3, 1024, 768, 32, param, dtype::QuantizedS8(10.f));
checker.set_param(param).set_epsilon(1).exect(test_case->testcase_in,
test_case->testcase_out);
}
TEST_F(CUDA, DCT_WITH_MASK) {
Checker<DctChannelSelectForward> checker(handle_cuda(), false);
DctChannelSelectForward::Param param;
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 3, 8, 16}, dtype::Uint8(),
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204}),
TensorValue({4}, dtype::Int32(), {0, 14, 22, 30}),
TensorValue({30}, dtype::Int32(),
{8, 16, 9, 2, 3, 10, 17, 24, 32, 25,
18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
{}},
Testcase{{},
{},
{},
TensorValue({1, 30, 1, 2}, dtype::Float32(),
{-22.850792, -97.862236, -101.043236,
-4.727012, 28.275675, -157.96654,
42.1377, 45.06531, -149.77373,
24.487143, -8.054966, -13.990831,
-6.9395194, -3.9211385, 64.79172,
-12.363858, -47.875, 59.,
56.271786, -62.725567, 120.522675,
16.559765, 85.74334, 112.904495,
99.375, 29.499973, 2.0220923,
-19.681704, 890.12494, 941.25,
-7.0498576, 99.47632, -22.850792,
-97.862236, -101.043236, -4.727012,
28.275675, -157.96654, 42.1377,
45.06531, -149.77373, 24.487143,
-8.054966, -13.990831, 890.12494,
941.25, -7.0498576, 99.47632,
-22.850792, -97.862236, -101.043236,
-4.727012, 28.275675, -157.96654,
42.1377, 45.06531, -149.77373,
24.487143, -8.054966, -13.990831})});
}
TEST_F(CUDA, DCT_WITH_MASK2) {
Checker<DctChannelSelectForward> checker(handle_cuda(), false);
DctChannelSelectForward::Param param;
UniformIntRNG rng_oc(0, 3 * 64);
for (size_t n : {1, 3}) {
for (size_t ic : {1, 3}) {
for (size_t ih : {8, 16, 32, 512, 1024}) {
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
int random_oc = static_cast<int>(rng_oc.gen_single_val());
int max_oc = ic * 64;
int mask_oc = (random_oc % max_oc) + 1;
auto test_case =
gen_dct_case(n, ic, ih, iw, mask_oc, param);
checker.set_param(param).exect(test_case->testcase_in,
test_case->testcase_out);
}
}
}
}
}
TEST_F(CUDA, DCT_WITH_MASK2_QINT8) {
Checker<DctChannelSelectForward> checker(handle_cuda(), false);
DctChannelSelectForward::Param param;
param.format = DctChannelSelectForward::Param::Format::NCHW4;
UniformIntRNG rng_oc(0, 3 * 64);
for (size_t n : {1, 3}) {
for (size_t ic : {1, 3}) {
for (size_t ih : {8, 16, 32, 512, 1024}) {
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
int random_oc = static_cast<int>(rng_oc.gen_single_val());
int max_oc = ic * 64;
int mask_oc = (random_oc % max_oc) + 1;
mask_oc = (mask_oc + 3) / 4 * 4;
auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param,
dtype::QuantizedS8(10.f));
checker.set_param(param).set_epsilon(1).exect(
test_case->testcase_in, test_case->testcase_out);
}
}
}
}
}
TEST_F(CUDA, DCT_WITH_MASK2_QINT8_CONSTRAINT) {
DctChannelSelectForward::Param param;
param.format = DctChannelSelectForward::Param::Format::NCHW4;
Checker<DctChannelSelectForward> checker(handle_cuda(), false);
checker.set_param(param)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Int32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::QuantizedS8(10.f))
.set_epsilon(1);
UniformIntRNG rng_oc(0, 3 * 64);
for (size_t n : {1, 3}) {
for (size_t ic : {1, 3}) {
for (size_t ih : {8, 16, 32, 512, 1024}) {
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
int random_oc = static_cast<int>(rng_oc.gen_single_val());
int max_oc = ic * 64;
int mask_oc = (random_oc % max_oc) + 1;
mask_oc = (mask_oc + 3) / 4 * 4;
if (mask_oc < max_oc) {
checker
.set_tensors_constraint(gen_dct_constriant(
n, ic, ih, iw, mask_oc, param))
.exec({TensorShape{n, ic, ih, iw},
TensorShape{ic + 1},
TensorShape{(size_t)mask_oc},
{}});
} else {
checker.set_tensors_constraint({}).exec(
{TensorShape{n, ic, ih, iw}, {}, {}, {}});
}
}
}
}
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_DCT) {
using Param = DctChannelSelectForward::Param;
auto run = [&](const TensorShapeArray& shapes, Param param) {
Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
benchmarker.set_param(param);
benchmarker.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Int32())
.set_dtype(2, dtype::Int32());
for (auto&& shape : shapes) {
double computation = double(shape[0]) * shape[1] * shape[2] *
shape[3] * 32.0 * 1e-6;
auto time_ms = benchmarker.execs({shape, {}, {}, {}});
printf("execute %s, %.4f Gops\n", shape.to_string().c_str(),
computation / time_ms);
}
};
auto run_case = [&](const DctTestcase& testcase, Param param,
std::string comment = "") {
Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
benchmarker.set_param(param);
benchmarker.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Int32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, testcase.testcase_out[3].layout.dtype);
auto src_shape = testcase.testcase_in[0].layout;
double computation = double(src_shape[0]) * src_shape[1] *
src_shape[2] * src_shape[3] * 32.0 * 1e-6;
auto time_ms = benchmarker.exect(testcase.testcase_in);
printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
src_shape.to_string().c_str(), computation / time_ms);
};
auto run_case_constraint =
[&](const Benchmarker<DctChannelSelectForward>::TensorsConstriant&
constraint,
Param param, const TensorShapeArray& shapes,
std::string comment = "", DType output_dtype) {
Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
benchmarker.set_param(param)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Int32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, output_dtype)
.set_tensors_constraint(constraint);
auto src_shape = shapes[0];
double computation = double(src_shape[0]) * src_shape[1] *
src_shape[2] * src_shape[3] * 32.0 * 1e-6;
auto time_ms = benchmarker.exec(shapes);
printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
src_shape.to_string().c_str(), computation / time_ms);
};
TensorShapeArray shapes = {
{1, 3, 512, 512},
{8, 3, 2176, 3840},
};
{
Param param;
run(shapes, param);
}
Param fix_32_param;
fix_32_param.fastImpl = Param::FastImpl::FIX_32_MASK;
{
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
run_case(*test_case, fix_32_param, "FIX_32_MASK");
}
{
Param param;
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
run_case(*test_case, param, "MASK 32");
}
{
Param fix_32_nchw4_param;
fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
fix_32_nchw4_param.format = Param::Format::NCHW4;
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param,
dtype::QuantizedS8(10.f));
run_case(*test_case, fix_32_nchw4_param, "FIX_32_MASK QINT8");
}
{
Param fix_32_nchw4_param;
fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
fix_32_nchw4_param.format = Param::Format::NCHW4;
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param,
dtype::QuantizedS8(10.f));
fix_32_nchw4_param.fastImpl = Param::FastImpl::NONE;
run_case(*test_case, fix_32_nchw4_param, "MASK 32 QINT8");
}
{
Param fix_32_nchw4_param;
fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
fix_32_nchw4_param.format = Param::Format::NCHW4;
TensorShapeArray shapes = {{8, 3, 2176, 3840}, {4}, {32}, {}};
auto constraint =
gen_dct_constriant(8, 3, 2176, 3840, 32, fix_32_nchw4_param);
run_case_constraint(constraint, fix_32_nchw4_param, shapes,
"FIX_32_MASK QINT8 Constraint",
dtype::QuantizedS8(10.f));
}
}
#endif
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/naive/dct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"
#include "test/common/dct_ref.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, DCT) {
Checker<DctChannelSelectForward> checker(handle(),
/* check_dispatch */ false);
DctChannelSelectForward::Param param;
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 1, 16, 16}, dtype::Uint8(),
{87, 155, 59, 161, 24, 200, 58, 3, 40, 43,
156, 7, 176, 232, 226, 78, 73, 236, 185, 109,
196, 169, 62, 32, 167, 180, 96, 157, 101, 53,
150, 47, 26, 238, 218, 210, 204, 236, 249, 111,
16, 35, 169, 204, 117, 16, 3, 147, 12, 233,
135, 162, 58, 118, 184, 237, 90, 105, 156, 195,
196, 104, 138, 19, 82, 62, 126, 140, 220, 171,
206, 232, 105, 123, 2, 135, 137, 41, 26, 219,
167, 245, 104, 103, 24, 144, 141, 210, 208, 114,
169, 170, 22, 11, 69, 106, 236, 150, 57, 184,
75, 241, 28, 175, 178, 186, 190, 124, 187, 116,
112, 162, 214, 154, 207, 31, 43, 40, 15, 188,
81, 197, 20, 199, 246, 132, 159, 111, 79, 95,
148, 184, 171, 173, 203, 146, 150, 33, 178, 9,
141, 49, 237, 222, 72, 5, 23, 38, 248, 82,
93, 229, 70, 180, 149, 232, 245, 72, 196, 138,
4, 31, 160, 30, 8, 109, 153, 252, 204, 126,
15, 182, 145, 130, 179, 234, 21, 240, 144, 105,
77, 116, 155, 232, 168, 99, 159, 92, 251, 223,
119, 173, 166, 39, 228, 91, 34, 5, 62, 172,
131, 164, 143, 10, 161, 165, 221, 214, 178, 110,
185, 254, 152, 149, 46, 144, 173, 237, 76, 210,
221, 45, 200, 113, 58, 20, 47, 135, 228, 80,
91, 51, 238, 194, 222, 231, 174, 244, 139, 96,
71, 25, 25, 62, 172, 181, 71, 27, 86, 0,
121, 38, 199, 236, 93, 158}),
{},
{},
{}},
Testcase{{},
{},
{},
TensorValue(
{1, 64, 2, 2}, dtype::Float32(),
{1.10687500e+03, 9.59500000e+02, 8.98125000e+02,
1.21912500e+03, 1.38846378e+01, 3.91629181e+01,
-1.50343018e+02, -1.02085358e+02, 2.34341068e+01,
-8.40960388e+01, -4.23510742e+01, 1.72630596e+01,
-4.66624413e+01, -4.87857285e+01, -7.06332016e+01,
6.31493912e+01, -9.96249924e+01, 7.72499924e+01,
7.46250153e+01, 5.81250114e+01, -9.07061768e+01,
-7.68266630e+00, -3.15778809e+01, -3.35406876e+01,
8.55864143e+00, -7.36760712e+01, 6.20557327e+01,
-2.92043419e+01, -1.39985870e+02, 2.56675129e+01,
5.21866226e+01, 1.07624054e+02, -6.16851950e+00,
-8.56008530e+01, 7.35654449e+01, -2.56767311e+01,
-2.09981880e+01, -6.22950821e+01, -1.31617493e+02,
-6.30962448e+01, -2.21552780e+02, -4.79528542e+01,
1.04179153e+02, 7.45253448e+01, 3.19730816e+01,
1.24306192e+01, -9.93905945e+01, -8.95680237e+01,
-1.44870041e+02, -9.44738235e+01, -4.09417763e+01,
4.50356903e+01, -3.65339231e+00, 5.79474449e+01,
-2.46253452e+01, 3.29394951e+01, -1.09065903e+02,
5.23808861e+01, -1.00386992e+01, -7.92311325e+01,
-1.44292374e+01, 5.74285736e+01, 2.28798485e+01,
6.84826508e+01, -1.49241837e+02, 9.35751495e+01,
-4.02763329e+01, -6.63586197e+01, 2.15622040e+02,
-7.83887939e+01, -8.06824951e+01, -2.51097183e+01,
1.58941059e+01, -5.66967869e+00, -1.53566467e+02,
-4.33494377e+01, 8.12108078e+01, 1.21169144e+02,
2.14673615e+02, -3.72018318e+01, 2.45811577e+01,
-1.27189613e+02, 4.98553581e+01, -5.83694696e+00,
-4.80477619e+00, -2.24601650e+01, -5.02191353e+00,
5.16259460e+01, 1.07266571e+02, -3.41748886e+01,
-5.44621315e+01, 6.25573196e+01, -4.24649086e+01,
4.42625465e+01, 2.71147366e+01, 4.83264275e+01,
-6.99711227e+01, -1.00299120e+01, 1.33173111e+02,
2.48003254e+01, -1.74687519e+01, 9.44530487e-01,
1.35930038e+02, 6.72219162e+01, 4.53297043e+01,
1.37072708e+02, -7.73253784e+01, 6.12967606e+01,
9.78184891e+01, 3.63894577e+01, -1.64039135e+01,
-6.67858887e+01, 5.27859840e+01, -4.99117432e+01,
8.77927475e+01, -5.86666260e+01, 3.86430244e+01,
2.17759323e+01, 8.34562683e+01, 3.06256886e+01,
1.61030369e+01, 8.11268158e+01, 1.36932516e+01,
-1.06112595e+02, -9.31621475e+01, 3.13674717e+01,
-4.90609503e+00, 7.96453857e+01, -1.02625000e+02,
1.40000076e+01, 3.18749981e+01, -1.08375000e+02,
-5.44420319e+01, -1.50944397e+02, 5.29974670e+01,
-1.44041641e+02, 4.86086197e+01, -7.13610382e+01,
3.06417294e+01, 7.20477829e+01, -6.95384140e+01,
1.25441925e+02, -1.54897385e+01, 3.78566666e+01,
4.23749886e+01, -3.37500000e+01, -9.96250000e+01,
-6.73750076e+01, 3.34241295e+01, -6.24825974e+01,
1.76387348e+01, -6.45708389e+01, 1.70728874e+01,
-5.73032570e+01, -1.71570969e+01, 1.84064590e+02,
4.17566071e+01, 7.08248520e+00, -2.59306641e+01,
1.37766739e+02, -2.16669798e+00, 6.03565750e+01,
6.84421844e+01, 6.19825096e+01, -1.44220114e+01,
-3.12404213e+01, -2.50061111e+01, 6.73021851e+01,
2.52050266e+01, -8.35850677e+01, -4.70746574e+01,
1.73889160e+01, 1.18955564e+01, 6.16792488e+00,
-3.29667168e+01, 4.55779572e+01, -4.17868996e+00,
-9.40233841e+01, -9.77727051e+01, 1.74934635e+01,
5.25992851e+01, 1.23662634e+01, 5.26129305e-01,
4.69518929e+01, -1.52657738e+01, 9.96897888e+01,
-9.51726151e+01, 9.99432602e+01, -1.75949844e+02,
1.00472336e+02, -5.89417953e+01, -1.72231483e+01,
1.89282093e+01, -8.17851868e+01, 7.22908936e+01,
-9.06294174e+01, 2.46093607e+00, -4.03946457e+01,
2.17710762e+01, -5.62999649e+01, 4.77665749e+01,
-4.04248848e+01, 4.78787374e+00, 1.05557320e+02,
-4.60584450e+01, -7.33774490e+01, -4.25107193e+01,
1.71907139e+01, -8.01314316e+01, 1.69647141e+01,
-8.24824219e+01, 8.29206543e+01, 3.72900200e+01,
3.77470016e+01, 6.70151443e+01, 1.79784470e+01,
-4.01441078e+01, 6.29196739e+01, 7.60664597e+01,
-5.59005699e+01, 8.81600475e+00, -6.89491081e+00,
-8.03825378e+01, -5.33856511e-01, 7.26196136e+01,
-3.76809120e+01, -1.08401566e+02, 6.35455990e+00,
-8.66767120e+01, -1.02679443e+02, -9.54313660e+00,
-3.55650787e+01, -1.21355652e+02, 2.32628040e+01,
3.94072838e+01, 1.24754738e+02, 9.51344986e+01,
-5.84752541e+01, -4.65028038e+01, 6.00556993e+00,
4.94889374e+01, 7.64868622e+01, -1.49546280e+01,
-3.70648766e+01, 5.55572205e+01, -1.17196434e+02,
9.20216217e+01, 3.29843826e+01, 3.25113411e+01,
5.62059135e+01, 6.30202141e+01, 4.99030991e+01,
2.85804024e+01, -1.44606361e+01, 7.64952774e+01,
-2.95697536e+01})});
}
TEST_F(NAIVE, DCT_INT8) {
Checker<DctChannelSelectForward> checker(handle(),
/* check_dispatch */ false);
DctChannelSelectForward::Param param;
param.format = DctChannelSelectForward::Param::Format::NCHW4;
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 1, 16, 16}, dtype::Uint8(),
{113, 223, 229, 159, 249, 252, 89, 84, 45, 16,
41, 72, 184, 236, 70, 184, 86, 172, 218, 211,
47, 177, 18, 85, 174, 226, 37, 109, 38, 135,
228, 195, 133, 238, 47, 246, 244, 118, 175, 143,
34, 10, 28, 4, 82, 103, 89, 55, 235, 78,
151, 178, 249, 62, 183, 84, 105, 0, 121, 98,
249, 90, 161, 114, 121, 241, 21, 199, 196, 119,
231, 209, 250, 180, 192, 213, 116, 105, 114, 169,
1, 142, 3, 30, 140, 245, 201, 109, 19, 26,
224, 68, 123, 228, 64, 150, 184, 212, 136, 172,
241, 152, 222, 233, 15, 72, 130, 144, 107, 130,
242, 79, 195, 46, 226, 57, 183, 36, 88, 161,
121, 170, 2, 215, 109, 212, 35, 18, 76, 197,
117, 81, 208, 8, 237, 75, 15, 20, 16, 192,
61, 113, 96, 126, 211, 57, 49, 62, 185, 211,
155, 87, 233, 163, 164, 84, 61, 28, 1, 11,
190, 253, 145, 30, 38, 98, 153, 56, 231, 152,
12, 204, 96, 8, 47, 87, 25, 237, 21, 150,
173, 19, 41, 175, 164, 231, 39, 145, 39, 187,
210, 123, 165, 98, 87, 242, 38, 136, 182, 145,
41, 47, 147, 171, 172, 35, 170, 148, 26, 89,
107, 151, 130, 232, 65, 217, 27, 206, 68, 219,
60, 106, 3, 209, 175, 189, 191, 32, 119, 141,
56, 48, 105, 58, 94, 163, 185, 60, 83, 249,
112, 245, 137, 60, 178, 51, 177, 106, 199, 209,
4, 247, 3, 127, 88, 46}),
{},
{},
{}},
Testcase{{},
{},
{},
TensorValue(
{1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f),
{122, -1, -8, 4, 92, -13, -5, 7, 99, 4,
5, 3, 89, 7, 2, -6, 3, -8, -10, 2,
-1, 0, 4, -3, -5, -8, -11, 1, 14, 4,
-10, -18, 3, 12, -14, -2, -4, -9, 12, 4,
-2, -2, 2, 6, -9, 6, 1, 5, -5, -1,
2, -12, 4, -5, -0, 4, 1, 5, -8, 5,
-3, 4, 2, 6, -0, 9, -4, -7, -4, -5,
-2, 8, 2, 4, 0, 7, -8, 4, -2, 3,
-6, -5, 19, 5, -4, -4, -5, -16, -8, -3,
-5, 19, 4, 3, 4, -6, 1, -12, -1, 7,
11, -5, -1, -8, 2, -12, -9, -2, -4, -20,
-11, -15, -15, -9, -2, -9, -2, -3, 13, 2,
5, 6, 7, -4, 1, -7, 6, 4, 2, 6,
0, -0, 8, 8, -6, 5, 1, -2, -2, -12,
2, -12, -2, 6, 7, 3, 4, 14, 14, -3,
1, -3, 6, 0, -20, 2, -10, 10, -5, -5,
13, 0, -3, 7, -12, -17, -13, 1, -6, 10,
-1, -9, 4, -16, 3, 2, 5, 1, -4, 9,
-0, 1, 3, 15, -4, -13, -6, 4, 3, -2,
-1, -4, -7, -7, -2, 8, -16, -4, -10, 5,
1, -3, 2, -9, -4, 1, -1, -1, -4, -6,
-4, 1, 0, -9, 15, -1, -7, -3, -5, -0,
3, -0, -6, -17, 16, -3, 3, -2, -3, 5,
3, -2, 3, 13, 8, 1, -3, -8, -7, -4,
6, -6, -15, -7, 0, 4, -3, -3, -10, 14,
1, 3, 14, 4, -1, 14})});
}
TEST_F(NAIVE, DCT_INT8_MASK) {
Checker<DctChannelSelectForward> checker(handle(),
/* check_dispatch */ false);
DctChannelSelectForward::Param param;
param.format = DctChannelSelectForward::Param::Format::NCHW4;
auto src_tensor = TensorValue(
{1, 3, 8, 16}, dtype::Uint8(),
{195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37, 250,
107, 36, 80, 5, 54, 247, 218, 191, 211, 239, 76, 140, 33,
253, 85, 132, 101, 105, 177, 46, 183, 102, 99, 19, 175, 108,
252, 42, 238, 48, 251, 108, 90, 176, 2, 35, 46, 161, 252,
38, 225, 195, 174, 58, 165, 198, 249, 162, 118, 198, 41, 154,
10, 87, 24, 201, 12, 188, 1, 93, 179, 246, 134, 18, 178,
173, 36, 122, 89, 115, 46, 43, 205, 232, 55, 149, 30, 206,
97, 186, 125, 35, 209, 51, 48, 222, 222, 130, 173, 63, 0,
223, 19, 5, 162, 154, 143, 134, 63, 123, 102, 102, 212, 145,
80, 87, 212, 42, 26, 219, 225, 120, 94, 213, 238,
25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76, 151,
105, 138, 87, 125, 55, 60, 211, 15, 158, 198, 37, 54, 203,
239, 79, 56, 6, 53, 201, 97, 233, 178, 74, 193, 46, 249,
65, 5, 208, 130, 67, 191, 168, 152, 129, 253, 195, 231, 3,
109, 229, 254, 193, 229, 202, 108, 22, 89, 251, 13, 53, 47,
192, 12, 81, 19, 53, 93, 104, 41, 217, 215, 184, 136, 249,
14, 244, 4, 220, 33, 53, 142, 219, 43, 28, 68, 198, 202,
88, 235, 7, 233, 47, 84, 127, 28, 17, 189, 135, 183, 192,
239, 116, 31, 118, 186, 49, 251, 233, 220, 27, 97, 30, 43,
193, 217, 48, 24, 225, 15, 3, 26, 71, 82, 104,
175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173, 43,
195, 123, 111, 106, 5, 91, 254, 34, 76, 52, 82, 193, 179,
185, 71, 57, 215, 18, 5, 151, 13, 59, 206, 154, 95, 149,
40, 229, 16, 116, 144, 249, 67, 97, 223, 208, 144, 92, 174,
246, 77, 196, 211, 20, 123, 239, 250, 235, 65, 184, 54, 239,
168, 135, 17, 79, 117, 171, 173, 109, 39, 57, 13, 129, 79,
236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226, 44,
113, 164, 217, 46, 249, 182, 22, 98, 228, 49, 78, 101, 236,
181, 5, 245, 72, 62, 182, 151, 210, 254, 190, 35, 73, 190,
247, 50, 81, 49, 217, 86, 229, 139, 203, 57, 194});
checker.set_param(param).exect(
Testcase{src_tensor,
TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
TensorValue({32}, dtype::Int32(),
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32,
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
{}},
Testcase{{},
{},
{},
TensorValue(
{1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f),
{100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3,
8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3,
6, 7, -0, -2, -7, 11, 6, 3, -1, 7, 94,
-5, 6, -5, 98, 0, -3, -16, 5, 7, 13, -8,
1, 5, -5, -8, 108, -3, -8, -7, 110, 1, -2,
5, -0, 7, 8, -9, 14, -0, 1, -4})});
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 3, 8, 16}, dtype::Uint8(),
{195, 165, 82, 30, 154, 60, 175, 195, 179, 165,
132, 37, 250, 107, 36, 80, 5, 54, 247, 218,
191, 211, 239, 76, 140, 33, 253, 85, 132, 101,
105, 177, 46, 183, 102, 99, 19, 175, 108, 252,
42, 238, 48, 251, 108, 90, 176, 2, 35, 46,
161, 252, 38, 225, 195, 174, 58, 165, 198, 249,
162, 118, 198, 41, 154, 10, 87, 24, 201, 12,
188, 1, 93, 179, 246, 134, 18, 178, 173, 36,
122, 89, 115, 46, 43, 205, 232, 55, 149, 30,
206, 97, 186, 125, 35, 209, 51, 48, 222, 222,
130, 173, 63, 0, 223, 19, 5, 162, 154, 143,
134, 63, 123, 102, 102, 212, 145, 80, 87, 212,
42, 26, 219, 225, 120, 94, 213, 238,
25, 172, 141, 45, 182, 203, 50, 94, 44, 88,
74, 76, 151, 105, 138, 87, 125, 55, 60, 211,
15, 158, 198, 37, 54, 203, 239, 79, 56, 6,
53, 201, 97, 233, 178, 74, 193, 46, 249, 65,
5, 208, 130, 67, 191, 168, 152, 129, 253, 195,
231, 3, 109, 229, 254, 193, 229, 202, 108, 22,
89, 251, 13, 53, 47, 192, 12, 81, 19, 53,
93, 104, 41, 217, 215, 184, 136, 249, 14, 244,
4, 220, 33, 53, 142, 219, 43, 28, 68, 198,
202, 88, 235, 7, 233, 47, 84, 127, 28, 17,
189, 135, 183, 192, 239, 116, 31, 118, 186, 49,
251, 233, 220, 27, 97, 30, 43, 193, 217, 48,
24, 225, 15, 3, 26, 71, 82, 104,
175, 125, 79, 195, 50, 236, 114, 179, 180, 177,
230, 173, 43, 195, 123, 111, 106, 5, 91, 254,
34, 76, 52, 82, 193, 179, 185, 71, 57, 215,
18, 5, 151, 13, 59, 206, 154, 95, 149, 40,
229, 16, 116, 144, 249, 67, 97, 223, 208, 144,
92, 174, 246, 77, 196, 211, 20, 123, 239, 250,
235, 65, 184, 54, 239, 168, 135, 17, 79, 117,
171, 173, 109, 39, 57, 13, 129, 79, 236, 117,
134, 123, 149, 113, 198, 160, 249, 242, 220, 226,
44, 113, 164, 217, 46, 249, 182, 22, 98, 228,
49, 78, 101, 236, 181, 5, 245, 72, 62, 182,
151, 210, 254, 190, 35, 73, 190, 247, 50, 81,
49, 217, 86, 229, 139, 203, 57, 194}),
TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}),
TensorValue({28}, dtype::Int32(),
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24,
32, 25, 0, 1, 8, 16, 9, 2, 3, 10,
0, 1, 8, 16, 9, 2, 3, 10}),
{}},
Testcase{{},
{},
{},
TensorValue(
{1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f),
{100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3,
8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3,
6, 7,
94, -5, 6, -5, 98, 0, -3, -16, 5, 7, 13,
-8, 1, 5, -5, -8, 108, -3, -8, -7, 110, 1,
-2, 5, -0, 7, 8, -9, 14, -0, 1, -4})});
}
TEST_F(NAIVE, DCT_4x4) {
Checker<DctChannelSelectForward> checker(handle(),
/* check_dispatch */ false);
DctChannelSelectForward::Param param;
param.dct_block_size = 4;
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 1, 8, 8}, dtype::Uint8(),
{186, 120, 112, 220, 69, 80, 201, 127, 246, 254,
175, 50, 240, 251, 76, 37, 34, 166, 250, 195,
231, 139, 128, 233, 75, 80, 3, 2, 19, 140,
193, 203, 115, 107, 250, 209, 14, 243, 199, 60,
234, 107, 174, 156, 81, 87, 13, 116, 96, 140,
197, 253, 113, 223, 229, 159, 249, 252, 89, 84,
45, 16, 41, 72}),
{},
{},
{}},
Testcase{{},
{},
{},
TensorValue(
{1, 16, 2, 2}, dtype::Float32(),
{5.42000000e+02, 5.91750000e+02, 6.78000000e+02,
4.27750000e+02, 3.49953423e+01, -1.17686939e+01,
-1.66842098e+01, -3.85316620e+01, -3.80000000e+01,
-1.22500000e+01, 2.00000000e+01, -9.77500000e+01,
-1.61191311e+01, -9.46695328e+00, 3.28882408e+01,
-4.92537880e+01, 1.66958221e+02, -4.26609573e+01,
2.56999969e-01, 5.39384537e+01, 1.71819706e+01,
9.00009003e+01, -1.23818558e+02, 1.18912420e+01,
6.61014938e+01, -2.49261990e+01, 4.95798302e+00,
-1.02324417e+02, 7.85859919e+00, 3.73140755e+01,
1.03783745e+02, -4.61430321e+01, -1.43000000e+02,
-7.57500000e+01, -5.00000000e-01, -8.27500000e+01,
1.34834738e+01, -1.93409515e+02, 6.84791718e+01,
-4.01652241e+00, 1.22000000e+02, -8.57500000e+01,
-4.05000000e+01, -5.62500000e+01, -2.88564739e+01,
5.76532059e+01, -2.67414131e+01, 1.70877876e+01,
3.85416756e+01, 3.09300461e+01, 5.84670639e+00,
1.85747864e+02, -2.05141403e+02, -9.91859360e+01,
-1.66716263e+02, -1.71430378e+01, 6.71520996e+00,
8.41980438e+01, -3.50666313e+01, -1.48387482e+02,
1.08180256e+01, 5.49991112e+01, -1.06814528e+01,
1.86087704e+01})});
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 1, 8, 8}, dtype::Uint8(),
{186, 120, 112, 220, 69, 80, 201, 127, 246, 254,
175, 50, 240, 251, 76, 37, 34, 166, 250, 195,
231, 139, 128, 233, 75, 80, 3, 2, 19, 140,
193, 203, 115, 107, 250, 209, 14, 243, 199, 60,
234, 107, 174, 156, 81, 87, 13, 116, 96, 140,
197, 253, 113, 223, 229, 159, 249, 252, 89, 84,
45, 16, 41, 72}),
TensorValue({2}, dtype::Int32(), {0, 6}),
TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}),
{}},
Testcase{
{},
{},
{},
TensorValue(
{1, 6, 2, 2}, dtype::Float32(),
{5.4200000e+02, 5.9175000e+02, 6.7800000e+02,
4.2775000e+02, 3.4995342e+01, -1.1768694e+01,
-1.6684210e+01, -3.8531662e+01, -1.4300000e+02,
-7.5750000e+01, -5.0000000e-01, -8.2750000e+01,
1.6695822e+02, -4.2660957e+01, 2.5699997e-01,
5.3938454e+01, -3.8000000e+01, -1.2250000e+01,
2.0000000e+01, -9.7750000e+01, -1.6119131e+01,
-9.4669533e+00, 3.2888241e+01, -4.9253788e+01})});
}
TEST_F(NAIVE, DCT_WITH_MASK) {
Checker<DctChannelSelectForward> checker(handle(),
/* check_dispatch */ false);
DctChannelSelectForward::Param param;
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 3, 8, 16}, dtype::Uint8(),
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204}),
TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
TensorValue({32}, dtype::Int32(),
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32,
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
{}},
Testcase{{},
{},
{},
TensorValue({1, 32, 1, 2}, dtype::Float32(),
{890.12494, 941.25, -7.0498576,
99.47632, -22.850792, -97.862236,
-101.043236, -4.727012, 28.275675,
-157.96654, 42.1377, 45.06531,
-149.77373, 24.487143, -8.054966,
-13.990831, -6.9395194, -3.9211385,
64.79172, -12.363858, -47.875,
59., 56.271786, -62.725567,
120.522675, 16.559765, 85.74334,
112.904495, 99.375, 29.499973,
2.0220923, -19.681704, 890.12494,
941.25, -7.0498576, 99.47632,
-22.850792, -97.862236, -101.043236,
-4.727012, 28.275675, -157.96654,
42.1377, 45.06531, -149.77373,
24.487143, -8.054966, -13.990831,
890.12494, 941.25, -7.0498576,
99.47632, -22.850792, -97.862236,
-101.043236, -4.727012, 28.275675,
-157.96654, 42.1377, 45.06531,
-149.77373, 24.487143, -8.054966,
-13.990831})});
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 3, 8, 16}, dtype::Uint8(),
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204}),
TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}),
TensorValue({24}, dtype::Int32(),
{17, 24, 32, 25, 18, 11, 4, 5, 0, 1, 8, 16,
9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
{}},
Testcase{{},
{},
{},
TensorValue({1, 24, 1, 2}, dtype::Float32(),
{-6.9395194, -3.9211385, 64.79172,
-12.363858, -47.875, 59.,
56.271786, -62.725567, 120.522675,
16.559765, 85.74334, 112.904495,
99.375, 29.499973, 2.0220923,
-19.681704, 890.12494, 941.25,
-7.0498576, 99.47632, -22.850792,
-97.862236, -101.043236, -4.727012,
28.275675, -157.96654, 42.1377,
45.06531, -149.77373, 24.487143,
-8.054966, -13.990831, 890.12494,
941.25, -7.0498576, 99.47632,
-22.850792, -97.862236, -101.043236,
-4.727012, 28.275675, -157.96654,
42.1377, 45.06531, -149.77373,
24.487143, -8.054966, -13.990831})});
}
TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) {
Checker<DctChannelSelectForward> checker(handle(),
/* check_dispatch */ false);
using Param = DctChannelSelectForward::Param;
Param param;
param.fastImpl = Param::FastImpl::FIX_32_MASK;
checker.set_param(param).exect(
Testcase{TensorValue(
{1, 3, 8, 16}, dtype::Uint8(),
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204,
109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
238, 181, 232, 191, 161, 57, 23, 204}),
TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
TensorValue({32}, dtype::Int32(),
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32,
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
{}},
Testcase{{},
{},
{},
TensorValue({1, 32, 1, 2}, dtype::Float32(),
{890.12494, 941.25, -7.0498576,
99.47632, -22.850792, -97.862236,
-101.043236, -4.727012, 28.275675,
-157.96654, 42.1377, 45.06531,
-149.77373, 24.487143, -8.054966,
-13.990831, -6.9395194, -3.9211385,
64.79172, -12.363858, -47.875,
59., 56.271786, -62.725567,
120.522675, 16.559765, 85.74334,
112.904495, 99.375, 29.499973,
2.0220923, -19.681704, 890.12494,
941.25, -7.0498576, 99.47632,
-22.850792, -97.862236, -101.043236,
-4.727012, 28.275675, -157.96654,
42.1377, 45.06531, -149.77373,
24.487143, -8.054966, -13.990831,
890.12494, 941.25, -7.0498576,
99.47632, -22.850792, -97.862236,
-101.043236, -4.727012, 28.275675,
-157.96654, 42.1377, 45.06531,
-149.77373, 24.487143, -8.054966,
-13.990831})});
}
TEST_F(NAIVE, DCT_WITH_MASK2) {
Checker<DctChannelSelectForward> checker(handle(), false);
DctChannelSelectForward::Param param;
UniformIntRNG rng_oc(0, 3 * 64);
for (size_t n : {1, 3}) {
for (size_t ic : {1, 3}) {
for (size_t ih : {8, 16, 32, 512, 1024}) {
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
int random_oc = static_cast<int>(rng_oc.gen_single_val());
int max_oc = ic * 64;
int mask_oc = (random_oc % max_oc) + 1;
auto test_case =
gen_dct_case(n, ic, ih, iw, mask_oc, param);
checker.set_param(param).exect(test_case->testcase_in,
test_case->testcase_out);
}
}
}
}
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册