提交 977c2071 编写于 作者: M Megvii Engine Team

feat(dnn): add RegionRestrictedConv DGRAD support int32 and uint8

GitOrigin-RevId: 814b8a83f8ac5c3f395d6760be8bd3f30fcedfbb
上级 543c9b77
...@@ -38,7 +38,7 @@ void RegionRestrictedConvolutionForward::deduce_dtype( ...@@ -38,7 +38,7 @@ void RegionRestrictedConvolutionForward::deduce_dtype(
"only float type is supported for region_restricted_conv forward"); "only float type is supported for region_restricted_conv forward");
megdnn_assert( megdnn_assert(
rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()),
"the dtype of rin/rout should be Int32, got %s.", rin.name()); "the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name());
} }
void RegionRestrictedConvolutionForward::deduce_layout( void RegionRestrictedConvolutionForward::deduce_layout(
...@@ -91,12 +91,12 @@ RegionRestrictedConvolutionBackwardData::check_exec( ...@@ -91,12 +91,12 @@ RegionRestrictedConvolutionBackwardData::check_exec(
auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd); auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
#define err_msg(lhs, rhs) \ #define err_msg(lhs, rhs) \
megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs);
err_msg(rin.shape[0], grad_fwd.shape[0]); err_msg(rin.shape[0], grad_fwd.shape[0]); // batch
err_msg(rin.shape[1], grad_fwd.shape[2]); err_msg(rin.shape[1], grad_fwd.shape[2]); // ih
err_msg(rin.shape[2], grad_fwd.shape[3]); err_msg(rin.shape[2], grad_fwd.shape[3]); // iw
err_msg(rout.shape[0], diff_fwd.shape[0]); err_msg(rout.shape[0], diff_fwd.shape[0]); // batch
err_msg(rout.shape[1], diff_fwd.shape[2]); err_msg(rout.shape[1], diff_fwd.shape[2]); // oh
err_msg(rout.shape[2], diff_fwd.shape[3]); err_msg(rout.shape[2], diff_fwd.shape[3]); // ow
#undef err_msg #undef err_msg
auto required_workspace_in_bytes = auto required_workspace_in_bytes =
get_workspace_in_bytes(filter, diff, rin, rout, grad); get_workspace_in_bytes(filter, diff, rin, rout, grad);
...@@ -106,45 +106,22 @@ RegionRestrictedConvolutionBackwardData::check_exec( ...@@ -106,45 +106,22 @@ RegionRestrictedConvolutionBackwardData::check_exec(
void RegionRestrictedConvolutionBackwardData::deduce_dtype( void RegionRestrictedConvolutionBackwardData::deduce_dtype(
DType filter, DType diff, DType rin, DType rout, DType& grad) { DType filter, DType diff, DType rin, DType rout, DType& grad) {
SmallVector<DType> supported_dst_dtype; // FIXME: infering dtype of grad via naive impl only support fp32
if (filter.category() == diff.category() && // (lack of quantized dtype infering or others) may not suitable in the furture
filter.category() == DTypeCategory::FLOAT) {
supported_dst_dtype.push_back(filter);
} else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
supported_dst_dtype.push_back(dtype::Int32());
} else if (
(filter.enumv() == DTypeEnum::QuantizedS8 &&
diff.enumv() == DTypeEnum::QuantizedS8) ||
(filter.enumv() == DTypeEnum::Quantized8Asymm &&
diff.enumv() == DTypeEnum::Quantized8Asymm)) {
supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
if (grad.valid() && grad.enumv() == diff.enumv()) {
supported_dst_dtype.push_back(grad);
}
} else {
megdnn_throw(ssprintf(
"unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
}
if (!grad.valid()) {
grad = supported_dst_dtype.at(0);
} else {
megdnn_assert(
vec_contains(supported_dst_dtype, grad),
"unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
grad.name());
}
megdnn_assert(
param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
|| filter.enumv() == DTypeEnum::Float16 || if (diff.enumv() == DTypeEnum::Float32 || diff.enumv() == DTypeEnum::Float16) {
filter.enumv() == DTypeEnum::BFloat16 grad = diff;
}
#endif #endif
, megdnn_assert(grad.valid(), "dtype of grad requires deducing of assigned");
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output.");
megdnn_assert( megdnn_assert(
rin == rout && rin == dtype::Int32(), diff.category() == DTypeCategory::FLOAT &&
"the dtype of rin/rout should be Int32, got %s.", rin.name()); filter.category() == DTypeCategory::FLOAT &&
grad.category() == DTypeCategory::FLOAT,
"only float type is supported for region_restricted_conv backward data");
megdnn_assert(
rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()),
"the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name());
} }
void RegionRestrictedConvolutionBackwardData::deduce_layout( void RegionRestrictedConvolutionBackwardData::deduce_layout(
......
#include "./kern.cuh"
#include "cuda.h" #include "cuda.h"
#include "cuda_fp16.h" #include "cuda_fp16.h"
#include "src/cuda/fp16_help.cuh" #include "src/cuda/fp16_help.cuh"
#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -15,7 +15,7 @@ namespace cuda { ...@@ -15,7 +15,7 @@ namespace cuda {
namespace region_restricted_convolution { namespace region_restricted_convolution {
namespace chanwise { namespace chanwise {
// =====================================fwd===================================== // =====================================bwd=====================================
template <> template <>
void run_bwd_depthwise_large_filter( void run_bwd_depthwise_large_filter(
......
...@@ -498,16 +498,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -498,16 +498,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
SrcGlobal2ShareVisitor gl2sh_src = { SrcGlobal2ShareVisitor gl2sh_src = {
smem_src, smem_src,
static_cast<int>(param.src_w), static_cast<int>(param.src_w),
static_cast<int>( static_cast<int>(src_start_h),
is_fwd ? src_start_h static_cast<int>(src_start_w),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h), is_fwd ? 1 : static_cast<int>(param.stride_h),
...@@ -516,16 +508,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -516,16 +508,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
RinGlobal2ShareVisitor gl2sh_rin = { RinGlobal2ShareVisitor gl2sh_rin = {
smem_rin, smem_rin,
static_cast<int>(param.src_w), static_cast<int>(param.src_w),
static_cast<int>( static_cast<int>(src_start_h),
is_fwd ? src_start_h static_cast<int>(src_start_w),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h), is_fwd ? 1 : static_cast<int>(param.stride_h),
...@@ -790,14 +774,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -790,14 +774,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
static_assert((FilterTileConfig::unroll_w & 3) == 0); static_assert(
(FilterTileConfig::unroll_w & 3) == 0, "filter tile unroll_w & 3 != 0");
int* smem_rin_ptr = smem_rin + (off_ow * FilterTileConfig::unroll_w >> 2); int* smem_rin_ptr = smem_rin + (off_ow * FilterTileConfig::unroll_w >> 2);
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
const uint8_t* rout_base_ptr = rout + batch * param.out_h * param.out_w; const uint8_t* rout_base_ptr = rout + batch * param.out_h * param.out_w;
static_assert((OutTileConfig::unroll_w & 3) == 0); static_assert((OutTileConfig::unroll_w & 3) == 0, "output tile unroll_w & 3 != 0");
static_assert((OutTileConfig::block_w & 3) == 0); static_assert((OutTileConfig::block_w & 3) == 0, "output block_w & 3 != 0");
int reg_rout[OutTileConfig::unroll_size] = {0}; int reg_rout[OutTileConfig::unroll_size] = {0};
#pragma unroll #pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) { for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
...@@ -821,16 +806,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -821,16 +806,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
SrcGlobal2ShareVisitor gl2sh_src = { SrcGlobal2ShareVisitor gl2sh_src = {
smem_src, smem_src,
static_cast<int>(param.src_w), static_cast<int>(param.src_w),
static_cast<int>( static_cast<int>(src_start_h),
is_fwd ? src_start_h static_cast<int>(src_start_w),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h), is_fwd ? 1 : static_cast<int>(param.stride_h),
...@@ -839,16 +816,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -839,16 +816,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
RinGlobal2ShareVisitor gl2sh_rin = { RinGlobal2ShareVisitor gl2sh_rin = {
smem_rin, smem_rin,
static_cast<int>(param.src_w), static_cast<int>(param.src_w),
static_cast<int>( static_cast<int>(src_start_h),
is_fwd ? src_start_h static_cast<int>(src_start_w),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h), is_fwd ? 1 : static_cast<int>(param.stride_h),
...@@ -1134,14 +1103,20 @@ void LaunchDepthwiseConv2dGPU( ...@@ -1134,14 +1103,20 @@ void LaunchDepthwiseConv2dGPU(
RinTileCount::smem_size * sizeof(int); RinTileCount::smem_size * sizeof(int);
void (*kernel)(const Param, const T*, const T*, const RT*, const RT*, T*); void (*kernel)(const Param, const T*, const T*, const RT*, const RT*, T*);
const bool is_fwd = (kDirection == DIRECTION_FORWARD);
if (param.is_compute_deafult) { if (param.is_compute_deafult) {
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>; kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>;
} else { } else {
megdnn_assert_internal(0); megdnn_assert_internal(0);
} }
kernel<<<grid, block, shared_storage, stream>>>( if (is_fwd) {
param, input, filter, rin, rout, output); kernel<<<grid, block, shared_storage, stream>>>(
param, input, filter, rin, rout, output);
} else {
kernel<<<grid, block, shared_storage, stream>>>(
param, input, filter, rout, rin, output);
}
after_kernel_launch(); after_kernel_launch();
} }
......
...@@ -55,25 +55,65 @@ size_t RegionRestrictedConvolutionBackwardDataImpl::get_workspace_in_bytes( ...@@ -55,25 +55,65 @@ size_t RegionRestrictedConvolutionBackwardDataImpl::get_workspace_in_bytes(
void RegionRestrictedConvolutionBackwardDataImpl::exec( void RegionRestrictedConvolutionBackwardDataImpl::exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
megdnn_throw(ssprintf( auto fm = check_exec(
"unsupported RegionRestrictedConvolutionBackwardData(%s, %s, %s, %s) -> %s", filter.layout, diff.layout, rin.layout, rout.layout, grad.layout,
filter.layout.dtype.name(), diff.layout.dtype.name(), workspace.size);
rin.layout.dtype.name(), rout.layout.dtype.name(), // XXX: a naive impl to set deconv padding to param, needs optimization in future.
grad.layout.dtype.name())); [&]() -> void {
size_t stride = fm.stride[0];
size_t src_size = grad.layout.shape[2];
size_t fwd_pad = fm.padding[0];
size_t filter_size = fm.spatial[0];
size_t deconv_pad = (stride * src_size - stride + stride * filter_size -
src_size - 2 * fwd_pad + filter_size - 1) /
(2 * stride);
fm.padding[0] = fm.padding[1] = deconv_pad;
return;
}();
auto kparam = chanwise::Param::load(
diff.layout, grad.layout, fm,
param().compute_mode == Param::ComputeMode::DEFAULT);
megdnn_assert(
fm.group > 1 && diff.layout.dtype.category() == DTypeCategory::FLOAT &&
param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip &&
param().stride_h == 1 && param().stride_w == 1);
// NOTE: uint8 dtype region mask requires the spatial size of src&dst is 4*N
if (rin.layout.dtype == dtype::Uint8()) {
megdnn_assert(
(grad.layout.shape[3] & 3) == 0 && (diff.layout.shape[3] & 3) == 0);
}
auto stream = cuda_stream(handle());
if (filter.layout.dtype == dtype::Float32() && rin.layout.dtype == dtype::Int32() &&
rout.layout.dtype == dtype::Int32()) {
chanwise::run_bwd_depthwise_large_filter(
grad.ptr<dt_float32>(), diff.ptr<dt_float32>(),
filter.ptr<dt_float32>(), rin.ptr<dt_int32>(), rout.ptr<dt_int32>(),
kparam, stream);
} else if (
filter.layout.dtype == dtype::Float32() &&
rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) {
chanwise::run_bwd_depthwise_large_filter(
grad.ptr<dt_float32>(), diff.ptr<dt_float32>(),
filter.ptr<dt_float32>(), rin.ptr<dt_uint8>(), rout.ptr<dt_uint8>(),
kparam, stream);
} else {
megdnn_throw("undefined or unimplemented region restricted conv mode");
}
} }
size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes( size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout&, const TensorLayout& src, const TensorLayout& diff, const TensorLayout&,
const TensorLayout&, const TensorLayout& grad) { const TensorLayout&, const TensorLayout& grad) {
size_t workspace_size = 0; return 0;
return workspace_size;
} }
/* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */ /* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */
void RegionRestrictedConvolutionBackwardFilterImpl::exec( void RegionRestrictedConvolutionBackwardFilterImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
megdnn_assert_internal(0); megdnn_throw("Region Restricted Conv BackwardFilter unimplemented");
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -117,7 +117,7 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { ...@@ -117,7 +117,7 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) {
.set_dtype(1, dtype::Float32()) .set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Int32()) .set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::Int32()); .set_dtype(3, dtype::Int32());
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng).set_rng(0, &r_rng); rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng);
rr_bencher.set_times(nr_times); rr_bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
...@@ -169,6 +169,202 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { ...@@ -169,6 +169,202 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) {
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
} }
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32) {
require_compute_capability(7, 5);
Benchmarker<ConvolutionBackwardData> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
Benchmarker<RegionRestrictedConvolutionBackwardData> rr_bencher(handle_cuda());
rr_bencher.set_display(false);
ConvolutionBackwardData::Param param;
param.format = ConvolutionBackwardData::Param::Format::NCHW;
param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP;
RegionRestrictedConvolutionBackwardData::Param rr_param;
rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW;
rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP;
UniformIntRNG r_rng{1, 3};
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
rr_param.pad_h = fh / 2;
rr_param.pad_w = fw / 2;
rr_param.stride_h = sh;
rr_param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.set_times(nr_times);
rr_bencher.set_param(rr_param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::Int32());
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng);
rr_bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/,
rin{batch, hi, wi}, rout{batch, ho, wo},
out{batch, g, ho, wo} /*output*/;
float bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
float rr_bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
rin.total_nr_elems() + rout.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times;
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12;
auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times;
auto rr_ops =
2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12;
printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: "
"grad=%s, "
"kern=%s, diff=%s\n"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops,
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms,
time_in_ms / rr_time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32_UINT8) {
require_compute_capability(7, 5);
Benchmarker<ConvolutionBackwardData> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
Benchmarker<RegionRestrictedConvolutionBackwardData> rr_bencher(handle_cuda());
rr_bencher.set_display(false);
ConvolutionBackwardData::Param param;
param.format = ConvolutionBackwardData::Param::Format::NCHW;
param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP;
RegionRestrictedConvolutionBackwardData::Param rr_param;
rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW;
rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP;
UniformIntRNG r_rng{1, 3};
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
rr_param.pad_h = fh / 2;
rr_param.pad_w = fw / 2;
rr_param.stride_h = sh;
rr_param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.set_times(nr_times);
rr_bencher.set_param(rr_param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Uint8())
.set_dtype(3, dtype::Uint8());
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng);
rr_bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/,
rin{batch, hi, wi}, rout{batch, ho, wo},
out{batch, g, ho, wo} /*output*/;
float bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
float rr_bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
rin.total_nr_elems() + rout.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times;
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12;
auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times;
auto rr_ops =
2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12;
printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: "
"grad=%s, "
"kern=%s, diff=%s\n"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops,
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms,
time_in_ms / rr_time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) {
require_compute_capability(7, 5); require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda()); Benchmarker<ConvBiasForward> bencher(handle_cuda());
...@@ -271,6 +467,124 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { ...@@ -271,6 +467,124 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) {
#endif #endif
TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle_cuda());
for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) {
auto run = [&checker, &dt](
size_t n, size_t g, size_t ih, size_t fh, size_t padding,
size_t stride) {
RegionRestrictedConvolutionBackwardData::Param cur_param;
cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode::
CROSS_CORRELATION;
cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param::
ComputeMode::DEFAULT;
cur_param.sparse =
RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dt)
.set_dtype(3, dt);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(scale, 2 * scale);
UniformIntRNG r_rng{1, 2};
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng(
3, &r_rng);
cur_param.pad_h = cur_param.pad_w = padding;
cur_param.stride_h = cur_param.stride_w = stride;
size_t oh = (ih + 2 * padding - fh + 1) / stride;
checker.set_param(cur_param).execs({
{g, 1, 1, fh, fh}, // filter
{n, g * 1, oh, oh}, // diff
{n, ih, ih}, // rin
{n, oh, oh}, // rout
{n, g * 1, ih, ih} // grad
});
};
if (dt == dtype::Int32()) {
run(4, 8, 32, 5, 5 / 2, 1);
run(1, 2, 2, 2, 0, 1);
run(1, 2, 3, 3, 0, 1);
run(1, 2, 4, 4, 0, 1);
run(1, 2, 5, 5, 0, 1);
run(1, 2, 6, 6, 0, 1);
run(1, 2, 7, 7, 0, 1);
}
run(4, 8, 32, 7, 7 / 2, 1);
run(4, 8, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1);
}
}
TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32_RIN_EQ_ROUT) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle_cuda());
for (auto dt : std::vector<DType>{dtype::Int32()}) {
auto run = [&checker, &dt](
size_t n, size_t g, size_t ih, size_t fh, size_t padding,
size_t stride) {
RegionRestrictedConvolutionBackwardData::Param cur_param;
cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode::
CROSS_CORRELATION;
cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param::
ComputeMode::DEFAULT;
cur_param.sparse =
RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP;
checker.set_dtype(2, dt).set_dtype(3, dt);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(scale, 2 * scale);
// value 0 mask may cause unexpected behaviour.
UniformIntRNG r_rng{1, 1};
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng(
3, &r_rng);
cur_param.pad_h = cur_param.pad_w = padding;
cur_param.stride_h = cur_param.stride_w = stride;
size_t oh = (ih + 2 * padding - fh + 1) / stride;
checker.set_param(cur_param).execs(
{/*filter*/ {g, 1, 1, fh, fh},
/*diff*/ {n, g * 1, oh, oh},
/*rin*/ {n, ih, ih},
/*rout*/ {n, oh, oh},
/*grad*/ {n, g * 1, ih, ih}});
};
if (dt == dtype::Int32()) {
// NOTE: UINT8 assert the spatial size of src&dst is 4*N
run(4, 8, 32, 5, 5 / 2, 1);
run(1, 2, 2, 2, 0, 1);
run(1, 2, 3, 3, 0, 1);
run(1, 2, 4, 4, 0, 1);
run(1, 2, 5, 5, 0, 1);
run(1, 2, 6, 6, 0, 1);
run(1, 2, 7, 7, 0, 1);
}
run(4, 8, 32, 7, 7 / 2, 1);
run(4, 8, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1);
}
}
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn
......
...@@ -131,4 +131,110 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { ...@@ -131,4 +131,110 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
{}}); {}});
} }
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD_DENSE_BRUTE) {
Checker<RegionRestrictedConvolutionForward> checker(handle());
RegionRestrictedConvolutionForward::Param param;
checker.set_param(param).exect(
Testcase{
TensorValue( // src
{1, 1, 4, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}),
TensorValue( // filter
{1, 1, 2, 2}, dtype::Float32(), {1, 1, 1, 1}),
TensorValue( // rin
{1, 4, 4}, dtype::Int32(),
{1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}),
TensorValue( // rout
{1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}),
{}, // output
},
Testcase{
{},
{},
{},
{},
TensorValue(
{1, 1, 3, 3}, dtype::Float32(),
{4, 14, 18, 5, 9, 0, 13, 9, 50})});
}
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_DENSE_BRUTE) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle());
RegionRestrictedConvolutionBackwardData::Param param;
checker.set_param(param).exect(
Testcase{
// filter
TensorValue(
{1, 1, 2, 2}, // shape
dtype::Float32(), // dtype
{1.f, 1.f, 1.f, 1.f}),
// diff
TensorValue(
{1, 1, 3, 3}, dtype::Float32(),
{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}),
// rin
TensorValue(
{1, 4, 4}, dtype::Int32(),
{1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}),
// rout
TensorValue({1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}),
// grad
{}},
Testcase{// filter
{},
// diff
{},
// rin
{},
// rout
{},
// grad
TensorValue(
{1, 1, 4, 4}, dtype::Float32(),
{0., 2., 5., 3., 1., 6., 5., 3., 0., 13., 9., 9., 0., 7.,
9., 9.})});
}
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_GROUP_BRUTE) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle());
// params
RegionRestrictedConvolutionBackwardData::Param param;
param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP;
param.mode = RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION;
param.compute_mode =
RegionRestrictedConvolutionBackwardData::Param::ComputeMode::DEFAULT;
param.pad_h = param.pad_w =
0; // forward param, naive backward data doesn't matter with deconv padding
param.stride_h = param.stride_w = 1;
// checker setting
checker.set_param(param).exect(
Testcase{// filter
TensorValue(
{2, 1, 1, 2, 2}, // shape
dtype::Float32(), // dtype
{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}),
// diff
TensorValue({1, 2, 1, 1}, dtype::Float32(), {1, 2}),
// rin
TensorValue({1, 2, 2}, dtype::Int32(), {1, 1, 1, 1}),
// rout
TensorValue({1, 1, 1}, dtype::Int32(), {1}),
// grad
{}},
Testcase{// filter
{},
// diff
{},
// rin
{},
// rout
{},
// grad
TensorValue(
{1, 2, 2, 2}, dtype::Float32(),
{1, 2, 3, 4, 10, 12, 14, 16})});
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册