提交 543c9b77 编写于 作者: M Megvii Engine Team

feat(dnn): add RegionRestrictedConv cuda

GitOrigin-RevId: b9f2d34a136d19590ca08ddcf7d167944d24aa6a
上级 fdec82ec
...@@ -23,6 +23,7 @@ std::string get_errmsg( ...@@ -23,6 +23,7 @@ std::string get_errmsg(
"dilate_h=" + std::to_string(param.dilate_h) + ", " + "dilate_h=" + std::to_string(param.dilate_h) + ", " +
"dilate_w=" + std::to_string(param.dilate_w); "dilate_w=" + std::to_string(param.dilate_w);
} }
} // namespace } // namespace
namespace megdnn { namespace megdnn {
...@@ -31,7 +32,12 @@ void RegionRestrictedConvolutionForward::deduce_dtype( ...@@ -31,7 +32,12 @@ void RegionRestrictedConvolutionForward::deduce_dtype(
DType src, DType filter, DType rin, DType rout, DType& dst) { DType src, DType filter, DType rin, DType rout, DType& dst) {
check_or_deduce_dtype_fwd(src, filter, dst); check_or_deduce_dtype_fwd(src, filter, dst);
megdnn_assert( megdnn_assert(
rin == rout && rin == dtype::Int32(), src.category() == DTypeCategory::FLOAT &&
filter.category() == DTypeCategory::FLOAT &&
dst.category() == DTypeCategory::FLOAT,
"only float type is supported for region_restricted_conv forward");
megdnn_assert(
rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()),
"the dtype of rin/rout should be Int32, got %s.", rin.name()); "the dtype of rin/rout should be Int32, got %s.", rin.name());
} }
...@@ -51,6 +57,9 @@ RegionRestrictedConvolutionForward::check_exec( ...@@ -51,6 +57,9 @@ RegionRestrictedConvolutionForward::check_exec(
megdnn_assert( megdnn_assert(
param().format == Param::Format::NCHW, param().format == Param::Format::NCHW,
"RegionRestrictedConv only support NCHW format mow."); "RegionRestrictedConv only support NCHW format mow.");
megdnn_assert(
param().stride_h == 1 && param().stride_w == 1,
"RegionRestrictedConv only support stride 1.");
#define err_msg(lhs, rhs) \ #define err_msg(lhs, rhs) \
megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs);
......
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include "src/cuda/pooling/opr_impl.h" #include "src/cuda/pooling/opr_impl.h"
#include "src/cuda/powc/opr_impl.h" #include "src/cuda/powc/opr_impl.h"
#include "src/cuda/reduce/opr_impl.h" #include "src/cuda/reduce/opr_impl.h"
#include "src/cuda/region_restricted_convolution/opr_impl.h"
#include "src/cuda/relayout/opr_impl.h" #include "src/cuda/relayout/opr_impl.h"
#include "src/cuda/relayout_format/opr_impl.h" #include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/remap/opr_impl.h" #include "src/cuda/remap/opr_impl.h"
...@@ -218,6 +219,9 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); ...@@ -218,6 +219,9 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(NormForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(NormForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter);
template <typename Opr> template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() { std::unique_ptr<Opr> HandleImpl::create_operator() {
......
#include "./kern.cuh"
#include "cuda.h"
#include "cuda_fp16.h"
#include "src/cuda/fp16_help.cuh"
using namespace megdnn;
using namespace cuda;
using namespace region_restricted_convolution;
using namespace chanwise;
#include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh"
namespace megdnn {
namespace cuda {
namespace region_restricted_convolution {
namespace chanwise {
// =====================================fwd=====================================
template <>
void run_bwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const int* rin, const int* rout,
const Param& param, cudaStream_t stream) {
INSTANCE_INT(float, int, DepthwiseConv2dDirection::DIRECTION_BACKWARD)
}
template <>
void run_bwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const uint8_t* rin,
const uint8_t* rout, const Param& param, cudaStream_t stream) {
INSTANCE_UINT8(float, uint8_t, DepthwiseConv2dDirection::DIRECTION_BACKWARD)
}
} // namespace chanwise
} // namespace region_restricted_convolution
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
#pragma once
namespace {
#define DIVUP(x, y) (((x) + (y)-1) / (y))
enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
template <typename ThreadConfig_, int oh_, int ow_>
struct OutTileConfig {
using ThreadConfig = ThreadConfig_;
static int constexpr unroll_h = oh_;
static int constexpr unroll_w = ThreadConfig::thread_x * ow_;
static int constexpr unroll_size = unroll_h * unroll_w;
static int constexpr block_h = unroll_h * ThreadConfig::thread_y;
static int constexpr block_w = unroll_w;
};
template <int fh_, int fw_>
struct FilterTileConfig {
static int constexpr unroll_h = fh_;
static int constexpr unroll_w = fw_;
static int constexpr unroll_size = unroll_h * unroll_w;
};
template <int x_, int y_>
struct ThreadConfig {
static int constexpr thread_x = x_;
static_assert((thread_x & (thread_x - 1)) == 0, "thread_x must be pow of 2!");
static int constexpr thread_y = y_;
static int constexpr nr_threads = x_ * y_;
};
template <
typename ldg_dtype, typename Rldg_dtype, typename Rcmp_dtype,
typename ThreadConfig_, typename OutTileConfig_, typename FilterTileConfig_,
int stride_w, int stride_h>
struct ConvTraitInner {
using ThreadConfig = ThreadConfig_;
using OutTileConfig = OutTileConfig_;
using FilterTileConfig = FilterTileConfig_;
using CompType = ldg_dtype;
struct SrcTileConfig {
static int constexpr unroll_h =
OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1;
static int constexpr unroll_w =
(OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w;
static int constexpr unroll_size = unroll_h * unroll_w;
};
struct SrcTileCount {
static int constexpr smem_src_h =
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h;
static int constexpr smem_delta_h = 2;
static int constexpr smem_buff_h =
FilterTileConfig::unroll_h * smem_delta_h * 2;
static int constexpr smem_load_h = smem_src_h + smem_buff_h;
static int constexpr smem_h = smem_load_h;
static int constexpr smem_w =
DIVUP((OutTileConfig::block_w - 1) * stride_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x,
2) *
2;
static int constexpr load_w = smem_w > 32 ? 32 : smem_w;
static int constexpr load_h = ThreadConfig::nr_threads / load_w;
static int constexpr reg_h = DIVUP(smem_delta_h, load_h);
static int constexpr reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_delta_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static int constexpr bank_w = smem_w / (4 / sizeof(CompType));
static int constexpr bank_offset_line =
(bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0)
? 1
: (bank_w % 16 == 0 ? 2 : 4);
static int constexpr smem_size =
smem_h * smem_w +
DIVUP(smem_h, bank_offset_line) * (4 / sizeof(CompType));
};
struct FilterTileCount {
static int constexpr smem_flt_h = FilterTileConfig::unroll_h;
static int constexpr smem_buff_h = FilterTileConfig::unroll_h;
static int constexpr smem_w =
FilterTileConfig::unroll_w * ThreadConfig::thread_x;
static int constexpr smem_delta_h = 2;
static int constexpr smem_load_h = smem_flt_h + smem_buff_h * smem_w;
static int constexpr smem_h = smem_load_h + smem_buff_h;
static int constexpr load_w = smem_w > 32 ? 32 : smem_w;
static int constexpr load_h = ThreadConfig::nr_threads / load_w;
static int constexpr reg_h = 1;
static int constexpr reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static int constexpr bank_w = smem_w / (4 / sizeof(CompType));
static int constexpr bank_offset_line =
(bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0)
? 1
: (bank_w % 16 == 0 ? 2 : 4);
static int constexpr smem_size =
smem_h * smem_w +
DIVUP(smem_h, bank_offset_line) * (4 / sizeof(CompType));
};
struct RinTileCount {
static int constexpr smem_src_h =
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h;
static int constexpr smem_delta_h = 2;
static int constexpr smem_buff_h =
FilterTileConfig::unroll_h * smem_delta_h * 2;
static int constexpr smem_load_h = smem_src_h + smem_buff_h;
static int constexpr smem_h = smem_load_h;
static int constexpr factor = sizeof(Rldg_dtype) / sizeof(Rcmp_dtype);
static int constexpr smem_w =
DIVUP(DIVUP((OutTileConfig::block_w - 1) * stride_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x,
factor),
2) *
2;
static int constexpr load_w = smem_w > 32 ? 32 : smem_w;
static int constexpr load_h = ThreadConfig::nr_threads / load_w;
static int constexpr reg_h = DIVUP(smem_delta_h, load_h);
static int constexpr reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_delta_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static int constexpr bank_w = smem_w;
static int constexpr bank_offset_line =
(bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0)
? 1
: (bank_w % 16 == 0 ? 2 : 4);
static int constexpr smem_size =
smem_h * smem_w + DIVUP(smem_h, bank_offset_line);
};
};
} // namespace
#pragma once
#include "depthwise_large_filter.cuh"
#include "src/cuda/cuda_shfl_compat.cuh"
namespace {
template <
typename T, typename RT, DepthwiseConv2dDirection kDirection,
typename ThreadConfig_, typename TileCount_>
struct Global2SharedMem {
using TileCount = TileCount_;
using ThreadConfig = ThreadConfig_;
T reg[TileCount::reg_h][TileCount::reg_w];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int tid = tidy * ThreadConfig::thread_x + tidx;
const int gl_load_y = tid / TileCount::load_w;
const int gl_load_x = tid - gl_load_y * TileCount::load_w;
const bool is_fwd = (kDirection == DIRECTION_FORWARD);
int w_offset;
T* smem;
int stride;
int start_h, start_w, bound_h, bound_w, ring_smem_h, ring_src_h;
// just used in backward src data
int stride_h, stride_w;
const RT* g_ptr;
__device__ __forceinline__ Global2SharedMem(
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_,
int stride_w_);
__device__ __forceinline__ void first_copy();
__device__ __forceinline__ void copy();
__device__ __forceinline__ void commit();
__device__ __forceinline__ void iter_forward();
__device__ __forceinline__ T* sh_ptr(int y, int x) {
return &smem[y * TileCount::smem_w + x];
}
__device__ __forceinline__ T* sh_ptr_as_copy_t(int y, int x) {
return reinterpret_cast<T*>(sh_ptr(y, x));
}
};
template <
typename ldg_dtype, typename Rldg_dtype, typename Rcmp_dtype,
DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename OutTileConfig_, typename FilterTileConfig_, int stride_w, int stride_h>
struct ConvTrait {
using ThreadConfig = ThreadConfig_;
using OutTileConfig = OutTileConfig_;
using FilterTileConfig = FilterTileConfig_;
using CompType = ldg_dtype;
using RLdgType = Rldg_dtype;
using RCmpType = Rcmp_dtype;
using CI = ConvTraitInner<
ldg_dtype, Rldg_dtype, Rcmp_dtype, ThreadConfig_, OutTileConfig_,
FilterTileConfig_, stride_w, stride_h>;
using SrcTileConfig = typename CI::SrcTileConfig;
using SrcTileCount = typename CI::SrcTileCount;
using FilterTileCount = typename CI::FilterTileCount;
using RinTileCount = typename CI::RinTileCount;
using SrcGlobal2ShareVisitor = Global2SharedMem<
CompType, CompType, DepthwiseConv2dDirection::DIRECTION_FORWARD,
ThreadConfig, SrcTileCount>;
using RinGlobal2ShareVisitor = Global2SharedMem<
Rldg_dtype, Rcmp_dtype, DepthwiseConv2dDirection::DIRECTION_FORWARD,
ThreadConfig, RinTileCount>;
using FilterGlobal2ShareVisitor = Global2SharedMem<
CompType, CompType, kDirection, ThreadConfig, FilterTileCount>;
};
template <
typename T, typename RT, DepthwiseConv2dDirection kDirection,
typename ThreadConfig_, typename TileCount_>
__device__ __forceinline__
Global2SharedMem<T, RT, kDirection, ThreadConfig_, TileCount_>::Global2SharedMem(
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_,
int stride_w_)
: smem(smem_),
stride(stride_),
start_h(s_h),
start_w(s_w),
bound_h(b_h),
bound_w(b_w),
ring_smem_h(TileCount::smem_load_h),
stride_h(stride_h_),
stride_w(stride_w_) {
if (is_fwd) {
ring_src_h = s_h + TileCount::smem_load_h;
w_offset = 0;
} else {
ring_src_h = s_h - 1;
w_offset = TileCount::smem_w - b_w;
// stride_h and stride_w just used in backward src data.
stride_h = stride_w = 1;
}
}
template <
typename T, typename RT, DepthwiseConv2dDirection kDirection,
typename ThreadConfig_, typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, RT, kDirection, ThreadConfig_, TileCount_>::first_copy() {
static int const load_w = TileCount::smem_w > 32 ? 32 : TileCount::smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w;
static int const h_per_thread = DIVUP(TileCount::smem_load_h, load_h);
static int const w_per_thread = DIVUP(TileCount::smem_w, load_w);
static bool constexpr check_bounds_h = TileCount::smem_load_h % load_h != 0;
static bool constexpr check_bounds_w = TileCount::smem_w % load_w != 0;
const int y_base_idx = tid / load_w;
const int x_base_idx = tid - y_base_idx * load_w;
#pragma unroll
for (int i = 0; i < h_per_thread; ++i) {
int smem_h_idx = y_base_idx + i * load_h;
int bank_offset = smem_h_idx / TileCount::bank_offset_line;
int src_h_idx;
if (is_fwd) {
src_h_idx = start_h + smem_h_idx;
} else {
src_h_idx = start_h - smem_h_idx;
}
if (check_bounds_h && smem_h_idx >= TileCount::smem_load_h)
continue;
#pragma unroll
for (int j = 0; j < w_per_thread; ++j) {
int smem_w_idx = x_base_idx + j * load_w;
int src_w_idx;
if (is_fwd) {
src_w_idx = start_w + smem_w_idx;
} else {
src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1;
}
if (check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
T val = 0.0f;
if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
((is_fwd && src_h_idx % stride_h == 0 && src_w_idx % stride_w == 0) ||
(!is_fwd && TileCount::smem_load_h - smem_h_idx - 1 >= 0 &&
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w];
}
*(sh_ptr_as_copy_t(
smem_h_idx, smem_w_idx + bank_offset * (4 / sizeof(T)))) = val;
}
}
}
template <
typename T, typename RT, DepthwiseConv2dDirection kDirection,
typename ThreadConfig_, typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, RT, kDirection, ThreadConfig_, TileCount_>::copy() {
#pragma unroll
for (int i = 0; i < TileCount::reg_h; ++i) {
int thread_h_idx = gl_load_y + i * TileCount::load_h;
int smem_h_idx = (ring_smem_h + thread_h_idx) % TileCount::smem_h;
int src_h_idx;
if (is_fwd) {
src_h_idx = ring_src_h + thread_h_idx;
} else {
src_h_idx = start_h - smem_h_idx;
}
if (thread_h_idx >= TileCount::smem_delta_h)
continue;
#pragma unroll
for (int j = 0; j < TileCount::reg_w; ++j) {
int smem_w_idx = gl_load_x + j * TileCount::load_w;
int src_w_idx;
if (is_fwd) {
src_w_idx = start_w + smem_w_idx;
} else {
src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1;
}
if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
T val = 0.0f;
if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
((is_fwd && src_h_idx % stride_h == 0 && src_w_idx % stride_w == 0) ||
(!is_fwd && TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w];
}
reg[i][j] = val;
}
}
}
template <
typename T, typename RT, DepthwiseConv2dDirection kDirection,
typename ThreadConfig_, typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, RT, kDirection, ThreadConfig_, TileCount_>::commit() {
#pragma unroll
for (int i = 0; i < TileCount::reg_h; ++i) {
int thread_h_idx = gl_load_y + i * TileCount::load_h;
int smem_h_idx = (ring_smem_h + thread_h_idx) % TileCount::smem_h;
int bank_offset = smem_h_idx / TileCount::bank_offset_line;
if (thread_h_idx >= TileCount::smem_delta_h)
continue;
#pragma unroll
for (int j = 0; j < TileCount::reg_w; ++j) {
int smem_w_idx = gl_load_x + j * TileCount::load_w;
if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
*(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx + bank_offset)) = reg[i][j];
}
}
}
template <
typename T, typename RT, DepthwiseConv2dDirection kDirection,
typename ThreadConfig_, typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, RT, kDirection, ThreadConfig_, TileCount_>::iter_forward() {
if (is_fwd) {
ring_src_h += TileCount::smem_delta_h;
} else {
ring_src_h -= TileCount::smem_delta_h;
}
ring_smem_h = (ring_smem_h + TileCount::smem_delta_h) % TileCount::smem_h;
}
template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
struct Global2SharedMem<T, uint8_t, kDirection, ThreadConfig_, TileCount_> {
using TileCount = TileCount_;
using ThreadConfig = ThreadConfig_;
static const int InnerStep = sizeof(T);
T reg[TileCount::reg_h][TileCount::reg_w];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int tid = tidy * ThreadConfig::thread_x + tidx;
const int gl_load_y = tid / TileCount::load_w;
const int gl_load_x = tid - gl_load_y * TileCount::load_w;
const bool is_fwd = (kDirection == DIRECTION_FORWARD);
int w_offset;
T* smem;
int stride;
int start_h, start_w, bound_h, bound_w, ring_smem_h, ring_src_h;
// just used in backward src data
int stride_h, stride_w;
const uint8_t* g_ptr;
__device__ __forceinline__ Global2SharedMem(
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_,
int stride_w_);
__device__ __forceinline__ void first_copy();
__device__ __forceinline__ void copy();
__device__ __forceinline__ void commit();
__device__ __forceinline__ void iter_forward();
__device__ __forceinline__ T* sh_ptr(int y, int x) {
return &smem[y * TileCount::smem_w + x];
}
__device__ __forceinline__ T* sh_ptr_as_copy_t(int y, int x) {
return reinterpret_cast<T*>(sh_ptr(y, x));
}
};
template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__
Global2SharedMem<T, uint8_t, kDirection, ThreadConfig_, TileCount_>::Global2SharedMem(
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_,
int stride_w_)
: smem(smem_),
stride(stride_),
start_h(s_h),
start_w(s_w),
bound_h(b_h),
bound_w(b_w),
ring_smem_h(TileCount::smem_load_h),
stride_h(stride_h_),
stride_w(stride_w_) {
if (is_fwd) {
ring_src_h = s_h + TileCount::smem_load_h;
w_offset = 0;
} else {
ring_src_h = s_h - 1;
w_offset = TileCount::smem_w - b_w;
// stride_h and stride_w just used in backward src data.
stride_h = stride_w = 1;
}
}
template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, uint8_t, kDirection, ThreadConfig_, TileCount_>::first_copy() {
static int const load_w = TileCount::smem_w > 32 ? 32 : TileCount::smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w;
static int const h_per_thread = DIVUP(TileCount::smem_load_h, load_h);
static int const w_per_thread = DIVUP(TileCount::smem_w, load_w);
static bool constexpr check_bounds_h = TileCount::smem_load_h % load_h != 0;
static bool constexpr check_bounds_w = TileCount::smem_w % load_w != 0;
const int y_base_idx = tid / load_w;
const int x_base_idx = tid - y_base_idx * load_w;
#pragma unroll
for (int i = 0; i < h_per_thread; ++i) {
int smem_h_idx = y_base_idx + i * load_h;
int bank_offset = smem_h_idx / TileCount::bank_offset_line;
int src_h_idx;
if (is_fwd) {
src_h_idx = start_h + smem_h_idx;
} else {
src_h_idx = start_h - smem_h_idx;
}
if (check_bounds_h && smem_h_idx >= TileCount::smem_load_h)
continue;
#pragma unroll
for (int j = 0; j < w_per_thread; ++j) {
int smem_w_idx = x_base_idx + j * load_w;
int src_w_idx;
if (is_fwd) {
src_w_idx = start_w + smem_w_idx * InnerStep;
} else {
src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1;
}
if (check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
T val = 0.0f;
for (int inner = 0; inner < InnerStep; inner++) {
T temp = 0;
if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
((is_fwd && src_h_idx % stride_h == 0 &&
src_w_idx % stride_w == 0) ||
(!is_fwd && TileCount::smem_load_h - smem_h_idx - 1 >= 0 &&
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
temp = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w];
val |= (temp << (inner << 3));
}
src_w_idx++;
}
*(sh_ptr_as_copy_t(
smem_h_idx, smem_w_idx + bank_offset * (4 / sizeof(T)))) = val;
}
}
}
template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, uint8_t, kDirection, ThreadConfig_, TileCount_>::copy() {
#pragma unroll
for (int i = 0; i < TileCount::reg_h; ++i) {
int thread_h_idx = gl_load_y + i * TileCount::load_h;
int smem_h_idx = (ring_smem_h + thread_h_idx) % TileCount::smem_h;
int src_h_idx;
if (is_fwd) {
src_h_idx = ring_src_h + thread_h_idx;
} else {
src_h_idx = start_h - smem_h_idx;
}
if (thread_h_idx >= TileCount::smem_delta_h)
continue;
#pragma unroll
for (int j = 0; j < TileCount::reg_w; ++j) {
int smem_w_idx = gl_load_x + j * TileCount::load_w;
int src_w_idx;
if (is_fwd) {
src_w_idx = start_w + smem_w_idx * InnerStep;
} else {
src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1;
}
if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
T val = 0.0f;
#pragma unroll
for (int inner = 0; inner < InnerStep; inner++) {
uint32_t temp = 0;
if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
((is_fwd && src_h_idx % stride_h == 0 &&
src_w_idx % stride_w == 0) ||
(!is_fwd && TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
temp = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w];
val |= (temp << (inner << 3));
}
src_w_idx++;
}
reg[i][j] = val;
}
}
}
template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, uint8_t, kDirection, ThreadConfig_, TileCount_>::commit() {
#pragma unroll
for (int i = 0; i < TileCount::reg_h; ++i) {
int thread_h_idx = gl_load_y + i * TileCount::load_h;
int smem_h_idx = (ring_smem_h + thread_h_idx) % TileCount::smem_h;
int bank_offset = smem_h_idx / TileCount::bank_offset_line;
if (thread_h_idx >= TileCount::smem_delta_h)
continue;
#pragma unroll
for (int j = 0; j < TileCount::reg_w; ++j) {
int smem_w_idx = gl_load_x + j * TileCount::load_w;
if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
*(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx + bank_offset)) = reg[i][j];
}
}
}
template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, uint8_t, kDirection, ThreadConfig_, TileCount_>::iter_forward() {
if (is_fwd) {
ring_src_h += TileCount::smem_delta_h;
} else {
ring_src_h -= TileCount::smem_delta_h;
}
ring_smem_h = (ring_smem_h + TileCount::smem_delta_h) % TileCount::smem_h;
}
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const float* input, const float* filter, const int* rin,
const int* rout, float* output) {
using T = float;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using OutTileConfig = typename ConvTrait::OutTileConfig;
using SrcTileCount = typename ConvTrait::SrcTileCount;
using FilterTileCount = typename ConvTrait::FilterTileCount;
using RinTileCount = typename ConvTrait::RinTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using RinGlobal2ShareVisitor = typename ConvTrait::RinGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int* smem_rin = reinterpret_cast<int*>(&smem_flt[FilterTileCount::smem_size]);
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
batch = off_ichannel / param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h * stride_h - param.pad_h,
src_start_w = out_start_w * stride_w - param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
int* smem_rin_ptr = smem_rin + off_ow * FilterTileConfig::unroll_w;
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
const int* rout_base_ptr = rout + batch * param.out_h * param.out_w;
int reg_rout[OutTileConfig::unroll_size] = {0};
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx < param.out_w) {
reg_rout[i * OutTileConfig::unroll_w + j] =
rout_base_ptr[out_h_idx * param.out_w + out_w_idx];
}
}
}
}
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
RinGlobal2ShareVisitor gl2sh_rin = {
smem_rin,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt,
static_cast<int>(param.flt_w),
is_fwd ? 0 : static_cast<int>(param.flt_h - 1),
0,
static_cast<int>(param.flt_h),
static_cast<int>(param.flt_w),
1,
1};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_rin.g_ptr = rin + batch * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
gl2sh_src.first_copy();
gl2sh_rin.first_copy();
gl2sh_flt.first_copy();
__syncthreads();
T reg_src[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
int reg_rin[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w];
T sum[OutTileConfig::unroll_size] = {0.0};
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w + (off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line];
reg_rin[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_rin_ptr
[(off_oh * stride_h + s_h) % RinTileCount::smem_h *
RinTileCount::smem_w +
s_w + (off_oh * stride_h + s_h) / RinTileCount::bank_offset_line];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w +
f_h / FilterTileCount::bank_offset_line];
}
}
int fh = 1;
for (; fh < param.flt_h; fh += FilterTileConfig::unroll_h * 2) {
if (fh + 4 < param.flt_h + 1) {
gl2sh_src.copy();
gl2sh_rin.copy();
}
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
int smem_h_idx = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h;
reg_src[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[smem_h_idx * SrcTileCount::smem_w + s_w +
smem_h_idx / SrcTileCount::bank_offset_line];
reg_rin[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_rin_ptr
[smem_h_idx * RinTileCount::smem_w + s_w +
smem_h_idx / RinTileCount::bank_offset_line];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w + (fh + f_h) / FilterTileCount::bank_offset_line];
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
int src_idx = (inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w;
if (reg_rin[0][src_idx] ==
reg_rout[oh * OutTileConfig::unroll_w + ow]) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[0]
[inner_fh * FilterTileConfig::unroll_w +
fw] *
reg_src[0][src_idx];
}
}
}
}
}
if (fh + SrcTileCount::smem_delta_h < param.flt_h) {
__syncthreads();
}
if (fh + (SrcTileCount::smem_delta_h << 1) < param.flt_h) {
gl2sh_src.commit();
gl2sh_rin.commit();
gl2sh_src.iter_forward();
gl2sh_rin.iter_forward();
}
if (fh + 1 < param.flt_h) {
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
int smem_h_idx =
(off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h;
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[smem_h_idx * SrcTileCount::smem_w + s_w +
smem_h_idx / SrcTileCount::bank_offset_line];
reg_rin[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_rin_ptr
[smem_h_idx * RinTileCount::smem_w + s_w +
smem_h_idx / RinTileCount::bank_offset_line];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + 1 + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line];
}
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
int src_idx = (inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w;
if (reg_rin[1][src_idx] ==
reg_rout[oh * OutTileConfig::unroll_w + ow]) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[1]
[inner_fh * FilterTileConfig::unroll_w +
fw] *
reg_src[1][src_idx];
}
}
}
}
}
}
if (param.flt_h == fh) {
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
int src_idx = (inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w;
if (reg_rin[0][src_idx] ==
reg_rout[oh * OutTileConfig::unroll_w + ow]) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[0]
[inner_fh * FilterTileConfig::unroll_w +
fw] *
reg_src[0][src_idx];
}
}
}
}
}
}
__syncthreads();
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o] += __shfl_xor(sum[o], i, 32);
}
}
if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j];
}
}
}
}
}
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const float* input, const float* filter, const uint8_t* rin,
const uint8_t* rout, float* output) {
using T = float;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using OutTileConfig = typename ConvTrait::OutTileConfig;
using SrcTileCount = typename ConvTrait::SrcTileCount;
using FilterTileCount = typename ConvTrait::FilterTileCount;
using RinTileCount = typename ConvTrait::RinTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using RinGlobal2ShareVisitor = typename ConvTrait::RinGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int* smem_rin = reinterpret_cast<int*>(&smem_flt[FilterTileCount::smem_size]);
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
batch = off_ichannel / param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h * stride_h - param.pad_h,
src_start_w = out_start_w * stride_w - param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
static_assert((FilterTileConfig::unroll_w & 3) == 0);
int* smem_rin_ptr = smem_rin + (off_ow * FilterTileConfig::unroll_w >> 2);
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
const uint8_t* rout_base_ptr = rout + batch * param.out_h * param.out_w;
static_assert((OutTileConfig::unroll_w & 3) == 0);
static_assert((OutTileConfig::block_w & 3) == 0);
int reg_rout[OutTileConfig::unroll_size] = {0};
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; j += 4) {
int out_w_idx = out_start_w + j;
if (out_w_idx < param.out_w) {
uint32_t val = *(reinterpret_cast<const uint32_t*>(
&rout_base_ptr[out_h_idx * param.out_w + out_w_idx]));
reg_rout[i * OutTileConfig::unroll_w + j] = val & 0xff;
reg_rout[i * OutTileConfig::unroll_w + j + 1] = (val >> 8) & 0xff;
reg_rout[i * OutTileConfig::unroll_w + j + 2] = (val >> 16) & 0xff;
reg_rout[i * OutTileConfig::unroll_w + j + 3] = (val >> 24) & 0xff;
}
}
}
}
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
RinGlobal2ShareVisitor gl2sh_rin = {
smem_rin,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt,
static_cast<int>(param.flt_w),
is_fwd ? 0 : static_cast<int>(param.flt_h - 1),
0,
static_cast<int>(param.flt_h),
static_cast<int>(param.flt_w),
1,
1};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_rin.g_ptr = rin + batch * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
gl2sh_src.first_copy();
gl2sh_rin.first_copy();
gl2sh_flt.first_copy();
__syncthreads();
const static int irin_unroll_w = (DIVUP(SrcTileConfig::unroll_w, 4)) << 2;
T reg_src[2][SrcTileConfig::unroll_h * irin_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
int reg_rin[2][SrcTileConfig::unroll_h * irin_unroll_w];
T sum[OutTileConfig::unroll_size] = {0.0};
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
int s_idx = (off_oh * stride_h + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
(off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line;
#pragma unroll
for (int s_w = 0; s_w < irin_unroll_w; s_w += 4) {
uint32_t val = smem_rin_ptr
[(off_oh * stride_h + s_h) % RinTileCount::smem_h *
RinTileCount::smem_w +
(s_w >> 2) +
(off_oh * stride_h + s_h) / RinTileCount::bank_offset_line];
reg_src[0][s_h * irin_unroll_w + s_w] = smem_src_ptr[s_idx + s_w];
reg_src[0][s_h * irin_unroll_w + s_w + 1] = smem_src_ptr[s_idx + s_w + 1];
reg_src[0][s_h * irin_unroll_w + s_w + 2] = smem_src_ptr[s_idx + s_w + 2];
reg_src[0][s_h * irin_unroll_w + s_w + 3] = smem_src_ptr[s_idx + s_w + 3];
reg_rin[0][s_h * irin_unroll_w + s_w] = val & 0xff;
reg_rin[0][s_h * irin_unroll_w + s_w + 1] = (val >> 8) & 0xff;
reg_rin[0][s_h * irin_unroll_w + s_w + 2] = (val >> 16) & 0xff;
reg_rin[0][s_h * irin_unroll_w + s_w + 3] = (val >> 24) & 0xff;
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w +
f_h / FilterTileCount::bank_offset_line];
}
}
int fh = 1;
for (; fh < param.flt_h; fh += FilterTileConfig::unroll_h * 2) {
if (fh + 4 < param.flt_h + 1) {
gl2sh_src.copy();
gl2sh_rin.copy();
}
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
int src_off = ((off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h) *
SrcTileCount::smem_w +
((off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h) /
SrcTileCount::bank_offset_line;
int rin_h_idx = (off_oh * stride_h + fh + s_h) % RinTileCount::smem_h;
#pragma unroll
for (int s_w = 0; s_w < irin_unroll_w; s_w += 4) {
uint32_t val = smem_rin_ptr
[rin_h_idx * RinTileCount::smem_w + (s_w >> 2) +
rin_h_idx / RinTileCount::bank_offset_line];
reg_src[1][s_h * irin_unroll_w + s_w] = smem_src_ptr[src_off + s_w];
reg_src[1][s_h * irin_unroll_w + s_w + 1] =
smem_src_ptr[src_off + s_w + 1];
reg_src[1][s_h * irin_unroll_w + s_w + 2] =
smem_src_ptr[src_off + s_w + 2];
reg_src[1][s_h * irin_unroll_w + s_w + 3] =
smem_src_ptr[src_off + s_w + 3];
reg_rin[1][s_h * irin_unroll_w + s_w] = val & 0xff;
reg_rin[1][s_h * irin_unroll_w + s_w + 1] = (val >> 8) & 0xff;
reg_rin[1][s_h * irin_unroll_w + s_w + 2] = (val >> 16) & 0xff;
reg_rin[1][s_h * irin_unroll_w + s_w + 3] = (val >> 24) & 0xff;
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w + (fh + f_h) / FilterTileCount::bank_offset_line];
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
int src_h_off = (inner_fh + oh) * irin_unroll_w;
int rin_h_off = (inner_fh + oh) * irin_unroll_w;
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
int flt_off = inner_fh * FilterTileConfig::unroll_w + fw;
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
int src_w_idx = fw + ow * stride_w;
if (reg_rin[0][rin_h_off + src_w_idx] ==
reg_rout[oh * OutTileConfig::unroll_w + ow]) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[0][flt_off] *
reg_src[0][src_h_off + src_w_idx];
}
}
}
}
}
if (fh + SrcTileCount::smem_delta_h < param.flt_h) {
__syncthreads();
}
if (fh + (SrcTileCount::smem_delta_h << 1) < param.flt_h) {
gl2sh_src.commit();
gl2sh_rin.commit();
gl2sh_src.iter_forward();
gl2sh_rin.iter_forward();
}
if (fh + 1 < param.flt_h) {
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
int src_idx =
((off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h) *
SrcTileCount::smem_w +
((off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h) /
SrcTileCount::bank_offset_line;
int rin_h_idx =
(off_oh * stride_h + fh + 1 + s_h) % RinTileCount::smem_h;
#pragma unroll
for (int s_w = 0; s_w < irin_unroll_w; s_w += 4) {
uint32_t val = smem_rin_ptr
[rin_h_idx * RinTileCount::smem_w + (s_w >> 2) +
rin_h_idx / RinTileCount::bank_offset_line];
reg_src[0][s_h * irin_unroll_w + s_w] = smem_src_ptr[src_idx + s_w];
reg_src[0][s_h * irin_unroll_w + s_w + 1] =
smem_src_ptr[src_idx + s_w + 1];
reg_src[0][s_h * irin_unroll_w + s_w + 2] =
smem_src_ptr[src_idx + s_w + 2];
reg_src[0][s_h * irin_unroll_w + s_w + 3] =
smem_src_ptr[src_idx + s_w + 3];
reg_rin[0][s_h * irin_unroll_w + s_w] = val & 0xff;
reg_rin[0][s_h * irin_unroll_w + s_w + 1] = (val >> 8) & 0xff;
reg_rin[0][s_h * irin_unroll_w + s_w + 2] = (val >> 16) & 0xff;
reg_rin[0][s_h * irin_unroll_w + s_w + 3] = (val >> 24) & 0xff;
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + 1 + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line];
}
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
int src_h_off = (inner_fh + oh) * irin_unroll_w;
int rin_h_off = (inner_fh + oh) * irin_unroll_w;
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
int flt_off = inner_fh * FilterTileConfig::unroll_w + fw;
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
int src_w_idx = fw + ow * stride_w;
if (reg_rin[1][rin_h_off + src_w_idx] ==
reg_rout[oh * OutTileConfig::unroll_w + ow]) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[1][flt_off] *
reg_src[1][src_h_off + src_w_idx];
}
}
}
}
}
}
if (param.flt_h == fh) {
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
int src_h_off = (inner_fh + oh) * irin_unroll_w;
int rin_h_off = (inner_fh + oh) * irin_unroll_w;
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
int flt_off = inner_fh * FilterTileConfig::unroll_w + fw;
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
int src_w_idx = fw + ow * stride_w;
if (reg_rin[0][rin_h_off + src_w_idx] ==
reg_rout[oh * OutTileConfig::unroll_w + ow]) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[0][flt_off] *
reg_src[0][src_h_off + src_w_idx];
}
}
}
}
}
}
__syncthreads();
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o] += __shfl_xor(sum[o], i, 32);
}
}
if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j];
}
}
}
}
}
template <
typename T, typename RT, DepthwiseConv2dDirection kDirection, int unroll_fw,
int unroll_ow, int stride>
void LaunchDepthwiseConv2dGPU(
const Param& param, const T* input, const T* filter, const RT* rin,
const RT* rout, T* output, cudaStream_t stream) {
static int const unroll_oh = 1, unroll_fh = 1;
using FilterTileConfig = FilterTileConfig<unroll_fh, unroll_fw>;
using ThreadConfig = ThreadConfig<4, 32>;
using OutTileConfig = OutTileConfig<ThreadConfig, unroll_oh, unroll_ow>;
using IConvTrait = ConvTrait<
T, int, RT, kDirection, ThreadConfig, OutTileConfig, FilterTileConfig,
stride, stride>;
using SrcTileCount = typename IConvTrait::SrcTileCount;
using FilterTileCount = typename IConvTrait::FilterTileCount;
using RinTileCount = typename IConvTrait::RinTileCount;
dim3 block(ThreadConfig::thread_x, ThreadConfig::thread_y);
dim3 grid;
grid.x = param.batch * param.src_chl * param.chl_mul;
grid.y = DIVUP(param.out_w, OutTileConfig::block_w);
grid.z = DIVUP(param.out_h, OutTileConfig::block_h);
const int shared_storage =
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T) +
RinTileCount::smem_size * sizeof(int);
void (*kernel)(const Param, const T*, const T*, const RT*, const RT*, T*);
if (param.is_compute_deafult) {
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>;
} else {
megdnn_assert_internal(0);
}
kernel<<<grid, block, shared_storage, stream>>>(
param, input, filter, rin, rout, output);
after_kernel_launch();
}
#define INSTANCE_AB(type1, type2, a, direction) \
if (param.out_w > 28) { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 8, 1>( \
param, src, flt, rin, rout, dst, stream); \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 8, 2>( \
param, src, flt, rin, rout, dst, stream); \
} \
} else { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 4, 1>( \
param, src, flt, rin, rout, dst, stream); \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 4, 2>( \
param, src, flt, rin, rout, dst, stream); \
} \
}
#define INSTANCE_INT(type1, type2, direction) \
if (param.flt_w > 24) { \
INSTANCE_AB(type1, type2, 8, direction) \
} else if (param.flt_w > 16) { \
INSTANCE_AB(type1, type2, 6, direction) \
} else if (param.flt_w > 8) { \
INSTANCE_AB(type1, type2, 4, direction) \
} else { \
INSTANCE_AB(type1, type2, 2, direction) \
}
#define INSTANCE_UINT8(type1, type2, direction) \
if (param.flt_w > 16) { \
INSTANCE_AB(type1, type2, 8, direction) \
} else { \
INSTANCE_AB(type1, type2, 4, direction) \
}
} // anonymous namespace
#include "cuda.h"
#include "cuda_fp16.h"
#include "src/cuda/fp16_help.cuh"
#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh"
using namespace megdnn;
using namespace cuda;
using namespace region_restricted_convolution;
using namespace chanwise;
#include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh"
namespace megdnn {
namespace cuda {
namespace region_restricted_convolution {
namespace chanwise {
// =====================================fwd=====================================
#define check
template <>
void run_fwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const int* rin, const int* rout,
const Param& param, cudaStream_t stream) {
INSTANCE_INT(float, int, DepthwiseConv2dDirection::DIRECTION_FORWARD)
}
template <>
void run_fwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const uint8_t* rin,
const uint8_t* rout, const Param& param, cudaStream_t stream) {
INSTANCE_UINT8(float, uint8_t, DepthwiseConv2dDirection::DIRECTION_FORWARD)
}
} // namespace chanwise
} // namespace region_restricted_convolution
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
#pragma once
#include <cuda_runtime.h>
#include <stdint.h>
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace region_restricted_convolution {
namespace chanwise {
struct Param {
int batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h, pad_w,
stride_h, stride_w, dilation_h, dilation_w;
bool is_compute_deafult;
#if MEGDNN_CC_HOST
static Param load(
const TensorShape& src, const TensorShape& dst,
const RegionRestrictedConvolutionForward::CanonizedFilterMeta& fm,
bool is_compute_deafult_ = true) {
#define U(v) static_cast<int>(v)
size_t c_pos, hw_pos;
if (fm.format == param::Convolution::Format::NCHW) {
c_pos = 1;
hw_pos = 2;
} else {
megdnn_assert_internal(0);
}
return {
U(src[0]), U(src[c_pos]), U(src[hw_pos]),
U(src[hw_pos + 1]), U(fm.ocpg), U(fm.spatial[0]),
U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]),
U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]),
U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]),
is_compute_deafult_,
};
#undef U
}
#endif
};
template <typename T, typename RT>
void run_fwd_depthwise_large_filter(
T* dst, const T* src, const T* flt, const RT* rin, const RT* rout,
const Param& param, cudaStream_t stream);
template <typename T, typename RT>
void run_bwd_depthwise_large_filter(
T* dst, const T* src, const T* flt, const RT* rin, const RT* rout,
const Param& param, cudaStream_t stream);
} // namespace chanwise
} // namespace region_restricted_convolution
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen
#include "src/cuda/region_restricted_convolution/opr_impl.h"
#include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter.cuh"
#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace region_restricted_convolution;
/* ============== RegionRestrictedConvolutionForwardImpl ============== */
void RegionRestrictedConvolutionForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
auto fm = check_exec(
src.layout, filter.layout, rin.layout, rout.layout, dst.layout,
workspace.size);
auto kparam = chanwise::Param::load(
src.layout, dst.layout, fm,
param().compute_mode == Param::ComputeMode::DEFAULT);
megdnn_assert(
fm.group > 1 && src.layout.dtype.category() == DTypeCategory::FLOAT &&
param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip &&
param().stride_h == 1 && param().stride_w == 1);
if (rin.layout.dtype == dtype::Uint8()) {
megdnn_assert((src.layout.shape[3] & 3) == 0 && (dst.layout.shape[3] & 3) == 0);
}
auto stream = cuda_stream(handle());
if (filter.layout.dtype == dtype::Float32() && rin.layout.dtype == dtype::Int32() &&
rout.layout.dtype == dtype::Int32()) {
chanwise::run_fwd_depthwise_large_filter(
dst.ptr<float>(), src.ptr<float>(), filter.ptr<float>(), rin.ptr<int>(),
rout.ptr<int>(), kparam, stream);
} else if (
filter.layout.dtype == dtype::Float32() &&
rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) {
chanwise::run_fwd_depthwise_large_filter(
dst.ptr<float>(), src.ptr<float>(), filter.ptr<float>(),
rin.ptr<uint8_t>(), rout.ptr<uint8_t>(), kparam, stream);
} else {
megdnn_assert_internal(0);
}
}
size_t RegionRestrictedConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& rin,
const TensorLayout& rout, const TensorLayout& grad) {
return 0;
}
/* ============== RegionRestrictedConvolutionBackwardDataImpl ============== */
void RegionRestrictedConvolutionBackwardDataImpl::exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
megdnn_throw(ssprintf(
"unsupported RegionRestrictedConvolutionBackwardData(%s, %s, %s, %s) -> %s",
filter.layout.dtype.name(), diff.layout.dtype.name(),
rin.layout.dtype.name(), rout.layout.dtype.name(),
grad.layout.dtype.name()));
}
size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout&,
const TensorLayout&, const TensorLayout& grad) {
size_t workspace_size = 0;
return workspace_size;
}
/* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */
void RegionRestrictedConvolutionBackwardFilterImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
megdnn_assert_internal(0);
}
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"
namespace megdnn {
namespace cuda {
class RegionRestrictedConvolutionForwardImpl
: public RegionRestrictedConvolutionForward {
public:
using RegionRestrictedConvolutionForward::RegionRestrictedConvolutionForward;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
class RegionRestrictedConvolutionBackwardDataImpl
: public RegionRestrictedConvolutionBackwardData {
public:
using RegionRestrictedConvolutionBackwardData::
RegionRestrictedConvolutionBackwardData;
void exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override;
};
class RegionRestrictedConvolutionBackwardFilterImpl
: public RegionRestrictedConvolutionBackwardFilter {
public:
using RegionRestrictedConvolutionBackwardFilter::
RegionRestrictedConvolutionBackwardFilter;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override;
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -878,8 +878,9 @@ void forward_bias( ...@@ -878,8 +878,9 @@ void forward_bias(
} }
template < template <
typename stype, typename ftype, typename dtype, typename comp_type, typename stype, typename ftype, typename rtype, typename dtype,
class Strategy, typename FilterMeta, typename FilterVisitor = ConvFilterVisitor> typename comp_type, class Strategy, typename FilterMeta,
typename FilterVisitor = ConvFilterVisitor>
void region_restricted_compute( void region_restricted_compute(
_megdnn_tensor_in src, ftype* __restrict fptr, _megdnn_tensor_in rin, _megdnn_tensor_in src, ftype* __restrict fptr, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out dst, const FilterMeta& filter_meta) { _megdnn_tensor_in rout, _megdnn_tensor_out dst, const FilterMeta& filter_meta) {
...@@ -897,8 +898,8 @@ void region_restricted_compute( ...@@ -897,8 +898,8 @@ void region_restricted_compute(
int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1]; int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1];
stype* __restrict sptr = src.compatible_ptr<stype>(); stype* __restrict sptr = src.compatible_ptr<stype>();
dtype* __restrict dptr = dst.compatible_ptr<dtype>(); dtype* __restrict dptr = dst.compatible_ptr<dtype>();
int32_t* __restrict rinptr = rin.ptr<int32_t>(); rtype* __restrict rinptr = rin.compatible_ptr<rtype>();
int32_t* __restrict routptr = rout.ptr<int32_t>(); rtype* __restrict routptr = rout.compatible_ptr<rtype>();
int h_offset = -ph, w_offset = -pw; int h_offset = -ph, w_offset = -pw;
if (filter_meta.should_flip) { if (filter_meta.should_flip) {
...@@ -934,7 +935,7 @@ void region_restricted_compute( ...@@ -934,7 +935,7 @@ void region_restricted_compute(
ftype* fptr_cur = FilterVisitor::template get_current_ptr( ftype* fptr_cur = FilterVisitor::template get_current_ptr(
fptr, n, oc, oh, ow, filter_sizes); fptr, n, oc, oh, ow, filter_sizes);
Strategy::init_dval(dval); Strategy::init_dval(dval);
int32_t routval = routptr[get_region_addr(n, oh, ow, rout.layout)]; rtype& routval = routptr[get_region_addr(n, oh, ow, rout.layout)];
for (size_t fh = 0; fh < FH; ++fh) for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw) { for (size_t fw = 0; fw < FW; ++fw) {
...@@ -950,7 +951,7 @@ void region_restricted_compute( ...@@ -950,7 +951,7 @@ void region_restricted_compute(
n, ic, ih, iw, src.layout)]; n, ic, ih, iw, src.layout)];
ftype& fval = fptr_cur[get_filter_addr( ftype& fval = fptr_cur[get_filter_addr(
gc_out, ic, ic0, fh, fw)]; gc_out, ic, ic0, fh, fw)];
int32_t rinval = rinptr[get_region_addr( rtype& rinval = rinptr[get_region_addr(
n, ih, iw, rin.layout)]; n, ih, iw, rin.layout)];
if (routval == rinval) { if (routval == rinval) {
Strategy::on( Strategy::on(
...@@ -967,28 +968,32 @@ void region_restricted_compute( ...@@ -967,28 +968,32 @@ void region_restricted_compute(
} }
//! forward with only filter ptr //! forward with only filter ptr
template <typename stype, typename ftype, typename dtype, typename comp_type> template <
typename stype, typename ftype, typename rtype, typename dtype,
typename comp_type>
void region_restricted_forward( void region_restricted_forward(
_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_in rin, _megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out dst, _megdnn_tensor_in rout, _megdnn_tensor_out dst,
const RegionRestrictedConvolution::CanonizedFilterMeta& filter_meta) { const RegionRestrictedConvolution::CanonizedFilterMeta& filter_meta) {
megdnn_assert(filter_meta.spatial_ndim == 2); megdnn_assert(filter_meta.spatial_ndim == 2);
megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW); megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW);
region_restricted_compute<stype, ftype, dtype, comp_type, StrategyFwd>( region_restricted_compute<stype, ftype, rtype, dtype, comp_type, StrategyFwd>(
src, const_cast<ftype*>(fptr), rin, rout, dst, filter_meta); src, const_cast<ftype*>(fptr), rin, rout, dst, filter_meta);
} }
//! forward with full filter (for API compatibility) //! forward with full filter (for API compatibility)
template <typename stype, typename ftype, typename dtype, typename comp_type> template <
typename stype, typename ftype, typename rtype, typename dtype,
typename comp_type>
void region_restricted_forward( void region_restricted_forward(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin, _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out dst, _megdnn_tensor_in rout, _megdnn_tensor_out dst,
const RegionRestrictedConvolution::CanonizedFilterMeta& filter_meta) { const RegionRestrictedConvolution::CanonizedFilterMeta& filter_meta) {
return region_restricted_forward<stype, ftype, dtype, comp_type>( return region_restricted_forward<stype, ftype, rtype, dtype, comp_type>(
src, filter.compatible_ptr<ftype>(), rin, rout, dst, filter_meta); src, filter.compatible_ptr<ftype>(), rin, rout, dst, filter_meta);
} }
template <typename ftype, typename dtype, typename gtype> template <typename ftype, typename dtype, typename rtype, typename gtype>
void region_restricted_backward_data( void region_restricted_backward_data(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_tensor_in rout, _megdnn_tensor_out grad,
...@@ -996,11 +1001,11 @@ void region_restricted_backward_data( ...@@ -996,11 +1001,11 @@ void region_restricted_backward_data(
megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW); megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW);
memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte()); memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 2); megdnn_assert(filter_meta.spatial_ndim == 2);
region_restricted_compute<gtype, ftype, dtype, dtype, StrategyBwdData>( region_restricted_compute<gtype, ftype, rtype, dtype, dtype, StrategyBwdData>(
grad, filter.compatible_ptr<ftype>(), rin, rout, diff, filter_meta); grad, filter.compatible_ptr<ftype>(), rin, rout, diff, filter_meta);
} }
template <typename stype, typename dtype, typename gtype> template <typename stype, typename dtype, typename rtype, typename gtype>
void region_restricted_backward_filter( void region_restricted_backward_filter(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_tensor_in rout, _megdnn_tensor_out grad,
...@@ -1008,7 +1013,7 @@ void region_restricted_backward_filter( ...@@ -1008,7 +1013,7 @@ void region_restricted_backward_filter(
megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW); megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW);
memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte()); memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 2); megdnn_assert(filter_meta.spatial_ndim == 2);
region_restricted_compute<stype, gtype, dtype, dtype, StrategyBwdFlt>( region_restricted_compute<stype, gtype, rtype, dtype, dtype, StrategyBwdFlt>(
src, grad.compatible_ptr<gtype>(), rin, rout, diff, filter_meta); src, grad.compatible_ptr<gtype>(), rin, rout, diff, filter_meta);
} }
......
...@@ -22,28 +22,37 @@ void RegionRestrictedConvolutionForwardImpl::exec( ...@@ -22,28 +22,37 @@ void RegionRestrictedConvolutionForwardImpl::exec(
src.layout, filter.layout, rin.layout, rout.layout, dst.layout, src.layout, filter.layout, rin.layout, rout.layout, dst.layout,
workspace.size); workspace.size);
using ComputeMode = Param::ComputeMode; using ComputeMode = Param::ComputeMode;
#define DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, cmode) \ #define DISPATCH_CMODE(in_dt, r_dt, out_dt, in_ct, r_ct, out_ct, comp_ct, cmode) \
do { \ do { \
using namespace dtype; \ using namespace dtype; \
if (src.layout.dtype.enumv() == DTypeTrait<in_dt>::enumv && \ if (src.layout.dtype.enumv() == DTypeTrait<in_dt>::enumv && \
dst.layout.dtype.enumv() == DTypeTrait<out_dt>::enumv && \ dst.layout.dtype.enumv() == DTypeTrait<out_dt>::enumv && \
rin.layout.dtype.enumv() == DTypeTrait<r_dt>::enumv && \
rout.layout.dtype.enumv() == DTypeTrait<r_dt>::enumv && \
param().compute_mode == cmode) { \ param().compute_mode == cmode) { \
MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_forward< \ MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_forward< \
in_ct, in_ct, out_ct, comp_ct>( \ in_ct, in_ct, r_ct, out_ct, comp_ct>( \
src, filter, rin, rout, dst, filter_meta));); \ src, filter, rin, rout, dst, filter_meta));); \
return; \ return; \
} \ } \
} while (0); } while (0);
#define DISPATCH(in_dt, out_dt, in_ct, out_ct, comp_ct) \ #define DISPATCH(in_dt, r_dt, out_dt, in_ct, r_ct, out_ct, comp_ct) \
DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, ComputeMode::DEFAULT) DISPATCH_CMODE( \
in_dt, r_dt, out_dt, in_ct, r_ct, out_ct, comp_ct, ComputeMode::DEFAULT)
#define cb(dt) \ #define cb(dt) \
DISPATCH( \ DISPATCH( \
dt, dt, DTypeTrait<dt>::ctype, DTypeTrait<dt>::ctype, \ dt, Int32, dt, DTypeTrait<dt>::ctype, dt_int32, DTypeTrait<dt>::ctype, \
DTypeTrait<dt>::ctype) \
DISPATCH( \
dt, Uint8, dt, DTypeTrait<dt>::ctype, dt_uint8, DTypeTrait<dt>::ctype, \
DTypeTrait<dt>::ctype) DTypeTrait<dt>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb #undef cb
DNN_INC_FLOAT16(DISPATCH_CMODE( DNN_INC_FLOAT16(DISPATCH_CMODE(
Float16, Float16, dt_float16, dt_float16, dt_float32, Float16, Int32, Float16, dt_float16, dt_int32, dt_float16, dt_float32,
ComputeMode::FLOAT32));
DNN_INC_FLOAT16(DISPATCH_CMODE(
Float16, Uint8, Float16, dt_float16, dt_uint8, dt_float16, dt_float32,
ComputeMode::FLOAT32)); ComputeMode::FLOAT32));
#undef DISPATCH #undef DISPATCH
megdnn_throw(ssprintf( megdnn_throw(ssprintf(
...@@ -89,26 +98,51 @@ void RegionRestrictedConvolutionBackwardDataImpl::exec( ...@@ -89,26 +98,51 @@ void RegionRestrictedConvolutionBackwardDataImpl::exec(
auto cmode = param().compute_mode; auto cmode = param().compute_mode;
#define cb(dt) \ #define cb(dt) \
do { \ do { \
if (filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) { \ if (filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Int32() && \
rout.layout.dtype == dtype::Int32()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
(convolution::region_restricted_backward_data< \
ctype, ctype, dt_int32, ctype>( \
filter, diff, rin, rout, grad, filter_meta))); \
return; \
} else if ( \
filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Uint8() && \
rout.layout.dtype == dtype::Uint8()) { \
using ctype = DTypeTrait<dt>::ctype; \ using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
(convolution::region_restricted_backward_data< \ (convolution::region_restricted_backward_data< \
ctype, ctype, ctype>( \ ctype, ctype, dt_uint8, ctype>( \
filter, diff, rin, rout, grad, filter_meta));); \ filter, diff, rin, rout, grad, filter_meta))); \
return; \ return; \
} \ } \
} while (0); } while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb #undef cb
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
if (filter.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32) { if (filter.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32 &&
rin.layout.dtype == dtype::Int32() && rout.layout.dtype == dtype::Int32()) {
TensorND grad_fp32{
workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
auto&& type_cvt = handle()->create_operator<TypeCvt>();
type_cvt->exec(grad, grad_fp32);
MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_data<
dt_float16, dt_float16, dt_int32, dt_float32>(
filter, diff, rin, rout, grad_fp32, filter_meta)));
type_cvt->exec(grad_fp32, grad);
return;
} else if (
filter.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32 &&
rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) {
TensorND grad_fp32{ TensorND grad_fp32{
workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}}; workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
auto&& type_cvt = handle()->create_operator<TypeCvt>(); auto&& type_cvt = handle()->create_operator<TypeCvt>();
type_cvt->exec(grad, grad_fp32); type_cvt->exec(grad, grad_fp32);
MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_data< MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_data<
dt_float16, dt_float16, dt_float32>( dt_float16, dt_float16, dt_uint8, dt_float32>(
filter, diff, rin, rout, grad_fp32, filter_meta));); filter, diff, rin, rout, grad_fp32, filter_meta)));
type_cvt->exec(grad_fp32, grad); type_cvt->exec(grad_fp32, grad);
return; return;
} }
...@@ -148,12 +182,27 @@ void RegionRestrictedConvolutionBackwardFilterImpl::exec( ...@@ -148,12 +182,27 @@ void RegionRestrictedConvolutionBackwardFilterImpl::exec(
auto cmode = param().compute_mode; auto cmode = param().compute_mode;
#define cb(dt) \ #define cb(dt) \
do { \ do { \
if (src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) { \ if (src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Int32() && \
rout.layout.dtype == dtype::Int32()) { \
using ctype = DTypeTrait<dt>::ctype; \ using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \ static_cast<HandleImpl*>(handle()), \
convolution::region_restricted_backward_filter< \ convolution::region_restricted_backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ ctype MEGDNN_COMMA ctype MEGDNN_COMMA dt_int32 \
MEGDNN_COMMA ctype>( \
src, diff, rin, rout, grad, filter_meta);); \
return; \
} else if ( \
src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Uint8() && \
rout.layout.dtype == dtype::Uint8()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution::region_restricted_backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA dt_uint8 \
MEGDNN_COMMA ctype>( \
src, diff, rin, rout, grad, filter_meta);); \ src, diff, rin, rout, grad, filter_meta);); \
return; \ return; \
} \ } \
...@@ -161,13 +210,26 @@ void RegionRestrictedConvolutionBackwardFilterImpl::exec( ...@@ -161,13 +210,26 @@ void RegionRestrictedConvolutionBackwardFilterImpl::exec(
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb #undef cb
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
if (src.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32) { if (src.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32 &&
rin.layout.dtype == dtype::Int32() && rout.layout.dtype == dtype::Int32()) {
TensorND grad_fp32{
workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
auto&& type_cvt = handle()->create_operator<TypeCvt>();
type_cvt->exec(grad, grad_fp32);
MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_filter<
dt_float16, dt_float16, dt_int32, dt_float32>(
src, diff, rin, rout, grad_fp32, filter_meta)););
type_cvt->exec(grad_fp32, grad);
return;
} else if (
src.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32 &&
rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) {
TensorND grad_fp32{ TensorND grad_fp32{
workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}}; workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
auto&& type_cvt = handle()->create_operator<TypeCvt>(); auto&& type_cvt = handle()->create_operator<TypeCvt>();
type_cvt->exec(grad, grad_fp32); type_cvt->exec(grad, grad_fp32);
MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_filter< MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_filter<
dt_float16, dt_float16, dt_float32>( dt_float16, dt_float16, dt_uint8, dt_float32>(
src, diff, rin, rout, grad_fp32, filter_meta));); src, diff, rin, rout, grad_fp32, filter_meta)););
type_cvt->exec(grad_fp32, grad); type_cvt->exec(grad_fp32, grad);
return; return;
......
...@@ -717,11 +717,11 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { ...@@ -717,11 +717,11 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward::algo_name<ConvBias::DirectParam>( ConvBiasForward::algo_name<ConvBias::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {}) "DEPTHWISE_LARGE_FILTER", {})
.c_str())); .c_str()));
for (auto dtype : std::vector<DType> { for (auto dtype : std::vector<DType>{
dtype::Float32(), dtype::Float32(),
#if CUDA_VERSION >= 9000 // #if CUDA_VERSION >= 9000
dtype::Float16() // dtype::Float16()
#endif // #endif
}) { }) {
auto run = [&checker, &dtype]( auto run = [&checker, &dtype](
size_t n, size_t g, size_t h, size_t fh, size_t padding, size_t n, size_t g, size_t h, size_t fh, size_t padding,
...@@ -750,36 +750,36 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { ...@@ -750,36 +750,36 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
checker.set_param(cur_param).execs( checker.set_param(cur_param).execs(
{{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}}); {{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}});
}; };
run(4, 8, 32, 5, 5 / 2, 1); // run(4, 8, 32, 5, 5 / 2, 1);
run(4, 8, 32, 7, 7 / 2, 1); // run(4, 8, 32, 7, 7 / 2, 1);
run(4, 8, 32, 9, 9 / 2, 1); // run(4, 8, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1); // run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1); // run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1); // run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1); // run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1); // run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1); // run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1); // run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1); // run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1); // run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1); // run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1); run(64, 384, 32, 31, 31 / 2, 1);
run(4, 8, 64, 5, 5 / 3, 2); // run(4, 8, 64, 5, 5 / 3, 2);
run(4, 8, 64, 7, 7 / 3, 2); // run(4, 8, 64, 7, 7 / 3, 2);
run(4, 8, 64, 9, 9 / 3, 2); // run(4, 8, 64, 9, 9 / 3, 2);
run(4, 8, 64, 11, 11 / 3, 2); // run(4, 8, 64, 11, 11 / 3, 2);
run(4, 8, 64, 13, 13 / 3, 2); // run(4, 8, 64, 13, 13 / 3, 2);
run(4, 8, 64, 15, 15 / 3, 2); // run(4, 8, 64, 15, 15 / 3, 2);
run(4, 8, 64, 17, 17 / 3, 2); // run(4, 8, 64, 17, 17 / 3, 2);
run(4, 8, 64, 19, 19 / 3, 2); // run(4, 8, 64, 19, 19 / 3, 2);
run(4, 8, 64, 21, 21 / 3, 2); // run(4, 8, 64, 21, 21 / 3, 2);
run(4, 8, 64, 23, 23 / 3, 2); // run(4, 8, 64, 23, 23 / 3, 2);
run(4, 8, 64, 25, 25 / 3, 2); // run(4, 8, 64, 25, 25 / 3, 2);
run(4, 8, 64, 27, 27 / 3, 2); // run(4, 8, 64, 27, 27 / 3, 2);
run(4, 8, 64, 29, 29 / 3, 2); // run(4, 8, 64, 29, 29 / 3, 2);
run(4, 8, 64, 31, 31 / 3, 2); // run(4, 8, 64, 31, 31 / 3, 2);
run(1, 2, 128, 31, 10, 2); // run(1, 2, 128, 31, 10, 2);
run(1, 2, 256, 31, 10, 2); // run(1, 2, 256, 31, 10, 2);
} }
} }
...@@ -1638,10 +1638,10 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP32) { ...@@ -1638,10 +1638,10 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP32) {
ConvBias::Param param; ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW; param.format = ConvBias::Param::Format::NCHW;
using NonlineMode = ConvBias::Param::NonlineMode; using NonlineMode = ConvBias::Param::NonlineMode;
param.nonlineMode = NonlineMode::IDENTITY; param.nonlineMode = NonlineMode::IDENTITY;
param.sparse = ConvBias::Param::Sparse::GROUP; param.sparse = ConvBias::Param::Sparse::GROUP;
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) { size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2; param.pad_h = fh / 2;
......
#include "megdnn/dtype.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/common/conv_bias.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/common/workspace_wrapper.h"
#include "test/cuda/benchmark.h"
#include "test/cuda/fixture.h"
#include "test/cuda/utils.h"
#include <cudnn.h>
#define V1(x) #x
#define V(x) V1(x)
#define CUDNN_VERSION_STRING \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL)
namespace megdnn {
namespace test {
TEST_F(CUDA, REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER) {
Checker<RegionRestrictedConvolutionForward> checker(handle_cuda());
auto opr = handle_cuda()->create_operator<ConvolutionForward>();
for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) {
auto run = [&checker, &dt, &opr](
size_t n, size_t g, size_t h, size_t fh, size_t padding,
size_t stride) {
RegionRestrictedConvolution::Param cur_param;
cur_param.mode =
RegionRestrictedConvolution::Param::Mode::CROSS_CORRELATION;
cur_param.sparse = RegionRestrictedConvolution::Param::Sparse::GROUP;
checker.set_dtype(2, dt).set_dtype(3, dt);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(scale, 2 * scale);
UniformIntRNG r_rng{0, 2};
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng(
3, &r_rng);
if (dt.enumv() == DTypeEnum::Float16) {
checker.set_epsilon(1e-1);
}
cur_param.pad_h = cur_param.pad_w = padding;
cur_param.stride_h = cur_param.stride_w = stride;
size_t ho = infer_conv_shape(h, fh, stride, padding);
checker.set_param(cur_param).execs(
{{n, g, h, h}, {g, 1, 1, fh, fh}, {n, h, h}, {n, ho, ho}, {}});
};
run(4, 8, 32, 3, 3 / 2, 1);
run(4, 8, 32, 5, 5 / 2, 1);
run(4, 8, 32, 7, 7 / 2, 1);
run(1, 2, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1);
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) {
require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
Benchmarker<RegionRestrictedConvolutionForward> rr_bencher(handle_cuda());
rr_bencher.set_display(false);
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW;
using NonlineMode = ConvBias::Param::NonlineMode;
param.nonlineMode = NonlineMode::IDENTITY;
param.sparse = ConvBias::Param::Sparse::GROUP;
RegionRestrictedConvolutionForward::Param rr_param;
rr_param.format = RegionRestrictedConvolutionForward::Param::Format::NCHW;
rr_param.sparse = RegionRestrictedConvolutionForward::Param::Sparse::GROUP;
UniformIntRNG r_rng{0, 2};
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
rr_param.pad_h = fh / 2;
rr_param.pad_w = fw / 2;
rr_param.stride_h = sh;
rr_param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.set_times(nr_times);
rr_bencher.set_param(rr_param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::Int32());
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng).set_rng(0, &r_rng);
rr_bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape inp{batch, g, hi, wi}, kern{g, 1, 1, fh, fw}, rin{batch, hi, wi},
rout{batch, ho, wo}, out{batch, g, ho, wo};
float bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
float rr_bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
rin.total_nr_elems() + rout.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
auto time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12;
auto rr_time_in_ms = rr_bencher.execs({inp, kern, rin, rout, out}) / nr_times;
auto rr_ops =
2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12;
printf("RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: inp=%s, "
"kern=%s, out=%s\n"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops,
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms,
time_in_ms / rr_time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) {
require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
Benchmarker<RegionRestrictedConvolutionForward> rr_bencher(handle_cuda());
rr_bencher.set_display(false);
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW;
using NonlineMode = ConvBias::Param::NonlineMode;
param.nonlineMode = NonlineMode::IDENTITY;
param.sparse = ConvBias::Param::Sparse::GROUP;
RegionRestrictedConvolutionForward::Param rr_param;
rr_param.format = RegionRestrictedConvolutionForward::Param::Format::NCHW;
rr_param.sparse = RegionRestrictedConvolutionForward::Param::Sparse::GROUP;
UniformIntRNG r_rng{0, 2};
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
rr_param.pad_h = fh / 2;
rr_param.pad_w = fw / 2;
rr_param.stride_h = sh;
rr_param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.set_times(nr_times);
rr_bencher.set_param(rr_param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Uint8())
.set_dtype(3, dtype::Uint8());
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng).set_rng(0, &r_rng);
rr_bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape inp{batch, g, hi, wi}, kern{g, 1, 1, fh, fw}, rin{batch, hi, wi},
rout{batch, ho, wo}, out{batch, g, ho, wo};
float bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
float rr_bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
rin.total_nr_elems() + rout.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
auto time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12;
auto rr_time_in_ms = rr_bencher.execs({inp, kern, rin, rout, out}) / nr_times;
auto rr_ops =
2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12;
printf("RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: inp=%s, "
"kern=%s, out=%s\n"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops,
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms,
time_in_ms / rr_time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
#endif
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -11,8 +11,8 @@ using namespace megdnn; ...@@ -11,8 +11,8 @@ using namespace megdnn;
using namespace test; using namespace test;
namespace { namespace {
template <typename rtype>
void mask_tensor( void mask_tensor_kernel(
const TensorND& in, TensorND& out, const TensorND& mask, const TensorND& in, TensorND& out, const TensorND& mask,
const int32_t mask_val) { const int32_t mask_val) {
megdnn_assert( megdnn_assert(
...@@ -23,7 +23,7 @@ void mask_tensor( ...@@ -23,7 +23,7 @@ void mask_tensor(
mask.layout[0] == in.layout[0] && mask.layout[1] == in.layout[2] && mask.layout[0] == in.layout[0] && mask.layout[1] == in.layout[2] &&
mask.layout[2] == in.layout[3]); mask.layout[2] == in.layout[3]);
int32_t* mask_ptr = mask.ptr<int32_t>(); rtype* mask_ptr = mask.compatible_ptr<rtype>();
float* src_ptr = in.compatible_ptr<float>(); float* src_ptr = in.compatible_ptr<float>();
float* dst_ptr = out.compatible_ptr<float>(); float* dst_ptr = out.compatible_ptr<float>();
...@@ -47,6 +47,16 @@ void mask_tensor( ...@@ -47,6 +47,16 @@ void mask_tensor(
} }
} }
} }
void mask_tensor(
const TensorND& in, TensorND& out, const TensorND& mask,
const int32_t mask_val) {
if (mask.layout.dtype == dtype::Int32()) {
mask_tensor_kernel<dt_int32>(in, out, mask, mask_val);
} else if (mask.layout.dtype == dtype::Uint8()) {
mask_tensor_kernel<dt_uint8>(in, out, mask, mask_val);
}
}
} // namespace } // namespace
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
...@@ -54,7 +64,7 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { ...@@ -54,7 +64,7 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
RegionRestrictedConvolution::Param param; RegionRestrictedConvolution::Param param;
constexpr int N = 3; constexpr int N = 3;
UniformIntRNG rng{0, N-1}; UniformIntRNG rng{0, N - 1};
auto extra_impl = [&, this](const TensorNDArray& tensors) { auto extra_impl = [&, this](const TensorNDArray& tensors) {
auto conv = handle()->create_operator<Convolution>(); auto conv = handle()->create_operator<Convolution>();
...@@ -64,16 +74,17 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { ...@@ -64,16 +74,17 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size)); dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
Workspace workspace{workspace_ptr, workspace_size}; Workspace workspace{workspace_ptr, workspace_size};
TensorND masked_src(malloc(tensors[0].layout.span().dist_byte()), tensors[0].layout); TensorND masked_src(
malloc(tensors[0].layout.span().dist_byte()), tensors[0].layout);
TensorNDArray dst_tensors; TensorNDArray dst_tensors;
for(int i=0; i<N; ++i) { for (int i = 0; i < N; ++i) {
dst_tensors.emplace_back(malloc(tensors[4].layout.span().dist_byte()), tensors[4].layout); dst_tensors.emplace_back(
malloc(tensors[4].layout.span().dist_byte()), tensors[4].layout);
} }
for(int i=0; i<N; ++i) { for (int i = 0; i < N; ++i) {
mask_tensor(tensors[0], masked_src, tensors[2], i); mask_tensor(tensors[0], masked_src, tensors[2], i);
conv->exec(masked_src, tensors[1], dst_tensors[i], nullptr, workspace); conv->exec(masked_src, tensors[1], dst_tensors[i], nullptr, workspace);
mask_tensor(dst_tensors[i], dst_tensors[i], tensors[3], i); mask_tensor(dst_tensors[i], dst_tensors[i], tensors[3], i);
} }
free(workspace_ptr); free(workspace_ptr);
...@@ -81,7 +92,7 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { ...@@ -81,7 +92,7 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
auto add = handle()->create_operator<ElemwiseForward>(); auto add = handle()->create_operator<ElemwiseForward>();
add->param().mode = Mode::ADD; add->param().mode = Mode::ADD;
add->exec({dst_tensors[0], dst_tensors[1]}, tensors[4]); add->exec({dst_tensors[0], dst_tensors[1]}, tensors[4]);
for (int i=2; i<N; ++i) { for (int i = 2; i < N; ++i) {
add->exec({dst_tensors[i], tensors[4]}, tensors[4]); add->exec({dst_tensors[i], tensors[4]}, tensors[4]);
} }
}; };
...@@ -92,6 +103,12 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { ...@@ -92,6 +103,12 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
.set_dtype(2, dtype::Int32()) .set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::Int32()); .set_dtype(3, dtype::Int32());
checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}})
.execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}})
.execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}});
checker.set_dtype(2, dtype::Uint8()).set_dtype(3, dtype::Uint8());
checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}}) checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}})
.execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}}) .execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}})
.execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}}); .execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}});
...@@ -99,100 +116,19 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { ...@@ -99,100 +116,19 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
param.sparse = Convolution::Param::Sparse::GROUP; param.sparse = Convolution::Param::Sparse::GROUP;
checker.set_param(param) checker.set_param(param)
.execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}}) .execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}})
.execs({{20, 25, 30, 30}, {25, 1, 1, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}}); .execs({{20, 25, 30, 30},
} {25, 1, 1, 3, 3},
{20, 30, 30},
#if 0 {20, 28, 28},
{}});
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BACKWARD_DATA) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle()); checker.set_dtype(2, dtype::Int32()).set_dtype(3, dtype::Int32());
using Param = RegionRestrictedConvolutionBackwardData::Param; checker.execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}})
Param param; .execs({{20, 25, 30, 30},
{25, 1, 1, 3, 3},
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh, {20, 30, 30},
size_t fw, size_t stride, size_t padding, size_t dilate = 1, {20, 28, 28},
size_t group = 1) { {}});
param.pad_h = param.pad_w = padding;
param.stride_h = param.stride_w = stride;
param.dilate_h = param.dilate_w = dilate;
TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()};
TensorLayout grad;
TensorLayout filter;
if (group == 1) {
param.sparse = Param::Sparse::DENSE;
filter = {{oc, ic, fh, fw}, dtype::Float32()};
} else {
param.sparse = Param::Sparse::GROUP;
filter = {{group, oc, ic, fh, fw}, dtype::Float32()};
}
// TensorLayout grad;
{
auto opr = handle()->create_operator<ConvolutionBackwardData>();
opr->param() = param;
opr->deduce_layout(filter, diff, grad);
}
checker.set_param(param);
checker.exec(TensorLayoutArray{filter, diff, grad});
};
for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) {
param.mode = mode;
run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1, 1);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 1, 2);
run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4, 3);
run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1, 2);
run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4, 3);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2, 2);
run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2, 3);
run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3, 2);
}
}
TEST_F(NAIVE, CONVOLUTION_BACKWARD_DATA) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle());
using Param = RegionRestrictedConvolutionBackwardData::Param;
Param param;
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh,
size_t fw, size_t stride, size_t padding, size_t dilate = 1,
size_t group = 1) {
param.pad_h = param.pad_w = padding;
param.stride_h = param.stride_w = stride;
param.dilate_h = param.dilate_w = dilate;
TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()};
TensorLayout grad;
TensorLayout filter;
if (group == 1) {
param.sparse = Param::Sparse::DENSE;
filter = {{oc, ic, fh, fw}, dtype::Float32()};
} else {
param.sparse = Param::Sparse::GROUP;
filter = {{group, oc, ic, fh, fw}, dtype::Float32()};
}
// TensorLayout grad;
{
auto opr = handle()->create_operator<ConvolutionBackwardData>();
opr->param() = param;
opr->deduce_layout(filter, diff, grad);
}
checker.set_param(param);
checker.exec(TensorLayoutArray{filter, diff, grad});
};
for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) {
param.mode = mode;
run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1, 1);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 1, 2);
run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4, 3);
run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1, 2);
run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4, 3);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2, 2);
run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2, 3);
run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3, 2);
}
} }
#endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册