提交 3bf73ff1 编写于 作者: M Megvii Engine Team

feat(dnn): add cuda preprocess fusion

GitOrigin-RevId: d789c99e59ce713a075061aacf6acdba78af43d3
上级 86cf7490
......@@ -201,6 +201,8 @@ class dt_quint8 {
#endif
bool operator<(const dt_quint8& b) const { return _ < b._; }
bool operator>(const dt_quint8& b) const { return _ > b._; }
bool operator==(const dt_quint8& b) const { return _ == b._; }
bool operator!=(const dt_quint8& b) const { return _ != b._; }
} MEGDNN_PACKED;
class dt_qint32 {
......@@ -255,6 +257,8 @@ class dt_qint8 {
#endif
bool operator<(const dt_qint8& b) const { return _ < b._; }
bool operator>(const dt_qint8& b) const { return _ > b._; }
bool operator==(const dt_qint8& b) const { return _ == b._; }
bool operator!=(const dt_qint8& b) const { return _ != b._; }
} MEGDNN_PACKED;
class dt_qint16 {
......
......@@ -877,6 +877,7 @@ when the ``I`` suffix is present.
'NCHW88_NCHW',
'NCHW_NCHW4_IC_SMALL',
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT',
'NCHW_NCHW4',
)
)
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
......@@ -94,7 +95,9 @@ void RelayoutForward::check_layout_and_canonize(TensorLayout& src,
src = src.collapse_contiguous();
dst = dst.collapse_contiguous();
megdnn_assert(src.dtype == dst.dtype &&
src.total_nr_elems() == dst.total_nr_elems());
src.total_nr_elems() == dst.total_nr_elems(),
"check %s == %s and %zu == %zu", src.dtype.name(),
dst.dtype.name(), src.total_nr_elems(), dst.total_nr_elems());
}
bool relayout::is_transpose(const TensorLayout& src, const TensorLayout& dst,
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
......@@ -207,6 +208,15 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
dst[3] = src[2];
dst[4] = src[4];
break;
case Param::Mode::NCHW_NCHW4:
megdnn_assert(src.ndim == 4);
dst.ndim = 5;
dst[0] = src[0];
dst[1] = div_ceil<size_t>(src[1], 4);
dst[2] = src[2];
dst[3] = src[3];
dst[4] = 4;
break;
default:
megdnn_assert(0, "Invalid RelayoutFormat Mode");
break;
......@@ -214,7 +224,9 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
TensorFormat dst_fmt;
deduce_format(src.format, dst_fmt);
dst.format = dst_fmt;
dst.dtype = src.dtype;
if (!dst.dtype.valid()) {
dst.dtype = src.dtype;
}
dst.init_contiguous_stride();
}
......@@ -245,6 +257,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
CHECK_SRC(DefaultTensorFormat::make());
dst = src;
break;
case Param::Mode::NCHW_NCHW4:
CHECK_SRC(DefaultTensorFormat::make());
dst = src;
break;
case Param::Mode::NCHW_NHWCD4I:
CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(2, align);
......@@ -322,6 +338,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
void RelayoutFormat::check_layout_fwd(const TensorLayout& src,
const TensorLayout& dst) {
TensorLayout dst_expected;
dst_expected.dtype = dst.dtype;
deduce_layout_fwd(src, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);
}
......@@ -354,6 +371,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
exec_dst = dst;
}
break;
case Param::Mode::NCHW_NCHW4:
// nchw to nchw4
{
TensorLayout work_space_layout(
{src[0], round_up(src[1], 4_z), src[2], src[3]},
src.dtype, src.format);
exec_src = work_space_layout
.reshape({src[0], div_ceil(src[1], 4_z), 4,
src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2});
exec_dst = dst;
}
break;
case Param::Mode::NCHW88_NCHW:
// nchw8c to nchw
exec_src = src;
......@@ -422,7 +452,6 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
}
break;
case Param::Mode::NCHW_NHWCD4:
case Param::Mode::NCHW_NHWCD4I:
// src is {N, C, H, W}
......
......@@ -6,11 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/handle.h"
#include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/relayout_format/relayout_format.h"
#include "src/cuda/utils.h"
using namespace megdnn;
......@@ -21,6 +23,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
auto src_dtype = src.layout.dtype;
megdnn_assert(
param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
param().mode == param::RelayoutFormat::Mode::NCHW_NCHW4 ||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode ==
......@@ -72,12 +75,25 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
return handle()->create_operator<RelayoutForward>()->exec(
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout});
}
TensorLayout exec_src, exec_dst;
deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst);
TensorND exec_src_nd{src.raw_ptr, exec_src};
TensorND exec_dst_nd{dst.raw_ptr, exec_dst};
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd,
exec_dst_nd);
if (param().mode == Param::Mode::NCHW_NCHW4) {
bool is_usable = relayout_format::RelayoutFormatFast::usable(
src.layout, dst.layout);
megdnn_assert(is_usable,
"RelayoutFormatNCHW_NCHW4 kernel not usable for %s(%s) "
"to %s(%s)",
src.layout.to_string().c_str(), src.layout.dtype.name(),
dst.layout.to_string().c_str(), dst.layout.dtype.name());
relayout_format::RelayoutFormatFast::exec(src, dst,
cuda_stream(this->handle()));
} else {
TensorLayout exec_src, exec_dst;
deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst);
TensorND exec_src_nd{src.raw_ptr, exec_src};
TensorND exec_dst_nd{dst.raw_ptr, exec_dst};
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd,
exec_dst_nd);
}
}
size_t RelayoutFormatImpl::get_workspace_in_bytes(
......
/**
* \file dnn/src/cuda/relayout_format/relayout_format.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/relayout_format/relayout_format.cuh"
#include "src/cuda/relayout_format/relayout_format.h"
using namespace megdnn;
using namespace cuda;
namespace {
inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale,
uint8_t& zero_point) {
if (tensor_dtype.enumv() == DTypeEnum::Quantized8Asymm) {
zero_point = tensor_dtype.param<dtype::Quantized8Asymm>().zero_point;
scale = tensor_dtype.param<dtype::Quantized8Asymm>().scale;
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS8) {
scale = tensor_dtype.param<dtype::QuantizedS8>().scale;
}
}
} // namespace
bool relayout_format::RelayoutFormatFast::usable(
const TensorLayout& src_layout, const TensorLayout& dst_layout) {
return relayout_format_cuda_usable(src_layout, dst_layout);
}
void relayout_format::RelayoutFormatFast::exec(const TensorND& src,
const TensorND& dst,
cudaStream_t stream) {
size_t ih = src.layout[2];
size_t iw = src.layout[3];
size_t hw = ih * iw;
float src_scale = 1.f;
float dst_scale = 1.f;
uint8_t src_zero_point = 0;
uint8_t dst_zero_point = 0;
get_scale_zeropoint(src.layout.dtype, src_scale, src_zero_point);
get_scale_zeropoint(dst.layout.dtype, dst_scale, dst_zero_point);
if (src.layout.dtype.enumv() == DTypeEnum::Uint8) {
src_zero_point = 128;
}
if (hw % 4 == 0) {
relayout_format_cuda_exec<4>(src, dst, stream, src_scale, dst_scale,
src_zero_point, dst_zero_point);
} else {
relayout_format_cuda_exec<1>(src, dst, stream, src_scale, dst_scale,
src_zero_point, dst_zero_point);
}
}
/**
* \file dnn/src/cuda/relayout_format/relayout_format.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/relayout_format/relayout_format.cuh"
using namespace megdnn;
using namespace cuda;
namespace {
template <typename SrcType, typename DstType, bool same_scale>
struct CudaPostProcess;
template <>
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, true> {
CudaPostProcess(float, uint8_t, float, uint8_t){};
inline __device__ int8_t operator()(uint8_t val) { return val - 128; }
};
template <>
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, false> {
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
};
inline __device__ int8_t operator()(uint8_t val) {
return m_dst_type_cvt.quantize((float)val - 128.f).as_int8();
}
};
template <>
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, false> {
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
CudaDTypeParamImpl<dt_quint8> m_src_type_cvt;
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale,
uint8_t) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
m_src_type_cvt =
CudaDTypeParamImpl<dt_quint8>(src_scale, src_zero_point);
};
inline __device__ int8_t operator()(uint8_t val) {
float med_var = m_src_type_cvt.dequantize(dt_quint8(val));
return m_dst_type_cvt.quantize(med_var).as_int8();
}
};
template <>
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, true> {
uint8_t m_src_zero_point = 0;
CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) {
m_src_zero_point = src_zero_point;
};
inline __device__ int8_t operator()(uint8_t val) {
return val - m_src_zero_point;
}
};
template <>
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, false> {
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
CudaDTypeParamImpl<dt_qint8> m_src_type_cvt;
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
m_src_type_cvt = CudaDTypeParamImpl<dt_qint8>(src_scale);
};
inline __device__ int8_t operator()(int8_t val) {
float med_var = m_src_type_cvt.dequantize(dt_qint8(val));
return m_dst_type_cvt.quantize(med_var).as_int8();
}
};
template <>
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, true> {
CudaPostProcess(float, uint8_t, float, uint8_t){};
inline __device__ int8_t operator()(int8_t val) { return val; }
};
template <typename SrcType, int pack_w>
struct DTypeRWHelper;
template <>
struct DTypeRWHelper<char, 1> {
using InnerDtype = char;
using DstDtype = char4;
};
template <>
struct DTypeRWHelper<char, 4> {
using InnerDtype = char4;
using DstDtype = char4;
};
template <int pack_w, int pack_c, typename SrcType, typename DnnSrcType,
typename DnnDstType, bool same_scale>
struct Translayout {
using InnerDtype = typename DTypeRWHelper<SrcType, pack_w>::InnerDtype;
using DstDtype = typename DTypeRWHelper<SrcType, pack_w>::DstDtype;
static inline __device__ void trans(DstDtype (&dst_width)[pack_w],
InnerDtype (&read_channel)[pack_c],
const char zero_point);
};
template <typename SrcType, typename DnnSrcType, typename DnnDstType,
bool same_scale>
struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> {
using InnerDtype = typename DTypeRWHelper<SrcType, 1>::InnerDtype;
using DstDtype = typename DTypeRWHelper<SrcType, 1>::DstDtype;
static inline __device__ void trans(
DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4],
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process,
const char zero_point) {
dst_width[0].x = post_process(read_channel[0]);
dst_width[0].y = post_process(read_channel[1]);
dst_width[0].z = post_process(read_channel[2]);
dst_width[0].w = post_process(read_channel[3]);
}
};
template <typename SrcType, typename DnnSrcType, typename DnnDstType,
bool same_scale>
struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> {
using InnerDtype = typename DTypeRWHelper<SrcType, 4>::InnerDtype;
using DstDtype = typename DTypeRWHelper<SrcType, 4>::DstDtype;
static inline __device__ void trans(
DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4],
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process,
const char zero_point) {
dst_width[0].x = post_process(read_channel[0].x);
dst_width[0].y = post_process(read_channel[1].x);
dst_width[0].z = post_process(read_channel[2].x);
dst_width[0].w = post_process(read_channel[3].x);
dst_width[1].x = post_process(read_channel[0].y);
dst_width[1].y = post_process(read_channel[1].y);
dst_width[1].z = post_process(read_channel[2].y);
dst_width[1].w = post_process(read_channel[3].y);
dst_width[2].x = post_process(read_channel[0].z);
dst_width[2].y = post_process(read_channel[1].z);
dst_width[2].z = post_process(read_channel[2].z);
dst_width[2].w = post_process(read_channel[3].z);
dst_width[3].x = post_process(read_channel[0].w);
dst_width[3].y = post_process(read_channel[1].w);
dst_width[3].z = post_process(read_channel[2].w);
dst_width[3].w = post_process(read_channel[3].w);
}
};
template <typename DstType>
inline __device__ DstType make_zero_pad(const char zero_point) {
return zero_point;
}
template <>
inline __device__ char4 make_zero_pad<char4>(const char zero_point) {
return {zero_point, zero_point, zero_point, zero_point};
}
template <typename DstDtype>
inline __device__ void write_helper(DstDtype* ptr, DstDtype val) {
*ptr = val;
}
template <>
inline __device__ void write_helper<char4>(char4* ptr, char4 val) {
int32_t* rel_ptr = (int32_t*)ptr;
*rel_ptr = *(int32_t*)(&val);
}
template <bool with_pad, int pack_w, int pack_c, bool same_scale,
typename SrcType, typename DstType, typename DnnSrcType,
typename DnnDstType>
struct RelayoutKern {
using InnerDtype = typename DTypeRWHelper<SrcType, pack_w>::InnerDtype;
using DstDtype = typename DTypeRWHelper<SrcType, pack_w>::DstDtype;
static inline __device__ void write(DstType* dst_ptr,
char4 (&dst_width)[pack_w]) {
DstDtype* dst_inner_ptr = (DstDtype*)dst_ptr;
#pragma unroll
for (int iw_idx = 0; iw_idx < pack_w; ++iw_idx) {
write_helper(dst_inner_ptr + iw_idx, dst_width[iw_idx]);
}
}
static inline __device__ void read(const SrcType* src_ptr,
InnerDtype (&read_channel)[pack_c],
const int ic_stride) {
#pragma unroll
for (int ic_idx = 0; ic_idx < pack_c; ++ic_idx) {
read_channel[ic_idx] = *(InnerDtype*)(src_ptr + ic_idx * ic_stride);
}
}
static inline __device__ void read_with_pad(
const SrcType* src_ptr, InnerDtype (&read_channel)[pack_c],
const int ic_stride, const int remain_ic,
const InnerDtype zero_point) {
#pragma unroll
for (int ic_idx = 0; ic_idx < pack_c; ++ic_idx) {
read_channel[ic_idx] =
ic_idx < remain_ic
? *(InnerDtype*)(src_ptr + ic_idx * ic_stride)
: zero_point;
}
}
static inline __device__ void core_relayout_kern(
const SrcType* src, DstType* dst, const int src_offset_base,
const int dst_offset_base, const int ic_offset, const int ic_stride,
const int remain_ic,
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process,
const char zero_point) {
InnerDtype read_channel[pack_c];
if (with_pad) {
const InnerDtype zero_pad = make_zero_pad<InnerDtype>(zero_point);
read_with_pad(src + ic_offset + src_offset_base, read_channel,
ic_stride, remain_ic, zero_pad);
} else {
read(src + ic_offset + src_offset_base, read_channel, ic_stride);
}
DstDtype dst_width[pack_w];
Translayout<pack_w, pack_c, SrcType, DnnSrcType, DnnDstType,
same_scale>::trans(dst_width, read_channel, post_process,
zero_point);
write(dst + ic_offset + dst_offset_base, dst_width);
}
};
template <int pack_w, bool same_scale, typename SrcType, typename DstType,
typename DnnSrcType, typename DnnDstType>
__global__ void kern_nchw_nchw4(
const SrcType* src, DstType* dst, int ic, int ihw, int n_stride_src,
int ic_stride, int n_stride_dst,
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process,
const char zero_point) {
constexpr int pack_c = 4;
const int n_idx = blockIdx.y;
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int ihw_offset = ihw_block_idx * pack_w;
if (ihw_offset < ihw) {
const int ic_block = ic / pack_c;
const int remain_ic = ic % pack_c;
const int src_offset_base = n_idx * n_stride_src + ihw_offset;
const int dst_offset_base = n_idx * n_stride_dst + ihw_offset * pack_c;
for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) {
const int ic_offset = ic_blk_idx * pack_c * ic_stride;
RelayoutKern<false, pack_w, pack_c, same_scale, SrcType, DstType,
DnnSrcType,
DnnDstType>::core_relayout_kern(src, dst,
src_offset_base,
dst_offset_base,
ic_offset, ic_stride,
remain_ic,
post_process,
zero_point);
}
if (remain_ic > 0) {
const int ic_offset = ic_block * pack_c * ic_stride;
RelayoutKern<true, pack_w, pack_c, same_scale, SrcType, DstType,
DnnSrcType,
DnnDstType>::core_relayout_kern(src, dst,
src_offset_base,
dst_offset_base,
ic_offset, ic_stride,
remain_ic,
post_process,
zero_point);
}
}
}
} // namespace
template <int pack_w = 1>
void relayout_format::relayout_format_cuda_exec(
const TensorND& src, const TensorND& dst, const cudaStream_t& stream,
const float src_scale, const float dst_scale,
const uint8_t src_zero_point, const uint8_t dst_zero_point) {
constexpr int pack_oc = 4;
const int n = src.layout[0];
const int c = src.layout[1];
const int h = src.layout[2];
const int w = src.layout[3];
const int hw = h * w;
const int oc_block = DIVUP(c, pack_oc);
const int n_stride_src = c * hw;
const int ic_stride = hw;
const int n_stride_dst = oc_block * pack_oc * h * w;
auto& src_layout = src.layout;
auto& dst_layout = dst.layout;
bool same_scale = src_scale == dst_scale;
#define RUN_KERNEL(same_scale, SRC_TYPE, DST_TYPE, SRC_C_TYPE, DST_C_TYPE) \
if (same_scale) { \
int nr_threads = query_blocksize_for_kernel( \
kern_nchw_nchw4<pack_w, true, SRC_C_TYPE, DST_C_TYPE, \
SRC_TYPE, DST_TYPE>); \
const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), n); \
const dim3 thread_dim(nr_threads); \
kern_nchw_nchw4<pack_w, true><<<block_dim, thread_dim, 0, stream>>>( \
(SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, c, hw, \
n_stride_src, ic_stride, n_stride_dst, \
CudaPostProcess<SRC_TYPE, DST_TYPE, true>( \
src_scale, src_zero_point, dst_scale, dst_zero_point), \
src_zero_point); \
} else { \
int nr_threads = query_blocksize_for_kernel( \
kern_nchw_nchw4<pack_w, false, SRC_C_TYPE, DST_C_TYPE, \
SRC_TYPE, DST_TYPE>); \
const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), n); \
const dim3 thread_dim(nr_threads); \
kern_nchw_nchw4<pack_w, false><<<block_dim, thread_dim, 0, stream>>>( \
(SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, c, hw, \
n_stride_src, ic_stride, n_stride_dst, \
CudaPostProcess<SRC_TYPE, DST_TYPE, false>( \
src_scale, src_zero_point, dst_scale, dst_zero_point), \
src_zero_point); \
}
if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Uint8 &&
dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) {
RUN_KERNEL(same_scale, dtype::Uint8, dtype::QuantizedS8, char, char);
} else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized8Asymm &&
dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) {
RUN_KERNEL(same_scale, dtype::Quantized8Asymm, dtype::QuantizedS8, char,
char);
} else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8 &&
dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) {
RUN_KERNEL(same_scale, dtype::QuantizedS8, dtype::QuantizedS8, char,
char);
} else {
megdnn_assert(0, "not support dtype %s %s", src_layout.dtype.name(),
dst_layout.dtype.name());
}
}
bool relayout_format::relayout_format_cuda_usable(
const TensorLayout& src_layout, const TensorLayout& dst_layout) {
bool is_all_continue =
src_layout.is_contiguous() && dst_layout.is_contiguous();
bool is_all_int8 =
(src_layout.dtype.enumv().ev == DTypeEnum::Ev::Uint8 &&
dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) ||
(src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized8Asymm &&
dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) ||
(src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8 &&
dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8);
return is_all_continue && is_all_int8;
}
template void relayout_format::relayout_format_cuda_exec<1>(
const TensorND& src, const TensorND& dst, const cudaStream_t& stream,
const float src_scale, const float dst_scale,
const uint8_t src_zero_point, const uint8_t dst_zero_point);
template void relayout_format::relayout_format_cuda_exec<4>(
const TensorND& src, const TensorND& dst, const cudaStream_t& stream,
const float src_scale, const float dst_scale,
const uint8_t src_zero_point, const uint8_t dst_zero_point);
/**
* \file dnn/src/cuda/relayout_format/relayout_format.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/basic_types.h"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace relayout_format {
template <int pack_w = 1>
void relayout_format_cuda_exec(const TensorND& src, const TensorND& dst,
const cudaStream_t& stream,
const float src_scale = 1.f,
const float dst_scale = 1.f,
const uint8_t src_zero_point = 0,
const uint8_t dst_zero_point = 0);
bool relayout_format_cuda_usable(const TensorLayout& src_layout,
const TensorLayout& dst_layout);
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/relayout_format/relayout_format.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/basic_types.h"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace relayout_format {
struct RelayoutFormatFast {
static bool usable(const TensorLayout& src_layout,
const TensorLayout& dst_layout);
static void exec(const TensorND& src, const TensorND& dst,
cudaStream_t stream);
};
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -6,11 +6,12 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/relayout_format/opr_impl.h"
#include "src/naive/handle.h"
#include "src/naive/relayout_format/opr_impl.h"
#include "megdnn/tensor_iter.h"
......@@ -44,7 +45,7 @@ void padding_src_to_workspace(dtype* dptr, const dtype* sptr, size_t N,
template <typename dtype>
void padding_to_workspace(dtype* dptr, const dtype* sptr,
const TensorLayout& src_layout, const size_t pad_axis,
const size_t align_size) {
const size_t align_size, const int pad_val = 0) {
megdnn_assert(pad_axis < src_layout.ndim);
const size_t axis_dim = src_layout[pad_axis];
const size_t axis_dim_padded = round_up(axis_dim, align_size);
......@@ -64,14 +65,16 @@ void padding_to_workspace(dtype* dptr, const dtype* sptr,
sptr[src_inner_offset + inner_idx_offset];
} else {
dptr[dst_outer_offset + inner_idx_offset] =
static_cast<dtype>(0);
static_cast<dtype>(pad_val);
}
}
}
}
}
void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src,
const size_t pad_axis, const size_t align_size) {
const size_t pad_axis, const size_t align_size,
DType exec_dst_dtype) {
switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \
case (DTypeEnum::name): { \
......@@ -84,8 +87,27 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src,
cb(Float32, dt_float32);
cb(QuantizedS8, dt_qint8);
case (DTypeEnum::Quantized8Asymm): {
dt_quint8* sptr = src.compatible_ptr<dt_quint8>();
dt_quint8* dptr = dst.compatible_ptr<dt_quint8>();
padding_to_workspace<dt_quint8>(
dptr, sptr, src.layout, pad_axis, align_size,
src.layout.dtype.param<dtype::Quantized8Asymm>()
.zero_point);
break;
}
case (DTypeEnum::Uint8): {
uint8_t* sptr = src.compatible_ptr<uint8_t>();
uint8_t* dptr = dst.compatible_ptr<uint8_t>();
uint8_t zero_point =
exec_dst_dtype.enumv() == DTypeEnum::QuantizedS8 ? 128 : 0;
padding_to_workspace<uint8_t>(dptr, sptr, src.layout, pad_axis,
align_size, zero_point);
break;
}
default:
megdnn_assert(0);
megdnn_assert(0, "not support dtype %s", src.layout.dtype.name());
#undef cb
}
}
......@@ -108,6 +130,57 @@ void padding_filter_to_workspace(dtype* dptr, const dtype* sptr, size_t OC,
}
}
}
void do_copy_diff_qu8_q8(const TensorND& dst, const TensorND& src) {
auto isrc =
tensor_iter_valonly<DTypeTrait<dtype::Quantized8Asymm>::ctype>(src)
.begin();
auto idst = tensor_iter_valonly<DTypeTrait<dtype::QuantizedS8>::ctype>(dst)
.begin();
auto src_dt_parm = src.layout.dtype.param<dtype::Quantized8Asymm>();
auto dst_dt_parm = dst.layout.dtype.param<dtype::QuantizedS8>();
for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) {
*idst = dst_dt_parm.quantize(src_dt_parm.dequantize(*isrc));
++idst;
++isrc;
}
}
void do_copy_diff_q8_q8(const TensorND& dst, const TensorND& src) {
auto isrc = tensor_iter_valonly<DTypeTrait<dtype::QuantizedS8>::ctype>(src)
.begin();
auto idst = tensor_iter_valonly<DTypeTrait<dtype::QuantizedS8>::ctype>(dst)
.begin();
auto src_dt_parm = src.layout.dtype.param<dtype::QuantizedS8>();
auto dst_dt_parm = dst.layout.dtype.param<dtype::QuantizedS8>();
for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) {
*idst = dst_dt_parm.quantize(src_dt_parm.dequantize(*isrc));
++idst;
++isrc;
}
}
void do_copy_diff_u8_q8(const TensorND& dst, const TensorND& src) {
auto isrc =
tensor_iter_valonly<DTypeTrait<dtype::Uint8>::ctype>(src).begin();
auto idst = tensor_iter_valonly<DTypeTrait<dtype::QuantizedS8>::ctype>(dst)
.begin();
auto dst_dt_parm = dst.layout.dtype.param<dtype::QuantizedS8>();
for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) {
*idst = dst_dt_parm.quantize((float)(*isrc) - 128.f);
++idst;
++isrc;
}
}
void check_layout_and_canonize(TensorLayout& src, TensorLayout& dst) {
megdnn_assert(dst.is_non_overlapping_strong());
src = src.collapse_contiguous();
dst = dst.collapse_contiguous();
megdnn_assert(dst.dtype.valid() &&
src.total_nr_elems() == dst.total_nr_elems());
}
} // anonymous namespace
size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
......@@ -189,6 +262,13 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
size_t w = src[3];
return n * c * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW4: {
size_t n = src[0];
size_t c = round_up(src[1], 4_z);
size_t h = src[2];
size_t w = src[3];
return n * c * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: {
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5");
if (src[1] % 4 == 0)
......@@ -208,6 +288,8 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
megdnn_assert(src.layout.dtype.category() == DTypeCategory::FLOAT ||
(src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) ||
src.layout.dtype.category() == DTypeCategory::QUANTIZED);
check_exec(src.layout, dst.layout, workspace.size);
HandleImpl* m_handle = static_cast<HandleImpl*>(handle());
......@@ -284,7 +366,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
size_t val = src.layout[_idx]; \
if (val % _pack_size != 0) { \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
_pack_size); \
_pack_size, exec_dst.dtype); \
exec_src_nd.raw_ptr = workspace.raw_ptr; \
} \
} \
......@@ -301,11 +383,43 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) {
cb(1, 4, NCHW_NCHW4_IC_SMALL);
} else if (param().mode == Param::Mode::NCHW_NCHW4) {
cb(1, 4, NCHW_NCHW4);
} else if (param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
}
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
check_layout_and_canonize(src0.layout, src0.layout);
auto func = [](const TensorND& dst, const TensorND& src) {
do_copy_diff_qu8_q8(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
check_layout_and_canonize(src0.layout, src0.layout);
auto func = [](const TensorND& dst, const TensorND& src) {
do_copy_diff_u8_q8(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
check_layout_and_canonize(src0.layout, src0.layout);
auto func = [](const TensorND& dst, const TensorND& src) {
do_copy_diff_q8_q8(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else {
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
}
#undef cb
}
......
......@@ -6,10 +6,12 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
#include "test/cuda/fixture.h"
......@@ -24,6 +26,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT) {
param.mode = param::RelayoutFormat::Mode::NCHW4_CHWN4;
checker.set_dtype(0, dtype::QuantizedS8{0.1f})
.set_dtype(1, dtype::QuantizedS8{0.1f})
.set_rng(0, &rng)
.set_param(param)
.execs({{22, 23, 24, 25, 4}, {}});
......@@ -31,6 +34,164 @@ TEST_F(CUDA, RELAYOUT_FORMAT) {
checker.execs({{22, 23, 24, 25, 4}, {}});
}
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG rng{0, 50};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4;
for (size_t n : {1, 3}) {
for (size_t c : {1, 2, 3, 4, 8, 9, 11, 16}) {
for (size_t h : {3, 7, 12, 16, 22, 59, 83}) {
for (size_t w : {3, 22, 63, 128, 256}) {
checker.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{n, c, h, w}, {}});
checker.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{2.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{n, c, h, w}, {}});
}
}
}
}
checker.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{8, 3, 224, 224}, {}});
checker.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{8, 3, 600, 600}, {}});
checker.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{1, 6, 768, 1280}, {}});
}
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_DEFAULT) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG rng{0, 50};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4;
for (size_t n : {1, 3}) {
for (size_t c : {1, 2, 3, 4, 8, 9, 11, 16}) {
for (size_t h : {3, 7, 12, 16, 59, 83}) {
for (size_t w : {3, 63, 128, 256}) {
checker.set_dtype(0, dtype::Quantized8Asymm{1.f, 128})
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{n, c, h, w}, {}});
}
}
}
}
}
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_U8) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG rng{0, 255};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4;
for (size_t n : {1, 3}) {
for (size_t c : {1, 2, 3, 4, 8, 9, 11, 16}) {
for (size_t h : {3, 7, 12, 16, 59, 83}) {
for (size_t w : {3, 13, 3 * 4, 63 * 4, 128 * 4, 256 * 4}) {
checker.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{n, c, h, w}, {}});
checker.set_dtype(0, dtype::Quantized8Asymm{1.f, 128})
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{n, c, h, w}, {}});
checker.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::QuantizedS8{2.5f})
.set_rng(0, &rng)
.set_param(param)
.execs({{n, c, h, w}, {}});
}
}
}
}
}
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_IC_SMALL) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG rng{0, 50};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL;
checker.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{1.f})
.set_rng(0, &rng)
.set_param(param)
.execs({{8, 3, 768, 1280}, {}});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT) {
using Param = RelayoutFormat::Param;
auto run = [&](const TensorShapeArray& shapes, Param param,
Param default_param) {
Benchmarker<RelayoutFormat> benchmarker(handle_cuda());
benchmarker.set_param(param);
benchmarker.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{1.f});
Benchmarker<RelayoutFormat> benchmarker_default(handle_cuda());
benchmarker_default.set_param(default_param);
benchmarker_default.set_dtype(0, dtype::QuantizedS8{1.f})
.set_dtype(1, dtype::QuantizedS8{1.f});
for (auto&& shape : shapes) {
double memaccess = (double(shape.total_nr_elems()) +
double(shape[0]) * ((shape[1] + 3) / 4 * 4) *
shape[2] * shape[3]) *
1e-6;
auto time_ms = benchmarker.execs({shape, {}});
if (shape[1] <= 4) {
auto time_default_ms = benchmarker_default.execs({shape, {}});
printf("execute %s, time %.4f ms, %.4f GB/s, default %.4f "
"GB/s\n",
shape.to_string().c_str(), time_ms, memaccess / time_ms,
memaccess / time_default_ms);
} else {
printf("execute %s, time %.4f ms, %.4f GB/s\n",
shape.to_string().c_str(), time_ms, memaccess / time_ms);
}
}
};
TensorShapeArray shapes = {
{8, 1, 768, 1280}, {8, 3, 768, 1280}, {8, 3, 224, 224},
{8, 4, 768, 1280}, {64, 3, 768, 1280},
};
{
Param param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4;
Param default_param;
default_param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL;
run(shapes, param, default_param);
}
}
#endif
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG rng{-50, 50};
......@@ -39,7 +200,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) {
for (DType dtype :
std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) {
checker.set_dtype(0, dtype).set_rng(0, &rng);
checker.set_dtype(0, dtype).set_dtype(1, dtype).set_rng(0, &rng);
checker.set_param(param).execs({{2, 4, 35, 36}, {}});
checker.set_param(param).execs({{2, 3, 35, 36}, {}});
......
......@@ -219,7 +219,10 @@ R"__usage__(
Execute operators with weight preprocess, which can optimize the operator execution time with
algo of winograd, im2col ,etc., but it may consume more memory.
)__usage__"
R"__usage__(
--enable-fuse-preprocess
Fusion astype\pad_channel\dimshuffle and etc opr from h2d op
)__usage__"
;
struct DataParser {
......@@ -1141,6 +1144,11 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.graph_opt.enable_nchw44_dot();
continue;
}
if (!strcmp(argv[i], "--enable-fuse-preprocess")) {
mgb_log_warn("enable-fuse-preprocess optimization");
graph_opt.graph_opt.enable_fuse_preprocess();
continue;
}
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) {
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity();
......
......@@ -101,6 +101,8 @@ struct GraphCommonOptimizeOptions {
//! memory, default disable now, when weight preprocess is enabled, the
//! input shape should no change
bool weight_preprocess = false;
//! fuse preprocess patten, like astype + pad_channel + dimshuffle
bool fuse_preprocess = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format
......@@ -130,6 +132,7 @@ struct GraphCommonOptimizeOptions {
SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity);
SET(fuse_conv_bias_with_z);
SET(fuse_preprocess);
SET(weight_winograd_transform);
SET(weight_preprocess);
#undef SET
......
......@@ -724,6 +724,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
options.disable_##_option(); \
} \
}
cb(fuse_preprocess, {add_pass(FuseNCHW4Int8Preprocess::make());});
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
......@@ -761,6 +763,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass(EnableTensorCorePass::make_tensorcore_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
add_pass(FuseNCHW4Int8Preprocess::make());
});
cb(chwn4, {
add_pass<FuseConvBiasNonlinPass>();
......
/**
* \file src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/cond.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/serializer.h"
using namespace mgb;
using namespace gopt;
namespace {
#define RETURN_IF_FALSE(ok) \
{ \
if (!ok) \
return ok; \
}
struct SubGraphMatcher {
struct Node {
using CallBack = std::function<bool(OperatorNodeBase* opr)>;
Node(Typeinfo* in_op_type) : op_type(in_op_type){};
Node(Typeinfo* in_op_type, CallBack func)
: op_type(in_op_type), cbk(func){};
Node(Typeinfo* in_op_type, std::vector<Node> in_pre_node)
: op_type(in_op_type), pre_node(in_pre_node){};
Node(Typeinfo* in_op_type, std::vector<Node> in_pre_node, CallBack func)
: op_type(in_op_type), pre_node(in_pre_node), cbk(func){};
Typeinfo* op_type{nullptr};
std::vector<Node> pre_node;
//! cbk used to check param and gather args for creating fusion op
CallBack cbk;
};
bool match(Node& root, OperatorNodeBase* opr) {
if (opr == nullptr) {
return false;
}
//! match nullptr node always
if (root.op_type == nullptr || root.op_type == opr->dyn_typeinfo()) {
bool match_ok = true;
if (root.cbk)
match_ok &= root.cbk(opr);
RETURN_IF_FALSE(match_ok);
auto& inp = opr->input();
for (size_t node_idx = 0; node_idx < root.pre_node.size();
++node_idx) {
bool valid_node_idx = node_idx < inp.size();
RETURN_IF_FALSE(valid_node_idx);
match_ok &= match(root.pre_node[node_idx],
inp[node_idx]->owner_opr());
RETURN_IF_FALSE(match_ok);
}
return match_ok;
} else {
return false;
}
}
};
#undef RETURN_IF_FALSE
struct SubGraphChecker {
using DepType = cg::OperatorNodeProp::DepType;
using ReaderType =
ThinHashMap<OperatorNodeBase*,
SmallVector<std::pair<OperatorNodeBase*, DepType>>>;
SubGraphChecker() {}
bool check(ThinHashSet<OperatorNodeBase*> used_input,
OperatorNodeBase* start_opr, OperatorNodeBase* stop_opr,
ReaderType& readers, bool ignore_immutable = true) {
bool is_all_inp_used = check_all_inp_used(used_input, start_opr,
stop_opr, ignore_immutable);
bool is_all_dep_inside =
check_all_dep_inside_node(start_opr, stop_opr, readers);
return is_all_inp_used && is_all_dep_inside;
}
bool check_all_inp_used(ThinHashSet<OperatorNodeBase*>& used_input,
OperatorNodeBase* start_opr,
OperatorNodeBase* stop_opr,
bool ignore_immutable = true) {
ThinHashSet<OperatorNodeBase*> leaf_set;
get_leaf_node(start_opr, stop_opr, leaf_set);
for (auto in_opr : leaf_set) {
bool skip = in_opr->same_type<opr::ImmutableTensor>() &&
ignore_immutable;
if (used_input.find(in_opr) == used_input.end() && !skip) {
return false;
}
}
return true;
}
bool check_all_dep_inside_node(OperatorNodeBase* start_opr,
OperatorNodeBase* stop_opr,
ReaderType& readers) {
ThinHashSet<OperatorNodeBase*> mid_set;
get_mid_node(start_opr, start_opr, stop_opr, mid_set);
for (auto inner_opr : mid_set) {
if (readers.find(inner_opr) != readers.end()) {
for (auto& out_node : readers[inner_opr]) {
if (mid_set.find(out_node.first) == mid_set.end() &&
out_node.first != start_opr &&
out_node.second ==
cg::OperatorNodeProp::DepType::DEV_VALUE) {
return false;
}
}
}
}
return true;
}
void get_mid_node(OperatorNodeBase* opr, OperatorNodeBase* start_opr,
OperatorNodeBase* stop_opr,
ThinHashSet<OperatorNodeBase*>& mid_set) {
if (opr == nullptr) {
return;
}
if (opr != start_opr) {
mid_set.insert(opr);
}
if (opr == stop_opr) {
return;
}
for (auto& tensor : opr->input()) {
auto pre_opr = tensor->owner_opr();
get_mid_node(pre_opr, start_opr, stop_opr, mid_set);
}
}
void get_leaf_node(OperatorNodeBase* opr, OperatorNodeBase* stop_opr,
ThinHashSet<OperatorNodeBase*>& leaf_set) {
if (opr == nullptr) {
return;
}
if (opr == stop_opr || opr->input().size() == 0) {
leaf_set.insert(opr);
}
if (opr == stop_opr) {
return;
}
for (auto& tensor : opr->input()) {
auto pre_opr = tensor->owner_opr();
get_leaf_node(pre_opr, stop_opr, leaf_set);
}
}
};
static inline bool is_shape_nchw(const TensorShape& shape) {
return shape.ndim == 4;
}
static inline bool is_shape_before_nchw4(const TensorShape& shape) {
return shape.ndim == 5 && shape[2] == 4;
}
static inline bool is_nchw_nchw4_shuffle_vec(
const opr::Dimshuffle::Param param) {
return param.ndim == 5 && param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 3 && param.pattern[3] == 4 &&
param.pattern[4] == 2;
}
template <typename T>
static inline bool is_immutable_equal(OperatorNodeBase* opr, T val,
DTypeEnum dtype_enum) {
auto const_opr = opr->try_cast_final<opr::ImmutableTensor>();
if (!const_opr) {
return false;
}
auto& host_value = const_opr->host_value();
bool ok_value = host_value.layout().total_nr_elems() == 1 &&
host_value.dtype().enumv() == dtype_enum &&
host_value.ptr<T>()[0] == val;
return ok_value;
}
template <typename T>
static inline bool is_immutable_all_equal(OperatorNodeBase* opr,
typename DTypeTrait<T>::ctype val) {
auto const_opr = opr->try_cast_final<opr::ImmutableTensor>();
if (!const_opr) {
return false;
}
auto& host_value = const_opr->host_value();
bool ok_value = host_value.dtype().enumv() == DTypeTrait<T>::enumv;
if (!ok_value) {
return false;
}
size_t nr_elem = host_value.layout().total_nr_elems();
for (size_t i = 0; i < nr_elem; ++i) {
if (host_value.ptr<typename DTypeTrait<T>::ctype>()[i] != val) {
ok_value = false;
break;
}
}
return ok_value;
}
} // namespace
const char* FuseNCHW4Int8Preprocess::name() const {
return "fuse_pre_process_pass";
}
std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
using SGM = SubGraphMatcher;
auto gen_pad_dimshuffle_graph = [&](SGM::Node& in_node,
SGM::Node::CallBack& pad_cbk,
SGM::Node::CallBack& shape_cbk) {
SGM::Node::CallBack check_pad = [&](OperatorNodeBase* opr) {
SGM sub_matcher;
SGM::Node immu_node{opr::ImmutableTensor::typeinfo(), pad_cbk};
if (opr->same_type<opr::ImmutableTensor>()) {
return sub_matcher.match(immu_node, opr);
} else if (opr->same_type<opr::Broadcast>()) {
return sub_matcher.match(immu_node,
opr->input()[0]->owner_opr());
} else {
return false;
}
};
SGM::Node broadcast_or_immutable{nullptr, check_pad};
SGM::Node broadcast_concat{
opr::Concat::typeinfo(),
{in_node, broadcast_or_immutable},
[](OperatorNodeBase* opr) {
auto concat_pad = opr->try_cast_final<opr::Concat>();
return concat_pad->axis() == 1;
}};
SGM::Node nchwx_reshape{opr::Reshape::typeinfo(),
{broadcast_concat, SGM::Node(nullptr)},
[](OperatorNodeBase* opr) {
auto inp0 = opr->input()[0];
return is_shape_nchw(inp0->shape());
}};
SGM::Node shuffle_root{
opr::Dimshuffle::typeinfo(),
{nchwx_reshape},
[](OperatorNodeBase* opr) {
auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>();
auto& input_vec = shuffle_opr.input();
return is_shape_before_nchw4(input_vec[0]->shape()) &&
is_nchw_nchw4_shuffle_vec(shuffle_opr.param());
}};
return shuffle_root;
};
auto replace_shuffle_opr = [&](OperatorNodeBase* opr,
const VarNodeArray& new_inp,
SubGraph::Rewriter& rewriter,
ReaderType& reader) {
SGM matcher;
OperatorNodeBase* src_node = nullptr;
SGM::Node input_data_cp{
nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Quantized8Asymm) {
src_node = opr;
return true;
} else {
return false;
}
}};
SGM::Node type_cvt{opr::TypeCvt::typeinfo(), {input_data_cp}};
SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
bool is_i32_pad = is_immutable_all_equal<dtype::Int32>(opr, 0);
bool is_q8_pad = is_immutable_all_equal<dtype::QuantizedS8>(
opr, dt_qint8(0));
return is_fp32_pad || is_i32_pad || is_q8_pad;
};
SGM::Node::CallBack const_reshape_cbk = [](OperatorNodeBase* opr) {
return true;
};
auto&& shuffle_root = gen_pad_dimshuffle_graph(type_cvt, const_pad_cbk,
const_reshape_cbk);
bool match = matcher.match(shuffle_root, opr);
bool check_ok = false;
if (match) {
check_ok =
SubGraphChecker().check({src_node}, opr, src_node, reader);
}
if (match && check_ok) {
opr::RelayoutFormat::Param param;
param.mode = opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
OperatorNodeConfig config(opr->output()[0]->dtype());
auto out_node = opr::RelayoutFormat::make(
rewriter.get_var(src_node->output()[0]), param.mode,
config);
return out_node.node()->owner_opr();
} else {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
auto replace_astype_opr = [&](OperatorNodeBase* opr,
const VarNodeArray& new_inp,
SubGraph::Rewriter& rewriter,
ReaderType& reader) {
SGM matcher;
OperatorNodeBase* src_node = nullptr;
OperatorNodeBase* neg_128_immu_node = nullptr;
OperatorNodeBase* pad0_immu_node = nullptr;
OperatorNodeBase* const_reshape_last_dim_node = nullptr;
SGM::Node input_data_cp{nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Uint8) {
src_node = opr;
return true;
} else {
return false;
}
}};
SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(),
{input_data_cp},
[](OperatorNodeBase* opr) {
auto cvt_op =
opr->try_cast_final<opr::TypeCvt>();
bool is_fp32 = cvt_op->param().enumv() ==
DTypeEnum::Float32;
return is_fp32;
}};
SGM::Node sub_128{
opr::Elemwise::typeinfo(),
{cvt_fp32},
[&](OperatorNodeBase* opr) {
auto elem_op = opr->try_cast_final<opr::Elemwise>();
bool is_add_op = elem_op->param().mode ==
opr::Elemwise::Param::Mode::ADD;
auto neg_128_op = elem_op->input()[1]->owner_opr();
bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f,
DTypeEnum::Float32);
neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr;
return is_add_op && is_neg_128;
}};
SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
pad0_immu_node = opr;
bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
bool is_i32_pad = is_immutable_all_equal<dtype::Int32>(opr, 0);
return is_fp32_pad || is_i32_pad;
};
SGM::Node::CallBack const_reshape_cbk = [&](OperatorNodeBase* opr) {
const_reshape_last_dim_node = opr;
return true;
};
auto&& shuffle_root = gen_pad_dimshuffle_graph(sub_128, const_pad_cbk,
const_reshape_cbk);
SGM::Node astype_root{opr::TypeCvt::typeinfo(), {shuffle_root}};
bool match = matcher.match(astype_root, opr);
bool check_ok = false;
if (match) {
check_ok = SubGraphChecker().check(
{src_node, neg_128_immu_node, pad0_immu_node,
const_reshape_last_dim_node},
opr, src_node, reader);
}
if (match && check_ok) {
opr::RelayoutFormat::Param param;
param.mode = opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
OperatorNodeConfig config(opr->output()[0]->dtype());
auto out_node = opr::RelayoutFormat::make(
rewriter.get_var(src_node->output()[0]), param.mode,
config);
return out_node.node()->owner_opr();
} else {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
auto ret = std::make_unique<FuseNCHW4Int8Preprocess>();
auto&& replace_func = ret->m_opr_replace_func;
MGB_MARK_USED_VAR(replace_astype_opr);
MGB_MARK_USED_VAR(replace_shuffle_opr);
replace_func[opr::Dimshuffle::typeinfo()] = replace_shuffle_opr;
replace_func[opr::TypeCvt::typeinfo()] = replace_astype_opr;
return ret;
}
void FuseNCHW4Int8Preprocess::apply(OptState& state) const {
state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_DTYPE |
VarReplaceCheckFlag::CHECK_SHAPE);
auto rewriter = state.graph().make_rewriter();
VarNodeArray new_inp_cache;
ReaderType readers;
state.graph().iter([&readers](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
readers[i.first->owner_opr()].emplace_back(opr, i.second);
}
});
auto on_opr = [this, &rewriter, &new_inp_cache,
&readers](OperatorNodeBase* opr) {
auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
if (it != m_opr_replace_func.end()) {
auto&& new_inp = new_inp_cache;
new_inp.clear();
new_inp.reserve(opr->input().size());
for (auto i : opr->input()) {
new_inp.push_back(rewriter.get_var(i));
}
auto new_opr = (it->second)(opr, new_inp, rewriter, readers);
if (new_opr->try_cast_final<opr::RelayoutFormat>()) {
auto &&origin_out = opr->output(),
&&cur_out = new_opr->output();
rewriter.replace_var(origin_out[0], cur_out[0], nullptr);
} else {
auto &&origin_out = opr->output(),
&&cur_out = new_opr->output();
mgb_assert(origin_out.size() == cur_out.size(),
"bad opr replace: src=%s{%s} dst=%s{%s}, %zu != %zu",
opr->cname(), opr->dyn_typeinfo()->name,
new_opr->cname(), new_opr->dyn_typeinfo()->name,
origin_out.size(), cur_out.size());
for (size_t i = 0; i < origin_out.size(); i++) {
rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
}
}
} else {
rewriter.auto_replace_outputs(opr);
}
};
state.graph().iter(on_opr);
rewriter.apply_inplace();
}
\ No newline at end of file
......@@ -152,6 +152,26 @@ namespace gopt {
void apply(OptState& opt) const override;
};
/*!
* \brief fuse preprocess, like pad channel, quint8 to qint8
*/
class FuseNCHW4Int8Preprocess : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
static std::unique_ptr<FuseNCHW4Int8Preprocess> make();
using DepType = cg::OperatorNodeProp::DepType;
using ReaderType =
ThinHashMap<OperatorNodeBase*,
SmallVector<std::pair<OperatorNodeBase*, DepType>>>;
private:
ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*(
OperatorNodeBase*, const VarNodeArray&,
SubGraph::Rewriter&, ReaderType&)>>
m_opr_replace_func;
};
/*!
* \brief fuse deconv and typecvt to a deconv opr
*/
......
......@@ -719,15 +719,15 @@ TEST(TestGoptInference, Float16IOFloat32ComputeDeConv) {
};
graph->options().graph_opt_level = 0;
auto s0 = mkvar("s0", {5, 5, 3, 3}),
s1 = mkvar("s1", {1, 5, INP_H, INP_W});
auto s0 = mkvar("s0", {5, 5, 3, 3}), s1 = mkvar("s1", {1, 5, INP_H, INP_W});
auto y = opr::ConvolutionBackwardData::make(s0, s1, {}, {});
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_f16_io_f32_comp();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(find_opr<opr::ConvolutionBackwardData>(y_opt).param().compute_mode,
opr::ConvBias::Param::ConvBias::ComputeMode::FLOAT32);
ASSERT_EQ(
find_opr<opr::ConvolutionBackwardData>(y_opt).param().compute_mode,
opr::ConvBias::Param::ConvBias::ComputeMode::FLOAT32);
ASSERT_EQ(y_opt.dtype(), dtype::Float32());
HostTensorND host_y, host_y_opt;
......@@ -1603,7 +1603,6 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) {
}
}
TEST(TestGoptInference, ParamMerge) {
auto cns = load_multiple_xpus(2);
HostTensorGenerator<> gen;
......@@ -3364,14 +3363,14 @@ TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) {
auto b = mkvar("b", {1, 1, 16, 16}),
elem0 = opr::Elemwise::make({conv1 + b + b},
opr::Elemwise::Param::Mode::RELU);
opr::Elemwise::Param::Mode::RELU);
auto w2 = mkcvar("w2", {8, 8, 3, 3}),
conv2 = opr::Convolution::make(elem0, w2, param_conv);
auto b1 = mkvar("b1", {1}),
y = opr::Elemwise::make({conv2 + b1 + b},
opr::Elemwise::Param::Mode::RELU);
opr::Elemwise::Param::Mode::RELU);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
......@@ -3631,4 +3630,97 @@ TEST(TestGoptInference, ConvertFormatCD4GroupOneConv) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
#if MGB_CUDA
TEST(TestGoptInference, PreProcessCase0) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Quantized8Asymm, RandomDistribution::UNIFORM>
gen(dt_quint8(0), dt_quint8(50), 1, 128, 1234);
auto cn = CompNode::load("gpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
size_t n = 1;
size_t c = 3;
size_t h = 16;
size_t w = 16;
auto host_x1 = gen({n, c, h, w}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto x_q8 = opr::TypeCvt::make(x, dtype::QuantizedS8(1.f), cn);
auto zero = DTypeScalar(dtype::QuantizedS8(1.f));
auto zero_tensor = opr::ImmutableTensor::make(*graph, zero, cn);
auto pad_channel_tensor =
opr::Broadcast::make(zero_tensor, {n, 1, h, w}, cn);
auto paded_x = opr::Concat::make({x_q8, pad_channel_tensor}, 1, cn)
.reshape({n, 1, 4, h, w});
auto result = opr::Dimshuffle::make(paded_x, {0, 1, 3, 4, 2}, 5, cn);
auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_preprocess();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.PreProcessCase0.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::RelayoutFormat>());
}
TEST(TestGoptInference, PreProcessCase1) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("gpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
size_t n = 1;
size_t c = 3;
size_t h = 16;
size_t w = 16;
auto host_x1 = gen({n, c, h, w}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto x_u8 = opr::TypeCvt::make(x, dtype::Float32(), cn);
auto x_s8 = x_u8 - 128;
auto zero = DTypeScalar(dtype::Float32());
auto zero_tensor = opr::ImmutableTensor::make(*graph, zero, cn);
auto pad_channel_tensor =
opr::Broadcast::make(zero_tensor, {n, 1, h, w}, cn);
auto paded_x = opr::Concat::make({x_s8, pad_channel_tensor}, 1, cn)
.reshape({n, 1, 4, h, w});
auto nchw4_out = opr::Dimshuffle::make(paded_x, {0, 1, 3, 4, 2}, 5, cn);
auto result = opr::TypeCvt::make(nchw4_out, dtype::QuantizedS8(1.f));
auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_preprocess();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.PreProcessCase1.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::RelayoutFormat>());
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -198,7 +198,7 @@ Elemwise::Elemwise(
param.mode == Param::Mode::MAX ||
param.mode == Param::Mode::MIN,
"Only ADD, SUB, NEGATE, RELU, MAX and MIN is guaranteed "
"to be supported on Elemwise for quantized DType");
"to be supported on Elemwise for quantized DType, no support %d", (int)param.mode);
}
}
......
......@@ -1578,6 +1578,23 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) {
// f}}}
/* f{{{ ======================= RelayoutFormat ======================= */
namespace mgb {
namespace opr {
namespace intl {
template <>
struct MegDNNOprInitPostCtor<RelayoutFormat> {
static void apply(cg::OperatorNodeBase& opr) {
if (opr.config().output_dtype().valid()) {
opr.output(0)->dtype(opr.config().output_dtype());
} else {
opr.output(0)->dtype(opr.input(0)->dtype());
}
}
};
} // namespace intl
} // namespace opr
} // namespace mgb
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RelayoutFormat);
MEGDNN_OPR_INIT1(RelayoutFormat, "relayout_format")
......
......@@ -190,6 +190,24 @@ namespace mgb {
}
return ret;
}
std::shared_ptr<HostTensorND>
HostTensorGenerator<dtype::Quantized8Asymm, RandomDistribution::UNIFORM>::
operator()(const TensorShape& shape, CompNode cn) {
if (!cn.valid())
cn = CompNode::load("xpu0");
auto dtype = dtype::Quantized8Asymm(m_scale, m_zero_point);
auto param = dtype.param();
std::shared_ptr<HostTensorND> ret =
std::make_shared<HostTensorND>(cn, shape, dtype);
auto ptr = ret->ptr<dt_quint8>();
double scale = (param.dequantize(m_hi) - param.dequantize(m_lo)) /
(m_rng.max() + 1.0);
for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++i) {
ptr[i] = param.quantize(m_rng() * scale + param.dequantize(m_lo));
}
return ret;
}
}
::testing::AssertionResult mgb::__assert_float_equal(
......
......@@ -264,6 +264,10 @@ struct UniformRNGDefaultRange<dtype::QuantizedS8> {
static const dt_qint8 LO, HI;
};
template<>
struct UniformRNGDefaultRange<dtype::Quantized8Asymm> {
static const dt_quint8 LO, HI;
};
//! gaussian
template<class dtype>
class HostTensorGenerator<dtype, RandomDistribution::GAUSSIAN> final:
......@@ -404,6 +408,33 @@ class HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM> final
ctype m_lo, m_hi;
};
template <>
class HostTensorGenerator<dtype::Quantized8Asymm, RandomDistribution::UNIFORM>
final : public HostTensorGeneratorBase {
public:
using ctype = typename DTypeTrait<dtype::Quantized8Asymm>::ctype;
HostTensorGenerator(
ctype lo = UniformRNGDefaultRange<dtype::Quantized8Asymm>::LO,
ctype hi = UniformRNGDefaultRange<dtype::Quantized8Asymm>::HI,
float scale = 1.f, uint8_t zero_point = 0,
uint64_t seed = next_rand_seed())
: HostTensorGeneratorBase{seed},
m_scale{scale},
m_zero_point(zero_point),
m_lo{lo},
m_hi{hi} {}
std::shared_ptr<HostTensorND> operator()(const TensorShape& shape,
CompNode cn = {}) override;
using HostTensorGeneratorBase::operator();
private:
float m_scale;
uint8_t m_zero_point;
ctype m_lo, m_hi;
};
/*!
* \brief get output file name in test output dir
* \param check_writable whether to ensure the file is writable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册