提交 4a802d21 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add conv u4xs4 sass kernel

GitOrigin-RevId: 4defcf5f1f33f91c5d92df0282f337e9080dc322
上级 adf75a29
......@@ -35,7 +35,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter) {
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv());
megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) ||
(src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter.dtype.enumv() == DTypeEnum::QuantizedS4));
// check compatibility of bias's scale
if (src.dtype.category() == DTypeCategory::QUANTIZED) {
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
......
......@@ -598,8 +598,10 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(filter);
megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv(), "%s",
errmsg().c_str());
megdnn_assert(((src.dtype.enumv() == filter.dtype.enumv()) ||
(src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
"%s", errmsg().c_str());
check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
size_t img_dim;
if (param().format == Param::Format::NCHW ||
......
......@@ -488,6 +488,10 @@ void LowbitsAlignedTensorFormatBase::assert_valid(
"bad stride:%s, %zu", layout.to_string().c_str(),
layout.stride[i]);
}
if (!has_dim_unity_stride &&
(int)layout.stride[layout.ndim - 1] ==
round_up(1, (int)m_align_size_in_elements))
has_dim_unity_stride = true;
megdnn_assert(layout.ndim == 0 || has_dim_unity_stride,
"innermost dim not contiguous");
}
......@@ -546,7 +550,12 @@ bool LowbitsAlignedTensorFormatBase::is_contiguous_spec(
assert_valid(layout);
ptrdiff_t expected = 1;
for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) {
if (layout.shape[i] != 1 && layout.stride[i] != expected)
bool is_valid_stride =
(layout.stride[i] == expected) ||
(expected == 1 &&
(int)layout.stride[i] ==
round_up(1, (int)m_align_size_in_elements));
if (layout.shape[i] != 1 && !is_valid_stride)
return false;
auto multiplier = layout.shape[i];
if (i == static_cast<int>(layout.ndim) - 1)
......@@ -568,7 +577,7 @@ TensorLayout LowbitsAlignedTensorFormatBase::collapse_contiguous_spec(
res.stride[0] = 1;
return res;
}
if (res.shape[i] == 1 && res.stride[i] != 1) {
if (res.shape[i] == 1) {
res.remove_axis_inplace(i);
}
}
......
......@@ -232,6 +232,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
(rhs.enumv() == DTypeTrait<dt2>::enumv)) \
return lhs.param<dt1>().scale * rhs.param<dt2>().scale;
cb_binary(::megdnn::dtype::QuantizedS8, ::megdnn::dtype::QuantizedS16)
cb_binary(::megdnn::dtype::Quantized4Asymm, ::megdnn::dtype::QuantizedS4)
#undef cb_binary
megdnn_assert(lhs.enumv() == rhs.enumv());
......
......@@ -66,7 +66,8 @@ public:
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8,
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......
......@@ -15,7 +15,7 @@
#include "./quint4x4x32_wmma/activation_u4.cuh"
#include "./quint4x4x32_wmma/reduce_with_scale_data.cuh"
#include "./quint4x4x32_wmma/reduce_with_scale_filter.cuh"
#include "./reduce_with_scale_filter.cuh"
#include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh"
using namespace megdnn;
......@@ -75,7 +75,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_bundle(
// for reduce filter
{
size_t A = OC, B = IC * FH * FW / 8, C = 1;
ws_size_zp_filter += _do_dispatch_reduce_workspace_in_bytes(A, B, C);
ws_size_zp_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C);
}
size_t ws_size_zp_data = N * OH * OW * sizeof(int32_t);
size_t ws_size_relayout_filter = get_workspace_in_bytes_do_conv(args);
......@@ -135,11 +135,11 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(
int32_t zp_data_filter = zp_data * zp_filter * FH * FW * IC;
auto&& stream = cuda_stream(handle);
// zp filter
_do_dispatch_reduce_with_scale_filter_u4(
do_dispatch_reduce_with_scale_filter_4bit<false>(
static_cast<uint8_t*>(args.filter_tensor->raw_ptr), -zp_data, OC,
FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream);
// zp data
_do_dispatch_reduce_with_scale_data_u4(
do_dispatch_reduce_with_scale_data_u4(
ws_zp_data.ptr<int32_t>(),
static_cast<uint8_t*>(args.src_tensor->raw_ptr), N, IH, IW, OH, OW,
PH, PW, FH, FW, SH, SW, IC, -zp_filter,
......@@ -173,12 +173,12 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(
args.bias_tensor->compatible_ptr<int32_t>(), s0, s1, s2, s3};
auto&& param = args.opr->param();
if (param.nonlineMode == Param::NonlineMode::RELU) {
_do_dispatch_activation_u4<ActivationRELU>(
do_dispatch_activation_u4<ActivationRELU>(
args.dst_tensor->compatible_ptr<int32_t>(), visitor,
ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(),
zp_data_filter, N, OC, OH, OW, stream);
} else if (param.nonlineMode == Param::NonlineMode::IDENTITY) {
_do_dispatch_activation_u4<ActivationIdentity>(
do_dispatch_activation_u4<ActivationIdentity>(
args.dst_tensor->compatible_ptr<int32_t>(), visitor,
ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(),
zp_data_filter, N, OC, OH, OW, stream);
......
......@@ -87,11 +87,10 @@ __global__ void kern_activation_u4(int32_t* dst, const int32_t* zp_data,
} // namespace
template <typename ActivationOp>
void _do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor,
const int32_t* zp_data,
const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream) {
void do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor,
const int32_t* zp_data, const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream) {
void (*fptr)(int32_t*, const int32_t*, const int32_t*, int32_t, int, int OC,
int, int, BiasVisitor) = kern_activation_u4<ActivationOp>;
dim3 grids{0, 0, 0};
......@@ -105,7 +104,7 @@ void _do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor,
}
#define INST(_op) \
template void _do_dispatch_activation_u4<_op>( \
template void do_dispatch_activation_u4<_op>( \
int32_t * dst, BiasVisitor visitor, const int32_t* zp_data, \
const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, \
int co, int ho, int wo, cudaStream_t stream);
......
......@@ -82,12 +82,10 @@ struct ActivationIdentity {
} // namespace activation_u4
template <typename ActivationOp>
void _do_dispatch_activation_u4(int32_t* dst,
activation_u4::BiasVisitor visitor,
const int32_t* zp_data,
const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream);
void do_dispatch_activation_u4(int32_t* dst, activation_u4::BiasVisitor visitor,
const int32_t* zp_data, const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream);
} // namespace cuda
} // namespace megdnn
......
......@@ -444,7 +444,7 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels(
} // namespace
void megdnn::cuda::_do_dispatch_reduce_with_scale_data_u4(
void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4(
int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw,
int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic,
int32_t scale, uint8_t zp_data, cudaStream_t stream) {
......
......@@ -37,7 +37,7 @@
namespace megdnn {
namespace cuda {
void _do_dispatch_reduce_with_scale_data_u4(
void do_dispatch_reduce_with_scale_data_u4(
int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw,
int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic,
int32_t scale, uint8_t zp_data, cudaStream_t stream);
......
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cu
* \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./reduce_with_scale_filter.cuh"
#include "src/cuda/reduce_helper.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
using namespace megdnn;
using namespace cuda;
namespace {
struct ReduceWithScaleUInt4Op {
template <bool signedness>
struct ReduceWithScaleInt4Op {
typedef int32_t wtype;
const uint8_t* src;
int32_t* dst;
......@@ -63,9 +68,8 @@ struct ReduceWithScaleUInt4Op {
int32_t ret = 0;
#pragma unroll
for (int j = 0; j < 8; j++) {
uint8_t cur = (val & 0xF);
ret += cur;
val = (val >> 4);
ret += integer_subbyte::unpack_integer_4bits<signedness>(val,
(j << 2));
}
return ret;
}
......@@ -74,13 +78,14 @@ struct ReduceWithScaleUInt4Op {
} // namespace
void megdnn::cuda::_do_dispatch_reduce_with_scale_filter_u4(
template <bool signedness>
void megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit(
const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols,
int32_t* dst, cudaStream_t stream) {
// rows = OC
// cols is measured in pixels, i.e. IC * FH * FW / 8, a pixel consists of 8
// subbyte data,
ReduceWithScaleUInt4Op op;
ReduceWithScaleInt4Op<signedness> op;
op.src = src;
op.scale = scale;
op.dst = dst;
......@@ -88,13 +93,22 @@ void megdnn::cuda::_do_dispatch_reduce_with_scale_filter_u4(
static_cast<void>(stream);
static_cast<void>(rows);
static_cast<void>(cols);
run_reduce<ReduceWithScaleUInt4Op, false>(dst + rows, rows, cols, 1, stream,
op);
run_reduce<ReduceWithScaleInt4Op<signedness>, false>(dst + rows, rows, cols,
1, stream, op);
}
size_t megdnn::cuda::_do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B,
size_t C) {
return get_reduce_workspace_in_bytes<ReduceWithScaleUInt4Op>(A, B, C);
#define INST(signedness) \
template void \
megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit<signedness>( \
const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, \
int32_t* dst, cudaStream_t stream)
INST(false);
INST(true);
#undef INST
size_t megdnn::cuda::do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B,
size_t C) {
return get_reduce_workspace_in_bytes<ReduceWithScaleInt4Op<false>>(A, B, C);
}
// vim: ft=cpp syntax=cuda.doxygen
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cuh
* \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
void _do_dispatch_reduce_with_scale_filter_u4(const uint8_t* src, int32_t scale,
uint32_t rows, uint32_t cols,
int32_t* dst,
cudaStream_t stream);
size_t _do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C);
template <bool signedness>
void do_dispatch_reduce_with_scale_filter_4bit(const uint8_t* src,
int32_t scale, uint32_t rows,
uint32_t cols, int32_t* dst,
cudaStream_t stream);
size_t do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C);
} // namespace cuda
} // namespace megdnn
......
/**
* \file dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./algo.h"
#include "src/cuda/conv_bias/sass_helper.cuh"
#include "src/cuda/sass_loader.h"
#include "src/cuda/utils.h"
#include "src/common/conv_bias.h"
using namespace megdnn;
using namespace cuda;
using namespace sass;
namespace {
#if !MEGDNN_TEGRA_X1
// all stride are in bytes
void compute_conv2d_offset(size_t fh, size_t fw, size_t ics, size_t ihs,
Conv2dConstantOffset& constant_offset) {
constexpr int interleaved = 64;
constexpr int size_bits = 4;
constexpr int threablock_k = 128;
constexpr int inc_step = threablock_k / interleaved;
size_t i = 0;
int* s32 = reinterpret_cast<int*>(&(constant_offset.c_offset[0]));
for (; i < inc_step; i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
}
for (; i < (inc_step + fh * fw * inc_step); i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
int i_ = i - inc_step;
c = i_ / (fh * fw);
khkw = i_ % (fh * fw);
kh = khkw / fw;
kw = khkw % fw;
s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8;
}
}
#endif
}; // namespace
std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key(
const SizeArgs& args) const {
std::string kernel_key;
using NonlineMode = Param::NonlineMode;
auto&& param = args.opr->param();
if (args.z_layout->ndim > 0) {
kernel_key =
ssprintf("%s_conv_bias_int4_fuse_z_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
} else {
kernel_key =
ssprintf("%s_conv_bias_int4_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
}
if (param.nonlineMode == NonlineMode::H_SWISH) {
kernel_key += "_hswish";
} else {
megdnn_assert(param.nonlineMode == NonlineMode::RELU ||
param.nonlineMode == NonlineMode::IDENTITY);
kernel_key += "_relu";
}
return kernel_key;
}
bool ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::is_available(
const SizeArgs& args) const {
if (args.bias_layout->ndim <= 0)
return false;
using Param = param::ConvBias;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
bool available = true;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
if (!check_bias_share_in_channel(*(args.bias_layout), param.format))
return false;
if (param.format != Format::NCHW64)
return false;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
// TODO support group conv
available &= param.sparse == Sparse::DENSE;
// mode must be cross correlation
available &= param.mode == Mode::CROSS_CORRELATION;
// check data type
auto src_dtype = args.src_layout->dtype,
filter_dtype = args.filter_layout->dtype,
bias_dtype = args.bias_layout->dtype,
dst_dtype = args.dst_layout->dtype;
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS4 &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4 &&
bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS4);
// TODO: support dialtion
available &= dh == 1 && dw == 1;
// ensure precomputed offsets are positive integers
available &= hi >= fh && wi >= fw;
// only support sm_75 or later, platform should have tensorcore int8
// support
available &= is_compute_capability_required(7, 5);
// param buffer size is 4K, use 3K to store precomputed offset, fh * fw <=
// (3*1024/4/2/2) - 1
available &= fh * fw <= 191;
return available;
}
size_t
ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::get_workspace_in_bytes(
const SizeArgs& args) const {
if (args.preprocessed_filter == nullptr) {
return args.filter_layout->span().dist_byte() +
args.bias_layout->span().dist_byte();
}
return 0_z;
}
void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec(
const ExecArgs& args) const {
#if MEGDNN_TEGRA_X1
megdnn_throw("sass kernel is disabled at compile time for TX1");
#else
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
auto&& stream = cuda_stream(args.opr->handle());
constexpr int interleaved = 64;
void* bias_ptr = nullptr;
void* filter_ptr = nullptr;
if (args.preprocessed_filter) {
megdnn_assert(args.preprocessed_filter->tensors.size() == 2);
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr;
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr;
} else {
// reorder filter and bias
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr);
bias_ptr =
reinterpret_cast<void*>(args.workspace.raw_ptr +
args.filter_layout->span().dist_byte());
if (args.z_layout->ndim > 0) {
reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
} else {
reorder_imma_filter_bias<4, 64, true>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
}
}
uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh,
u32_fw = fw, u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw,
u32_co = co, u32_ho = ho, u32_wo = wo;
Conv2dInt4Param kern_param(u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw,
u32_sh, u32_sw, u32_ph, u32_pw, u32_co, u32_ho,
u32_wo, interleaved);
Conv2dConstantOffset kern_coffset;
compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset);
// The starting address of Turing param buffer is c[0x0][0x160]
kern_coffset.c_offset_param.begin = param_buffer_start_address();
kern_coffset.c_offset_param.size = 16 * (1 + fh * fw);
kern_coffset.c_offset_param.max = 16 * fh * fw;
kern_coffset.c_offset_param.rewind = 16 * (1 - fh * fw);
auto kern_key = kernel_key(args);
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS4>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS4>().scale,
bias_scale =
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
float alpha = src_scale * filter_scale / dst_scale,
beta = bias_scale / dst_scale;
float inv_dst_scale = 1.f / dst_scale;
unsigned int tx = m_threads, ty = 1;
unsigned int gridx = div_ceil<unsigned int>(
static_cast<unsigned int>(n * ho * wo), m_tile_nhw);
unsigned int gridy =
div_ceil<unsigned int>(static_cast<unsigned int>(co), m_tile_oc);
void* src_ptr = const_cast<void*>(args.src_tensor->raw_ptr);
void* dst_ptr = const_cast<void*>(args.dst_tensor->raw_ptr);
using NonlineMode = Param::NonlineMode;
auto&& kernel = SASSKernelLoader::instance().get_kernel(kern_key, kern_key);
if (args.z_layout->ndim > 0) {
void* z_ptr = const_cast<void*>(args.z_tensor->raw_ptr);
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS4>().scale;
float gamma = z_scale / dst_scale;
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr,
&dst_ptr, &alpha, &beta, &gamma};
kern_coffset.c_offset_param.begin +=
sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) +
sizeof(z_ptr) + sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta) +
sizeof(gamma);
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
if (param.nonlineMode == NonlineMode::H_SWISH) {
params.push_back(&dst_scale);
params.push_back(&inv_dst_scale);
kern_coffset.c_offset_param.begin +=
sizeof(dst_scale) + sizeof(inv_dst_scale);
} else {
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
}
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin +=
sizeof(kern_coffset.c_offset_param);
kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin;
params.push_back(&kern_coffset);
cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream,
params.data(), 0));
} else {
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr,
&dst_ptr, &alpha, &beta};
kern_coffset.c_offset_param.begin +=
sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) +
sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta);
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
if (param.nonlineMode == NonlineMode::H_SWISH) {
params.push_back(&dst_scale);
params.push_back(&inv_dst_scale);
kern_coffset.c_offset_param.begin +=
sizeof(dst_scale) + sizeof(inv_dst_scale);
} else {
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
}
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin +=
sizeof(kern_coffset.c_offset_param);
kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin;
params.push_back(&kern_coffset);
cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream,
params.data(), 0));
}
after_kernel_launch();
#endif
}
size_t ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::
get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
return 0_z;
}
SmallVector<TensorLayout> ConvBiasForwardImpl::
AlgoSASSInt4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout(
const SizeArgs& args) const {
return {args.filter_layout->collapse_contiguous(),
args.bias_layout->collapse_contiguous()};
}
void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess(
const ExecArgs& args) const {
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
auto&& stream = cuda_stream(args.opr->handle());
reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(
args.preprocessed_filter->tensors[0].raw_ptr),
args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
}
// vim: syntax=cpp.doxygen
......@@ -161,6 +161,38 @@ void forward_bias<dt_qint4, dt_qint4, dt_qint32, dt_qint32>(
forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>(
new_src, new_flt, bias, dst, nullptr, new_filter_meta);
}
template <>
void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, dt_byte* workspace_ptr,
const ConvBiasForward::CanonizedFilterMeta& filter_meta) {
auto convert_layout_src = [](const TensorLayout& layout) {
auto ret = layout;
auto param = layout.dtype.param<dtype::Quantized4Asymm>();
ret.dtype = dtype::QuantizedS8(param.scale);
ret.format = TensorFormat(ret.dtype);
ret.init_contiguous_stride();
return ret;
};
auto convert_layout_flt = [](const TensorLayout& layout) {
auto ret = layout;
auto param = layout.dtype.param<dtype::QuantizedS4>();
ret.dtype = dtype::QuantizedS8(param.scale);
ret.format = TensorFormat(ret.dtype);
ret.init_contiguous_stride();
return ret;
};
TensorND new_src = {workspace_ptr, convert_layout_src(src.layout)};
TensorND new_flt = {workspace_ptr + new_src.layout.span().dist_byte(),
convert_layout_flt(filter.layout)};
uint4_to_int8(src, new_src);
int4_to_int8(filter, new_flt);
auto new_filter_meta = filter_meta;
new_filter_meta.dtype = new_flt.layout.dtype;
forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>(
new_src, new_flt, bias, dst, nullptr, new_filter_meta);
}
} // namespace convolution
size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
......@@ -211,9 +243,10 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
TensorLayout{dst.layout, bias.layout.dtype}};
workspace_ptr += sfb.layout.span().dist_byte();
}
#define DISPATCH_RAW(in_dt, bias_dt, out_dt, cmode, func) \
#define DISPATCH_RAW(in_dt, flt_dt, bias_dt, out_dt, cmode, func) \
else if (src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \
filter.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \
filter.layout.dtype.enumv() == \
DTypeTrait<dtype::flt_dt>::enumv && \
bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv && \
sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv && \
param().compute_mode == Param::ComputeMode::cmode) { \
......@@ -222,7 +255,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
}
#define DISPATCH(in_dt, out_dt) \
DISPATCH_RAW( \
in_dt, out_dt, out_dt, DEFAULT, \
in_dt, in_dt, out_dt, out_dt, DEFAULT, \
(convolution::forward_bias<DTypeTrait<dtype::in_dt>::ctype, \
DTypeTrait<dtype::in_dt>::ctype, \
DTypeTrait<dtype::out_dt>::ctype, \
......@@ -236,16 +269,21 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
DISPATCH(QuantizedS8, Float32)
DISPATCH(Quantized8Asymm, QuantizedS32)
DISPATCH(Quantized4Asymm, QuantizedS32)
DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32,
DISPATCH_RAW(QuantizedS8, QuantizedS8, QuantizedS32, QuantizedS32,
FLOAT32,
(convolution::forward_bias<dt_int8, dt_int8, dt_int32,
dt_int32>))
DISPATCH(QuantizedS4, QuantizedS32)
DISPATCH_RAW(Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32,
DEFAULT,
(convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32,
dt_qint32>))
#if !MEGDNN_DISABLE_FLOAT16
DISPATCH(Float16, Float16)
DISPATCH_RAW(Float16, Float16, Float16, FLOAT32,
DISPATCH_RAW(Float16, Float16, Float16, Float16, FLOAT32,
(convolution::forward_bias<dt_float16, dt_float16,
dt_float16, dt_float32>))
DISPATCH_RAW(BFloat16, BFloat16, BFloat16, FLOAT32,
DISPATCH_RAW(BFloat16, BFloat16, BFloat16, BFloat16, FLOAT32,
(convolution::forward_bias<dt_bfloat16, dt_bfloat16,
dt_bfloat16, dt_float32>))
#endif
......
......@@ -57,6 +57,54 @@ void megdnn::naive::uint8_to_uint4(const TensorND& in, const TensorND& out) {
}
}
void megdnn::naive::uint4_to_int8(const TensorND& in, const TensorND& out) {
auto in_ptr = static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte;
auto out_ptr = out.compatible_ptr<int8_t>() + out.layout.span().low_byte;
const auto& ly = in.layout;
int8_t zero_point =
(int8_t)ly.dtype.param<dtype::Quantized4Asymm>().zero_point;
auto dim_in = ly.shape[ly.ndim - 1];
auto elems = ly.total_nr_elems();
auto dim_out = elems / dim_in;
auto stride_out = div_ceil(dim_in, 2_z);
for (size_t i = 0; i < dim_out; ++i) {
for (size_t j = 0; j < dim_in; j += 2) {
uint8_t val = in_ptr[j / 2];
out_ptr[j] = (int8_t)(val & 0xF) - zero_point;
if (j + 1 < dim_in)
out_ptr[j + 1] = (int8_t)((val >> 4) & 0xF) - zero_point;
}
in_ptr += stride_out;
out_ptr += dim_in;
}
}
void megdnn::naive::int8_to_uint4(const TensorND& in, const TensorND& out) {
auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte;
auto out_ptr =
static_cast<uint8_t*>(out.raw_ptr) + out.layout.span().low_byte;
auto zero_point =
out.layout.dtype.param<dtype::Quantized4Asymm>().zero_point;
const auto& ly = in.layout;
auto dim_in = ly.shape[ly.ndim - 1];
auto elems = ly.total_nr_elems();
auto dim_out = elems / dim_in;
auto stride_out = div_ceil(dim_in, 2_z);
for (size_t i = 0; i < dim_out; ++i) {
for (size_t j = 0; j < dim_in; j += 2) {
uint8_t a = (uint8_t)std::max((int32_t)in_ptr[j] + zero_point, 0);
uint8_t b = 0;
if (j + 1 < dim_in)
b = (uint8_t)std::max((int32_t)in_ptr[j + 1] + zero_point, 0);
a = std::min(a, DTypeTrait<dtype::Quantized4Asymm>::max());
b = std::min(b, DTypeTrait<dtype::Quantized4Asymm>::max());
out_ptr[j / 2] = a + (b << 4);
}
in_ptr += dim_in;
out_ptr += stride_out;
}
}
// ==================================qint4======================================
void megdnn::naive::int4_to_int8(const TensorND& in, const TensorND& out) {
auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte;
......
......@@ -20,6 +20,10 @@ void uint4_to_uint8(const TensorND& in, const TensorND& out);
void uint8_to_uint4(const TensorND& in, const TensorND& out);
void uint4_to_int8(const TensorND& in, const TensorND& out);
void int8_to_uint4(const TensorND& in, const TensorND& out);
void int4_to_int8(const TensorND& in, const TensorND& out);
void int8_to_int4(const TensorND& in , const TensorND& out);
......
......@@ -733,19 +733,33 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
param::ConvBias::Format format,
const std::vector<TestArg>& args, bool fuse_z,
bool stable_test) {
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
megdnn_assert((src_dtype.enumv() == filter_dtype.enumv()) ||
(src_dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4));
Checker<ConvBiasForward> checker(handle, !stable_test);
if (algo) {
checker.set_before_exec_callback(
ConvBiasAlgoChecker<ConvBiasForward>(algo));
}
std::unique_ptr<RNG> rng;
std::unique_ptr<RNG> flt_rng;
std::unique_ptr<RNG> bias_rng;
std::unique_ptr<RNG> const_rng;
std::unique_ptr<RNG> zero_rng;
// TODO: check range of rng
if (src_dtype.enumv() == DTypeEnum::QuantizedS8) {
rng = std::make_unique<UniformIntRNG>(-3, 3);
flt_rng = std::make_unique<UniformIntRNG>(-3, 3);
const_rng = std::make_unique<UniformIntRNG>(1, 1);
zero_rng = std::make_unique<UniformIntRNG>(0, 0);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32);
bias_rng = std::make_unique<UniformIntRNG>(-50, 50);
checker.set_epsilon(1 + 1e-3)
.set_max_avg_error(1e-1)
.set_max_avg_biased_error(1e-3);
} else if (src_dtype.enumv() == DTypeEnum::Quantized4Asymm) {
rng = std::make_unique<UniformIntRNG>(0, 6);
flt_rng = std::make_unique<UniformIntRNG>(-3, 3);
const_rng = std::make_unique<UniformIntRNG>(1, 1);
zero_rng = std::make_unique<UniformIntRNG>(0, 0);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32);
......@@ -755,6 +769,7 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
.set_max_avg_biased_error(1e-3);
} else if (src_dtype.enumv() == DTypeEnum::QuantizedS4) {
rng = std::make_unique<UniformIntRNG>(-3, 3);
flt_rng = std::make_unique<UniformIntRNG>(-3, 3);
const_rng = std::make_unique<UniformIntRNG>(1, 1);
zero_rng = std::make_unique<UniformIntRNG>(0, 0);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32);
......@@ -764,11 +779,13 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
.set_max_avg_biased_error(1e-3);
} else if (src_dtype.enumv() == DTypeEnum::Float16) {
rng = std::make_unique<NormalRNG>(2.f);
flt_rng = std::make_unique<NormalRNG>(2.f);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float16);
bias_rng = std::make_unique<NormalRNG>(2.f);
checker.set_epsilon(1e-2);
} else if (src_dtype.enumv() == DTypeEnum::Float32) {
rng = std::make_unique<NormalRNG>(2.f);
flt_rng = std::make_unique<NormalRNG>(2.f);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float32);
bias_rng = std::make_unique<NormalRNG>(2.f);
}
......@@ -819,9 +836,9 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
}
return z;
};
megdnn_assert(rng != nullptr && bias_rng != nullptr);
megdnn_assert(rng != nullptr && flt_rng != nullptr && bias_rng != nullptr);
checker.set_rng(0, rng.get())
.set_rng(1, rng.get())
.set_rng(1, flt_rng.get())
.set_rng(2, bias_rng.get())
.set_rng(3, rng.get());
if (stable_test) {
......
......@@ -257,7 +257,9 @@ void benchmark_target_algo_with_cudnn_tsc(
param::ConvBias::Format change_cudnn_format,
DType change_cudnn_src_dtype, DType change_cudnn_filter_dtype,
DType change_cudnn_bias_dtype, DType change_cudnn_dst_dtype) {
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
megdnn_assert((src_dtype.enumv() == filter_dtype.enumv()) ||
(src_dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4));
CUBenchmarker<ConvBiasForward> benchmarker(handle);
CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle);
size_t RUNS = 200;
......@@ -299,30 +301,30 @@ void benchmark_target_algo_with_cudnn_tsc(
using Param = ConvBias::Param;
using Format = Param::Format;
// helper function to change format
auto get_tensor_shape = [](TensorShape shape,
auto get_tensor_shape = [](TensorShape shape, DType dtype,
Format format) -> TensorShape {
TensorShape ret;
if (format == Format::NCHW4) {
ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::Int8()}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 4, 4, shape[2],
shape[3]})
.dimshuffle({0, 1, 3, 4, 2}));
} else if (format == Format::NCHW32) {
ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::Int8()}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 32, 32, shape[2],
shape[3]})
.dimshuffle({0, 1, 3, 4, 2}));
} else if (format == Format::NCHW64) {
ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::QuantizedS4(1.f)}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 64, 64, shape[2],
shape[3]})
.dimshuffle({0, 1, 3, 4, 2}));
} else if (format == Format::CHWN4) {
ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::Int8()}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 4, 4, shape[2],
shape[3]})
.dimshuffle({1, 3, 4, 0, 2}));
......@@ -370,21 +372,24 @@ void benchmark_target_algo_with_cudnn_tsc(
if (algo) {
time_in_ms =
algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>,
CUTimer>(benchmarker,
{get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
{},
{}},
algo) /
CUTimer>(
benchmarker,
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
{},
{}},
algo) /
RUNS;
} else {
time_in_ms = benchmarker.execs({get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
{},
{}}) /
RUNS;
time_in_ms =
benchmarker.execs(
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
{},
{}}) /
RUNS;
}
float time_in_ms_cudnn = 0;
if (with_cudnn) {
......@@ -393,9 +398,11 @@ void benchmark_target_algo_with_cudnn_tsc(
algo_benchmark<ConvBiasForward,
OprProxy<ConvBiasForward>, CUTimer>(
benchmarker_cudnn,
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
{},
{}},
change_cudnn_algo) /
......@@ -403,9 +410,11 @@ void benchmark_target_algo_with_cudnn_tsc(
} else {
time_in_ms_cudnn =
benchmarker_cudnn.execs(
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
{},
{}}) /
RUNS;
......@@ -426,21 +435,24 @@ void benchmark_target_algo_with_cudnn_tsc(
if (algo) {
time_in_ms =
algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>,
CUTimer>(benchmarker,
{get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
get_tensor_shape(z, format),
{}},
algo) /
CUTimer>(
benchmarker,
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
get_tensor_shape(z, src_dtype, format),
{}},
algo) /
RUNS;
} else {
time_in_ms = benchmarker.execs({get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
get_tensor_shape(z, format),
{}}) /
RUNS;
time_in_ms =
benchmarker.execs(
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
get_tensor_shape(z, src_dtype, format),
{}}) /
RUNS;
}
time_in_ms_cudnn = 0;
if (with_cudnn) {
......@@ -449,20 +461,24 @@ void benchmark_target_algo_with_cudnn_tsc(
algo_benchmark<ConvBiasForward,
OprProxy<ConvBiasForward>, CUTimer>(
benchmarker_cudnn,
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
get_tensor_shape(z, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
get_tensor_shape(z, src_dtype, format_cudnn),
{}},
change_cudnn_algo) /
RUNS;
} else {
time_in_ms_cudnn =
benchmarker_cudnn.execs(
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
get_tensor_shape(z, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
get_tensor_shape(z, src_dtype, format_cudnn),
{}}) /
RUNS;
}
......
......@@ -746,6 +746,45 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4) {
checker.set_param(param).exect(Testcase{input, filter, bias, z, {}},
Testcase{{}, {}, {}, {}, output});
// test qu4 x q4
for (size_t i = 0; i < input_values.size(); i++) {
input_values[i] = input_values[i] + 8;
}
for (size_t i = 0; i < z_values.size(); i++) {
z_values[i] = z_values[i] + 8;
}
std::vector<int> output_uint4;
auto dtype_qu4 = dtype::Quantized4Asymm(0.01, 8);
for (size_t i = 0; i < output_values.size(); i++) {
int result =
static_cast<int>(dtype_qu4.param()
.quantize(output_values.at(i) * 0.01)
.as_uint8());
output_uint4.push_back(result);
}
auto input_qu4 = TensorValueLowbit4(
{1, 1, 4, 4}, dtype::Quantized4Asymm(0.1, 8), input_values);
auto filter_q4 = TensorValueLowbit4({3, 1, 3, 3}, dtype::QuantizedS4(0.1),
filter_values);
auto bias_s32 = GenTensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.01),
bias_values);
auto z_qu4 = TensorValueLowbit4({1, 3, 2, 2},
dtype::Quantized4Asymm(0.01, 8), z_values);
auto output_qu4 = TensorValueLowbit4(
{1, 3, 2, 2}, dtype::Quantized4Asymm(0.01, 8), output_uint4);
checker.set_param(param).exect(
Testcase{input_qu4, filter_q4, bias_s32, z_qu4, {}},
Testcase{{}, {}, {}, {}, output_qu4});
}
TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) {
......@@ -3329,7 +3368,7 @@ TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) {
auto input_64 = TensorValueLowbit4({1, 1, 4, 4, 64},
dtype::QuantizedS4(0.1), input_values);
auto fliter_64 = TensorValueLowbit4({64, 1, 3, 3, 64},
auto filter_64 = TensorValueLowbit4({64, 1, 3, 3, 64},
dtype::QuantizedS4(0.1), filter_values);
auto bias1_64 =
GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1);
......@@ -3338,7 +3377,31 @@ TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) {
{1, 1, 2, 2, 64}, dtype::QuantizedS4(1), output_values);
checker.set_param(param).exect(
Testcase{input_64, fliter_64, bias1_64, {}, {}},
Testcase{input_64, filter_64, bias1_64, {}, {}},
Testcase{{}, {}, {}, {}, output_64});
// test qu4 x q4
for (size_t i = 0; i < input_values.size(); i++) {
input_values[i] = input_values[i] + 8;
}
for (size_t i = 0; i < output_values.size(); i++) {
output_values[i] = output_values[i] + 8;
}
auto input_qu4_64 = TensorValueLowbit4(
{1, 1, 4, 4, 64}, dtype::Quantized4Asymm(0.1, 8), input_values);
auto filter_q4_64 = TensorValueLowbit4(
{64, 1, 3, 3, 64}, dtype::QuantizedS4(0.1), filter_values);
auto bias_64 =
GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1);
auto output_q4_64 = TensorValueLowbit4(
{1, 1, 2, 2, 64}, dtype::Quantized4Asymm(1, 8), output_values);
checker.set_param(param).exect(
Testcase{input_qu4_64, filter_q4_64, bias_64, {}, {}},
Testcase{{}, {}, {}, {}, output_q4_64});
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册