提交 12a0e615 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cuda elemwise int4

GitOrigin-RevId: 8a9aaec3281c154470b15ed79ba5be288d7a752a
上级 df1af59b
......@@ -6,6 +6,7 @@ dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary
dnn/src/cuda/matrix_mul/fp32_simt/kimpl/* binary
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary
dnn/src/cuda/convolution/backward_data/int8/kimpl/* binary
dnn/src/cuda/elemwise_multi_type/kimpl/* binary
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text
imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -text
ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text
......
......@@ -382,6 +382,9 @@ struct TensorLayout : public TensorShape {
//! get lowest and highest offset reachable from this layout
Span span() const;
//! total number of access bytes
size_t access_bytes() const;
};
/**
......
......@@ -308,6 +308,8 @@ class dt_qulowbit {
return _;
}
MEGDNN_DEVICE uint8_t as_storage() const { return _; }
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val):_(val) {}
#ifdef MEGDNN_CC_HOST
explicit operator uint8_t() { return _; }
......@@ -332,6 +334,8 @@ class dt_qlowbit {
return _;
}
MEGDNN_DEVICE int8_t as_storage() const { return _; }
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val):_(val) {}
#ifdef MEGDNN_CC_HOST
explicit operator int8_t() { return _; }
......
# As cuda currently do not support quint8, so we just ignore it.
SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')]
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32')]
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'),
('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')]
SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')]
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')]
MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
......@@ -16,6 +20,15 @@ MODES = {
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
}
QINT4_MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID',
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'],
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0',
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH',
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'],
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
}
QINT32_MODES = {
1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'],
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID',
......
......@@ -212,7 +212,7 @@ TensorLayout::TensorLayout(const TensorShape& shape, DType dtype,
TensorLayout::TensorLayout(const TensorShape& shape,
const std::vector<ptrdiff_t>& stride, DType dtype)
: TensorLayout(shape, stride, dtype, DefaultTensorFormat::make()) {}
: TensorLayout(shape, stride, dtype, Format(dtype)) {}
TensorLayout::TensorLayout(const TensorShape& shape,
const std::vector<ptrdiff_t>& stride, DType dtype,
......@@ -412,6 +412,27 @@ TensorLayout::Span TensorLayout::span() const {
return format.impl()->span_spec(*this);
}
size_t TensorLayout::access_bytes() const {
megdnn_assert(dtype.valid());
auto contig = collapse_contiguous();
size_t ret = 0;
if (dtype.is_low_bit()) {
ret = 1;
int align_size_in_elements = 8 / dtype.low_bit();
for (size_t i = 0; i < contig.ndim; ++i) {
if (contig.stride[i] == 1) {
ret *= round_up((int)contig.shape[i], align_size_in_elements);
} else {
ret *= contig.shape[i];
}
}
ret /= align_size_in_elements;
} else {
ret = dtype.size(total_nr_elems());
}
return ret;
}
TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
megdnn_throw_if(!ndim || !tshape.ndim, tensor_reshape_error,
"broadcast involves empty tensor");
......
......@@ -236,33 +236,66 @@ INST(dt_qint8);
INST(dt_quint8);
#undef dt_ibyte
template <int ndim>
void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init(
const TensorND& rv, int /*grid_size*/, int /*block_size*/) {
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr);
for (size_t i = 0; i < rv.layout.ndim; ++i) {
m_stride[i] = rv.layout.stride[i];
m_shape[i] = rv.layout.shape[i];
if (i + 1 < rv.layout.ndim) {
m_shape_highdim[i] = rv.layout.shape[i + 1];
if (rv.layout.stride[i + 1] == 1)
m_align_shape_highdim[i] =
(uint32_t)round_up((int)rv.layout.shape[i + 1], 2);
else
m_align_shape_highdim[i] = rv.layout.shape[i + 1];
}
}
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) {
m_shape_highdim[i] = 1;
m_align_shape_highdim[i] = 1;
}
for (size_t i = rv.layout.ndim; i < ndim; ++i) {
m_stride[i] = 0;
m_shape[i] = 1;
}
m_is_physical_contiguous = rv.layout.is_physical_contiguous();
}
#define ndim_cb(_ndim) \
template class ParamElemVisitor4bitBase<_ndim, BCAST_OTHER>;
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb)
#undef ndim_cb
} // namespace elemwise_intl
void elemwise_intl::get_launch_spec(const void* kern, size_t size,
int* grid_size, int* block_size) {
safe_size_in_kern(size);
auto config = query_launch_config_for_kernel(kern);
*block_size = config.block_size;
int a = size / (config.block_size * 2),
b = (size - 1) / (config.block_size * 3) + 1;
if (current_device_prop().major <= 3) {
// for Kepler, less blocks (more work per thread) is faster
*grid_size = b;
} else {
*grid_size = std::max(a, b);
safe_size_in_kern(size);
auto config = query_launch_config_for_kernel(kern);
*block_size = config.block_size;
int a = size / (config.block_size * 2),
b = (size - 1) / (config.block_size * 3) + 1;
if (current_device_prop().major <= 3) {
// for Kepler, less blocks (more work per thread) is faster
*grid_size = b;
} else {
*grid_size = std::max(a, b);
}
if (!*grid_size) {
*block_size = std::min<int>(std::max<int>(size / 64, 1) * 32, 1024);
*grid_size = std::max<int>(size / *block_size, 1);
}
// because we unroll 3 times in the kernel
megdnn_assert(static_cast<size_t>(*block_size) * *grid_size * 3 >=
size);
}
if (!*grid_size) {
*block_size = std::min<int>(std::max<int>(size / 64, 1) * 32, 1024);
*grid_size = std::max<int>(size / *block_size, 1);
}
// because we unroll 3 times in the kernel
megdnn_assert(static_cast<size_t>(*block_size) * *grid_size * 3 >= size);
}
void elemwise_intl::on_bad_ndim(int ndim) {
megdnn_throw(ssprintf("invalid ndim: %d", ndim));
MEGDNN_MARK_USED_VAR(ndim);
}
void elemwise_intl::on_bad_ndim(int ndim) {
megdnn_throw(ssprintf("invalid ndim: %d", ndim));
MEGDNN_MARK_USED_VAR(ndim);
}
} // namespace cuda
} // namespace megdnn
......
......@@ -115,6 +115,34 @@ INST(dt_qint32, int4);
#undef as_raw
#undef INST
struct int4bx2 {
int8_t x;
};
struct uint4bx2 {
uint8_t x;
};
#define INST(_ctype, _Storage, _vect_type) \
template <> \
class VectTypeTrait<_ctype> { \
public: \
using Storage = _Storage; \
static const Storage kMask = 0xf; \
static const Storage kBits = 4; \
using vect_type = _vect_type; \
static const size_t packed_size = 2; \
static __device__ __forceinline__ vect_type make_vector(Storage x, \
Storage y) { \
vect_type t; \
t.x = (x & kMask) | (y << kBits); \
return t; \
} \
}
INST(dt_qint4, int8_t, int4bx2);
INST(dt_quint4, uint8_t, uint4bx2);
#undef INST
/*!
* \brief visitor to access an elemeent in a tensor at given logic index
* \tparam ctype plain element ctype (i.e. ctype in DTypeTrait)
......@@ -217,6 +245,7 @@ template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, BCAST_OTHER>
: public ParamVisitorBase<ndim, ctype, BCAST_OTHER> {
public:
using CType = ctype;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
......@@ -500,6 +529,177 @@ public:
#endif
};
template <int ndim, BcastType brd_type>
class ParamElemVisitor4bitBase;
template <int ndim>
class ParamElemVisitor4bitBase<ndim, BCAST_OTHER> {
using Storage = int8_t;
protected:
Storage* __restrict m_ptr;
int m_stride[ndim];
int m_shape[ndim];
bool m_is_physical_contiguous;
//! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
Uint32Fastdiv m_align_shape_highdim[ndim > 1 ? ndim - 1 : 1];
#else
Uint32Fastdiv m_shape_highdim[ndim];
Uint32Fastdiv m_align_shape_highdim[ndim];
#endif
public:
static const Storage kMask = 0xf;
static const Storage kBits = 4;
static const int NDIM = ndim;
void host_init(const TensorND& rv, int grid_size, int block_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {}
devfunc void next() {}
devfunc void get_shape_from_access(uint32_t access_idx,
int (&shape_idx)[ndim]) {
#pragma unroll
for (int i = ndim - 1; i >= 1; --i) {
Uint32Fastdiv& align_shp = m_align_shape_highdim[i - 1];
uint32_t access_idx_div = access_idx / align_shp;
shape_idx[i] = access_idx - access_idx_div * align_shp.divisor();
access_idx = access_idx_div;
}
shape_idx[0] = access_idx;
}
devfunc int offset(uint32_t idx) {
int offset = 0;
#pragma unroll
for (int i = ndim - 1; i >= 1; --i) {
Uint32Fastdiv& shp = m_shape_highdim[i - 1];
uint32_t idx_div = idx / shp;
offset += (idx - idx_div * shp.divisor()) * m_stride[i];
idx = idx_div;
}
offset += idx * m_stride[0];
return offset;
}
devfunc int idx(uint32_t access_idx) {
int idx = 0;
if (m_is_physical_contiguous) {
idx = access_idx;
} else {
int shape_idx[ndim];
bool valid = true;
get_shape_from_access(access_idx, shape_idx);
#pragma unroll
for (int i = 0; i < ndim; ++i) {
valid &= (shape_idx[i] < m_shape[i]);
}
#pragma unroll
for (int i = 0; i < ndim - 1; ++i) {
idx = (idx + shape_idx[i]) * m_shape[i + 1];
}
idx = valid ? idx + shape_idx[ndim - 1] : -1;
}
return idx;
}
devfunc Storage* ptr() { return m_ptr; }
#endif
};
template <int ndim>
class ParamElemVisitor<ndim, dt_qint4, BCAST_OTHER>
: public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> {
using CType = dt_qint4;
using Storage = int8_t;
public:
static const int packed_size = 1;
using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>;
void host_init(const TensorND& rv, int grid_size, int block_size) {
Super::host_init(rv, grid_size, block_size);
}
#if MEGDNN_CC_CUDA
// cannot be l-value, only support read
devfunc dt_qint4 at(uint32_t idx) {
int offset_ = Super::offset(idx);
int vec_idx = offset_ >> 1;
int lane_idx = offset_ & 0x1;
Storage item = Storage(unpack_integer_4bits<true>(
*(Storage*)&Super::m_ptr[vec_idx], lane_idx * 4));
dt_qint4 result(item);
return result;
}
#endif
};
template <int ndim>
class ParamElemVisitor<ndim, dt_quint4, BCAST_OTHER>
: public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> {
using CType = dt_quint4;
using Storage = uint8_t;
using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>;
public:
static const int packed_size = 1;
void host_init(const TensorND& rv, int grid_size, int block_size) {
Super::host_init(rv, grid_size, block_size);
}
#if MEGDNN_CC_CUDA
// cannot be l-value, only support read
devfunc dt_quint4 at(uint32_t idx) {
int offset_ = Super::offset(idx);
int vec_idx = offset_ >> 1;
int lane_idx = offset_ & 0x1;
Storage item = Storage(unpack_integer_4bits<false>(
*(Storage*)&Super::m_ptr[vec_idx], lane_idx * 4));
dt_quint4 result(item);
return result;
}
#endif
};
#if MEGDNN_CC_CUDA
#define DEVICE_WRAPPER(x) x
#else
#define DEVICE_WRAPPER(x)
#endif
#define INST_DT_IBYTE(ctype) \
template <int ndim> \
class ParamVectVisitor<ndim, ctype, BCAST_OTHER> \
: public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { \
public: \
using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>; \
void host_init(const TensorND& rv, int grid_size, int block_size) { \
Super::host_init(rv, grid_size, block_size); \
} \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = VectTypeTrait<ctype>::packed_size; \
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t access_idx) { \
return *(rwtype*)(&Super::m_ptr[access_idx]); \
}) \
};
INST_DT_IBYTE(dt_qint4);
INST_DT_IBYTE(dt_quint4);
#undef DEVICE_WRAPPER
#undef INST_DT_IBYTE
/* f}}} */
#if MEGDNN_CC_CUDA
......@@ -507,7 +707,8 @@ public:
/* f{{{ user operator callers */
/*
* OpCaller is used to invoke user operator with loaded element arguments.
* OpCaller is used to invoke user operator with loaded element
* arguments.
*
* device interface:
* void thread_init(uint32_t idx);
......@@ -518,8 +719,8 @@ public:
*/
/*!
* \brief call user op directly without visiting any params (i.e. arity ==
* 0)
* \brief call user op directly without visiting any params (i.e. arity
* == 0)
*/
template <class Op>
struct OpCallerNull {
......@@ -1151,6 +1352,20 @@ public:
}
};
#define INST_DT_TYPE(ctype) \
template <class Op> \
class UserOpInvoker<Op, ctype, 2> \
: public UserOpInvokerToSameNdim<Op, ctype, 2> { \
public: \
UserOpInvoker(const ElemwiseOpParamN<2>& param, cudaStream_t stream, \
const Op& op) \
: UserOpInvokerToSameNdim<Op, ctype, 2>(param, stream, op) {} \
}
INST_DT_TYPE(dt_qint4);
INST_DT_TYPE(dt_quint4);
#undef INST_DT_TYPE
#define DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, \
_stride) \
DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \
......@@ -1404,7 +1619,6 @@ void run_elemwise(const ElemwiseOpParamN<arity>& param, cudaStream_t stream,
#define INST_RUN_ELEMWISE(Op, ctype, arity) \
template void run_elemwise<Op, ctype, arity>( \
const ElemwiseOpParamN<arity>&, cudaStream_t, const Op&)
#endif
} // namespace cuda
......
/**
* \file dnn/src/cuda/elemwise_helper_q4.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/elemwise_helper.cuh"
/*
* please note that all arithmetics on GPU are 32-bit for best performance; this
* limits max possible size
*/
namespace megdnn {
namespace cuda {
template <typename ctype>
struct IsNotTypeQ4 {
static constexpr bool value = !(std::is_same<ctype, dt_qint4>::value ||
std::is_same<ctype, dt_quint4>::value);
};
template <typename ctype>
struct IsTypeQ4 {
static constexpr bool value = (std::is_same<ctype, dt_qint4>::value ||
std::is_same<ctype, dt_quint4>::value);
};
//! internals for element-wise
namespace elemwise_intl {
#define devfunc __device__ __forceinline__
#if MEGDNN_CC_CUDA
/*!
* \brief call an operator whose each param are promted to the same ndim and
* brdcast_mask
* \tparam PVis ParamElemVisitor class
*/
template <class Op, int arity, class PVisSrc, class PVisDst, bool BetweenQ4>
struct OpCallerToQ4;
//! specialization for arity == 1
template <class Op, class PVisSrc, class PVisDst>
struct OpCallerToQ4<Op, 1, PVisSrc, PVisDst, false> {
Op op;
PVisSrc par_src[1];
PVisDst par_dst[1];
using src_ctype = typename PVisSrc::CType;
devfunc void on(uint32_t access_idx) {
int32_t idx0 = par_dst[0].idx(access_idx * 2);
int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1);
src_ctype src0 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0;
src_ctype src1 = (idx1 >= 0) ? par_src[0].at(idx1) : (src_ctype)0;
op(access_idx, src0, src1);
}
};
//! specialization for arity == 2
template <class Op, class PVisSrc, class PVisDst>
struct OpCallerToQ4<Op, 2, PVisSrc, PVisDst, false> {
Op op;
PVisSrc par_src[2];
PVisDst par_dst[1];
using src_ctype = typename PVisSrc::CType;
devfunc void on(uint32_t access_idx) {
int32_t idx0 = par_dst[0].idx(access_idx * 2);
int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1);
src_ctype src00 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0;
src_ctype src10 = (idx0 >= 0) ? par_src[1].at(idx0) : (src_ctype)0;
src_ctype src01 = (idx0 >= 0) ? par_src[0].at(idx1) : (src_ctype)0;
src_ctype src11 = (idx0 >= 0) ? par_src[1].at(idx1) : (src_ctype)0;
op(access_idx, src00, src10, src01, src11);
}
};
template <class Op, class PVisSrc, class PVisDst>
struct OpCallerToQ4<Op, 3, PVisSrc, PVisDst, false> {
Op op;
PVisSrc par_src[3];
PVisDst par_dst[1];
using src_ctype = typename PVisSrc::CType;
devfunc void on(uint32_t access_idx) {
int32_t idx0 = par_dst[0].idx(access_idx * 2);
int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1);
src_ctype src00 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0;
src_ctype src10 = (idx0 >= 0) ? par_src[1].at(idx0) : (src_ctype)0;
src_ctype src20 = (idx0 >= 0) ? par_src[2].at(idx0) : (src_ctype)0;
src_ctype src01 = (idx0 >= 0) ? par_src[0].at(idx1) : (src_ctype)0;
src_ctype src11 = (idx0 >= 0) ? par_src[1].at(idx1) : (src_ctype)0;
src_ctype src21 = (idx0 >= 0) ? par_src[2].at(idx1) : (src_ctype)0;
op(access_idx, src00, src10, src20, src01, src11, src21);
}
};
//! specialization for arity == 1
template <class Op, class PVisSrc, class PVisDst>
struct OpCallerToQ4<Op, 1, PVisSrc, PVisDst, true> {
Op op;
PVisSrc par_src[1];
PVisDst par_dst[1];
devfunc void on(uint32_t access_idx) {
op(access_idx, par_src[0].at(access_idx));
}
};
//! specialization for arity == 2
template <class Op, class PVisSrc, class PVisDst>
struct OpCallerToQ4<Op, 2, PVisSrc, PVisDst, true> {
Op op;
PVisSrc par_src[2];
PVisDst par_dst[1];
devfunc void on(uint32_t access_idx) {
op(access_idx, par_src[0].at(access_idx), par_src[1].at(access_idx));
}
};
template <class Op, class PVisSrc, class PVisDst>
struct OpCallerToQ4<Op, 3, PVisSrc, PVisDst, true> {
Op op;
PVisSrc par_src[3];
PVisDst par_dst[1];
devfunc void on(uint32_t access_idx) {
op(access_idx, par_src[0].at(access_idx), par_src[1].at(access_idx),
par_src[2].at(access_idx));
}
};
/* f}}} */
template <class OpCaller>
__global__ void cuda_kern_q4(OpCaller op_caller, uint32_t size) {
uint32_t access_idx = blockIdx.x * blockDim.x + threadIdx.x,
delta = blockDim.x * gridDim.x;
if (access_idx < size) {
op_caller.on(access_idx);
access_idx += delta;
if (access_idx < size) {
op_caller.on(access_idx);
access_idx += delta;
if (access_idx < size) {
op_caller.on(access_idx);
}
}
}
}
/* f{{{ UserOpInvoker specializations */
//! run op by promoting all params to same ndim
template <class Op, typename src_ctype, typename dst_ctype, int arity,
bool BetweenQ4>
class UserOpInvokerQ4 {
const ElemwiseOpParamN<arity>& m_src_param;
const ElemwiseOpParamN<1>& m_dst_param;
cudaStream_t m_stream;
const Op& m_op;
void dispatch0() {
switch (m_dst_param.max_ndim) {
#define cb(ndim) \
case ndim: \
return dispatch1<ndim>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
}
on_bad_ndim(m_dst_param.max_ndim);
}
template <int ndim>
void dispatch1() {
using PVisSrc = typename std::conditional<
BetweenQ4, ParamVectVisitor<ndim, src_ctype, BCAST_OTHER>,
ParamElemVisitor<ndim, src_ctype, BCAST_OTHER>>::type;
typedef OpCallerToQ4<Op, arity, PVisSrc,
ParamVectVisitor<ndim, dst_ctype, BCAST_OTHER>,
BetweenQ4>
Caller;
size_t size = m_dst_param[0].layout.access_bytes();
int grid_size, block_size;
void (*fptr)(Caller, uint32_t) = cuda_kern_q4<Caller>;
get_launch_spec(reinterpret_cast<const void*>(fptr), size, &grid_size,
&block_size);
Caller caller;
caller.op = m_op;
for (int i = 0; i < arity; ++i)
caller.par_src[i].host_init(m_src_param[i], grid_size, block_size);
caller.par_dst[0].host_init(m_dst_param[0], grid_size, block_size);
(*fptr)<<<grid_size, block_size, 0, m_stream>>>(caller, size);
after_kernel_launch();
}
public:
UserOpInvokerQ4(const ElemwiseOpParamN<arity>& src_param,
const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream,
const Op& op)
: m_src_param(src_param),
m_dst_param(dst_param),
m_stream(stream),
m_op(op) {
dispatch0();
}
};
#endif
/* f}}} */
#undef devfunc
} // namespace elemwise_intl
template <class Op, typename src_ctype, typename dst_ctype, int arity>
void run_elemwise(const ElemwiseOpParamN<arity>& src_param,
const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream,
const Op& op = Op());
#if MEGDNN_CC_CUDA
template <class Op, typename src_ctype, typename dst_ctype, int arity>
void run_elemwise(const ElemwiseOpParamN<arity>& src_param,
const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream,
const Op& op) {
src_param.assert_initialized();
dst_param.assert_initialized();
// TODO: Maybe 2bit?
megdnn_assert(dst_param[0].layout.dtype.is_low_bit());
megdnn_assert(dst_param[0].layout.is_contiguous());
elemwise_intl::UserOpInvokerQ4<Op, src_ctype, dst_ctype, arity,
IsTypeQ4<src_ctype>::value>(
src_param, dst_param, stream, op);
}
#define INST_RUN_ELEMWISE_LOWBIT(Op, src_ctype, dst_ctype, arity) \
template void run_elemwise<Op, src_ctype, dst_ctype, arity>( \
const ElemwiseOpParamN<arity>&, const ElemwiseOpParamN<1>&, \
cudaStream_t, const Op&)
#endif
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/elemwise_multi_type/kern_impl_q4.inl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#ifndef KERN_IMPL_MODE
#error "KERN_IMPL_MODE, KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE must be defined"
#endif
#include "src/cuda/elemwise_multi_type/kern_ops.cuh"
namespace megdnn {
namespace cuda {
#define cb(_m) \
typedef ElemwiseKern<megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, \
float> \
KernImpl; \
typedef kern_ops_quantized::QuantizedMultiTypeOp< \
KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \
Op; \
INST_RUN_ELEMWISE_LOWBIT(Op, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, \
KERN_IMPL_ARITY);
KERN_IMPL_MODE(cb)
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,11 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/elemwise_helper.cuh"
#include "src/cuda/elemwise_helper_q4.cuh"
#include "src/cuda/elemwise_multi_type/kern.cuh"
#include "src/cuda/utils.cuh"
......@@ -127,10 +129,10 @@ struct QuantizedMultiTypeOp;
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
1, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value>::type> {
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) &&
IsNotTypeQ4<ctype_dst>::value>::type> {
ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a;
......@@ -173,10 +175,10 @@ struct QuantizedMultiTypeOp<
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value>::type> {
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) &&
IsNotTypeQ4<ctype_dst>::value>::type> {
ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b;
......@@ -224,10 +226,10 @@ struct QuantizedMultiTypeOp<
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
3, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value>::type> {
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) &&
IsNotTypeQ4<ctype_dst>::value>::type> {
ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b, param_c;
......@@ -277,6 +279,367 @@ struct QuantizedMultiTypeOp<
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
1, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<IsTypeQ4<ctype_src>::value &&
IsNotTypeQ4<ctype_dst>::value>::type> {
ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
ctype_dst* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ ctype_dst apply(ctype_src v1) {
float fv1 = param_a.dequantize(v1);
float rv = KernImpl::apply(fv1);
return dst_param.quantize(rv);
}
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a) {
dst[idx] = dst_param.quantize(KernImpl::apply(param_a.dequantize(a)));
}
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<IsTypeQ4<ctype_src>::value &&
IsNotTypeQ4<ctype_dst>::value>::type> {
ctype_dst* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
ctype_dst* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
param_b = src_params[1];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ ctype_dst apply(ctype_src v1, ctype_src v2) {
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2);
float rv = KernImpl::apply(fv1, fv2);
return dst_param.quantize(rv);
}
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a,
ctype_src b) {
dst[idx] = dst_param.quantize(
KernImpl::apply(param_a.dequantize(a), param_b.dequantize(b)));
}
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
1, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<IsTypeQ4<ctype_src>::value &&
IsTypeQ4<ctype_dst>::value>::type> {
using src_storage =
typename elemwise_intl::VectTypeTrait<ctype_src>::Storage;
using dst_storage =
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a;
static constexpr bool src_signedness =
std::is_same<ctype_src, dt_qint4>::value;
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type
src_vect_type;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type
dst_vect_type;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ dst_storage apply(src_storage v1) {
float fv1 = param_a.dequantize(v1);
float rv = KernImpl::apply(fv1);
return dst_param.quantize(rv).as_storage();
}
__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a) {
dst_storage x = apply(
src_storage(unpack_integer_4bits<src_signedness>(a.x, 0)));
dst_storage y = apply(
src_storage(unpack_integer_4bits<src_signedness>(a.x, 4)));
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y);
}
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
1, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) &&
IsTypeQ4<ctype_dst>::value>::type> {
using dst_storage =
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type
dst_vect_type;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ dst_storage apply(ctype_src v1) {
float fv1 = param_a.dequantize(v1);
float rv = KernImpl::apply(fv1);
return dst_param.quantize(rv).as_storage();
}
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x,
ctype_src a_y) {
dst_storage x = apply(a_x), y = apply(a_y);
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y);
}
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<IsTypeQ4<ctype_src>::value &&
IsTypeQ4<ctype_dst>::value>::type> {
using src_storage =
typename elemwise_intl::VectTypeTrait<ctype_src>::Storage;
using dst_storage =
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b;
static constexpr bool src_signedness =
std::is_same<ctype_src, dt_qint4>::value;
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type
src_vect_type;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type
dst_vect_type;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
param_b = src_params[1];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ dst_storage apply(src_storage v1,
src_storage v2) {
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2);
float rv = KernImpl::apply(fv1, fv2);
return dst_param.quantize(rv).as_storage();
}
__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a,
src_vect_type b) {
src_storage a_x =
src_storage(unpack_integer_4bits<src_signedness>(a.x, 0));
src_storage a_y =
src_storage(unpack_integer_4bits<src_signedness>(a.x, 4));
src_storage b_x =
src_storage(unpack_integer_4bits<src_signedness>(b.x, 0));
src_storage b_y =
src_storage(unpack_integer_4bits<src_signedness>(b.x, 4));
dst_storage x = apply(a_x, b_x), y = apply(a_y, b_y);
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y);
}
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
2, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) &&
IsTypeQ4<ctype_dst>::value>::type> {
using dst_storage =
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type
dst_vect_type;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
param_b = src_params[1];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ dst_storage apply(ctype_src v1, ctype_src v2) {
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2);
float rv = KernImpl::apply(fv1, fv2);
return dst_param.quantize(rv).as_storage();
}
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x,
ctype_src b_x, ctype_src a_y,
ctype_src b_y) {
dst_storage x = apply(a_x, b_x), y = apply(a_y, b_y);
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y);
}
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
3, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<IsTypeQ4<ctype_src>::value &&
IsTypeQ4<ctype_dst>::value>::type> {
using src_storage =
typename elemwise_intl::VectTypeTrait<ctype_src>::Storage;
using dst_storage =
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b, param_c;
static constexpr bool src_signedness =
std::is_same<ctype_src, dt_qint4>::value;
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type
src_vect_type;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type
dst_vect_type;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
param_b = src_params[1];
param_c = src_params[2];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ dst_storage apply(src_storage v1, src_storage v2,
src_storage v3) {
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2),
fv3 = param_c.dequantize(v3);
float rv = KernImpl::apply(fv1, fv2, fv3);
return dst_param.quantize(rv).as_storage();
}
__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a,
src_vect_type b,
src_vect_type c) {
src_storage a_x =
src_storage(unpack_integer_4bits<src_signedness>(a.x, 0));
src_storage a_y =
src_storage(unpack_integer_4bits<src_signedness>(a.x, 4));
src_storage b_x =
src_storage(unpack_integer_4bits<src_signedness>(b.x, 0));
src_storage b_y =
src_storage(unpack_integer_4bits<src_signedness>(b.x, 4));
src_storage c_x =
src_storage(unpack_integer_4bits<src_signedness>(c.x, 0));
src_storage c_y =
src_storage(unpack_integer_4bits<src_signedness>(c.x, 4));
dst_storage x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y);
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y);
}
#endif
};
template <typename ctype_src, typename ctype_dst, typename KernImpl>
struct QuantizedMultiTypeOp<
3, ctype_src, ctype_dst, KernImpl,
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint32>::value ||
std::is_same<ctype_src, dt_quint8>::value) &&
IsTypeQ4<ctype_dst>::value>::type> {
using dst_storage =
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage;
dst_storage* dst;
CudaDTypeParam<ctype_dst> dst_param;
CudaDTypeParam<ctype_src> param_a, param_b, param_c;
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type
dst_vect_type;
#if !MEGDNN_CC_CUDA
QuantizedMultiTypeOp(
const SmallVector<CudaDTypeParam<ctype_src>>& src_params,
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param)
: dst{dst}, dst_param{dst_param} {
param_a = src_params[0];
param_b = src_params[1];
param_c = src_params[2];
}
#endif
#if MEGDNN_CC_CUDA
__device__ __forceinline__ dst_storage apply(ctype_src v1, ctype_src v2,
ctype_src v3) {
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2),
fv3 = param_c.dequantize(v3);
float rv = KernImpl::apply(fv1, fv2, fv3);
return dst_param.quantize(rv).as_storage();
}
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x,
ctype_src b_x, ctype_src c_x,
ctype_src a_y, ctype_src b_y,
ctype_src c_y) {
dst_storage x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y);
*(dst_vect_type*)(&dst[idx]) =
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y);
}
#endif
};
} // namespace kern_ops_quantized
} // namespace cuda
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册