diff --git a/dnn/src/common/region_restricted_convolution.cpp b/dnn/src/common/region_restricted_convolution.cpp index 2f2110db2d19856e3c1ac208d7796dcde08a789d..efdaf9283ce1b63a24fe9fd1f44a721d1e8966b4 100644 --- a/dnn/src/common/region_restricted_convolution.cpp +++ b/dnn/src/common/region_restricted_convolution.cpp @@ -38,7 +38,7 @@ void RegionRestrictedConvolutionForward::deduce_dtype( "only float type is supported for region_restricted_conv forward"); megdnn_assert( rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), - "the dtype of rin/rout should be Int32, got %s.", rin.name()); + "the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name()); } void RegionRestrictedConvolutionForward::deduce_layout( @@ -91,12 +91,12 @@ RegionRestrictedConvolutionBackwardData::check_exec( auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd); #define err_msg(lhs, rhs) \ megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); - err_msg(rin.shape[0], grad_fwd.shape[0]); - err_msg(rin.shape[1], grad_fwd.shape[2]); - err_msg(rin.shape[2], grad_fwd.shape[3]); - err_msg(rout.shape[0], diff_fwd.shape[0]); - err_msg(rout.shape[1], diff_fwd.shape[2]); - err_msg(rout.shape[2], diff_fwd.shape[3]); + err_msg(rin.shape[0], grad_fwd.shape[0]); // batch + err_msg(rin.shape[1], grad_fwd.shape[2]); // ih + err_msg(rin.shape[2], grad_fwd.shape[3]); // iw + err_msg(rout.shape[0], diff_fwd.shape[0]); // batch + err_msg(rout.shape[1], diff_fwd.shape[2]); // oh + err_msg(rout.shape[2], diff_fwd.shape[3]); // ow #undef err_msg auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, rin, rout, grad); @@ -106,45 +106,22 @@ RegionRestrictedConvolutionBackwardData::check_exec( void RegionRestrictedConvolutionBackwardData::deduce_dtype( DType filter, DType diff, DType rin, DType rout, DType& grad) { - SmallVector supported_dst_dtype; - if (filter.category() == diff.category() && - filter.category() == DTypeCategory::FLOAT) { - supported_dst_dtype.push_back(filter); - } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) { - supported_dst_dtype.push_back(dtype::Int32()); - } else if ( - (filter.enumv() == DTypeEnum::QuantizedS8 && - diff.enumv() == DTypeEnum::QuantizedS8) || - (filter.enumv() == DTypeEnum::Quantized8Asymm && - diff.enumv() == DTypeEnum::Quantized8Asymm)) { - supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff))); - if (grad.valid() && grad.enumv() == diff.enumv()) { - supported_dst_dtype.push_back(grad); - } - } else { - megdnn_throw(ssprintf( - "unsupported input / diff DType: %s x %s", filter.name(), diff.name())); - } - if (!grad.valid()) { - grad = supported_dst_dtype.at(0); - } else { - megdnn_assert( - vec_contains(supported_dst_dtype, grad), - "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(), - grad.name()); - } - megdnn_assert( - param().compute_mode != Param::ComputeMode::FLOAT32 + // FIXME: infering dtype of grad via naive impl only support fp32 + // (lack of quantized dtype infering or others) may not suitable in the furture #if !MEGDNN_DISABLE_FLOAT16 - || filter.enumv() == DTypeEnum::Float16 || - filter.enumv() == DTypeEnum::BFloat16 + if (diff.enumv() == DTypeEnum::Float32 || diff.enumv() == DTypeEnum::Float16) { + grad = diff; + } #endif - , - "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " - "input / output."); + megdnn_assert(grad.valid(), "dtype of grad requires deducing of assigned"); megdnn_assert( - rin == rout && rin == dtype::Int32(), - "the dtype of rin/rout should be Int32, got %s.", rin.name()); + diff.category() == DTypeCategory::FLOAT && + filter.category() == DTypeCategory::FLOAT && + grad.category() == DTypeCategory::FLOAT, + "only float type is supported for region_restricted_conv backward data"); + megdnn_assert( + rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), + "the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name()); } void RegionRestrictedConvolutionBackwardData::deduce_layout( diff --git a/dnn/src/cuda/region_restricted_convolution/chanwise/bwd_large_filter.cu b/dnn/src/cuda/region_restricted_convolution/chanwise/bwd_large_filter.cu index 84c314efbde00e87ede2a69b9e92c8684f3491f6..9ece4f50e2747a77885bd453397b56060cf53010 100644 --- a/dnn/src/cuda/region_restricted_convolution/chanwise/bwd_large_filter.cu +++ b/dnn/src/cuda/region_restricted_convolution/chanwise/bwd_large_filter.cu @@ -1,7 +1,7 @@ -#include "./kern.cuh" #include "cuda.h" #include "cuda_fp16.h" #include "src/cuda/fp16_help.cuh" +#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh" using namespace megdnn; using namespace cuda; @@ -15,7 +15,7 @@ namespace cuda { namespace region_restricted_convolution { namespace chanwise { -// =====================================fwd===================================== +// =====================================bwd===================================== template <> void run_bwd_depthwise_large_filter( diff --git a/dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh b/dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh index cc6ff7a4df728ba553e18ff26846e3c34f420474..c35df2b16b248557f693c56587f685dce8abb0cd 100644 --- a/dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh +++ b/dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh @@ -498,16 +498,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( SrcGlobal2ShareVisitor gl2sh_src = { smem_src, static_cast(param.src_w), - static_cast( - is_fwd ? src_start_h - : src_start_h - - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2)), - static_cast( - is_fwd ? src_start_w - : src_start_w - - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2)), + static_cast(src_start_h), + static_cast(src_start_w), static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), is_fwd ? 1 : static_cast(param.stride_h), @@ -516,16 +508,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( RinGlobal2ShareVisitor gl2sh_rin = { smem_rin, static_cast(param.src_w), - static_cast( - is_fwd ? src_start_h - : src_start_h - - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2)), - static_cast( - is_fwd ? src_start_w - : src_start_w - - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2)), + static_cast(src_start_h), + static_cast(src_start_w), static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), is_fwd ? 1 : static_cast(param.stride_h), @@ -790,14 +774,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; - static_assert((FilterTileConfig::unroll_w & 3) == 0); + static_assert( + (FilterTileConfig::unroll_w & 3) == 0, "filter tile unroll_w & 3 != 0"); int* smem_rin_ptr = smem_rin + (off_ow * FilterTileConfig::unroll_w >> 2); T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; const uint8_t* rout_base_ptr = rout + batch * param.out_h * param.out_w; - static_assert((OutTileConfig::unroll_w & 3) == 0); - static_assert((OutTileConfig::block_w & 3) == 0); + static_assert((OutTileConfig::unroll_w & 3) == 0, "output tile unroll_w & 3 != 0"); + static_assert((OutTileConfig::block_w & 3) == 0, "output block_w & 3 != 0"); int reg_rout[OutTileConfig::unroll_size] = {0}; #pragma unroll for (int i = 0; i < OutTileConfig::unroll_h; ++i) { @@ -821,16 +806,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( SrcGlobal2ShareVisitor gl2sh_src = { smem_src, static_cast(param.src_w), - static_cast( - is_fwd ? src_start_h - : src_start_h - - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2)), - static_cast( - is_fwd ? src_start_w - : src_start_w - - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2)), + static_cast(src_start_h), + static_cast(src_start_w), static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), is_fwd ? 1 : static_cast(param.stride_h), @@ -839,16 +816,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( RinGlobal2ShareVisitor gl2sh_rin = { smem_rin, static_cast(param.src_w), - static_cast( - is_fwd ? src_start_h - : src_start_h - - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2)), - static_cast( - is_fwd ? src_start_w - : src_start_w - - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2)), + static_cast(src_start_h), + static_cast(src_start_w), static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), is_fwd ? 1 : static_cast(param.stride_h), @@ -1134,14 +1103,20 @@ void LaunchDepthwiseConv2dGPU( RinTileCount::smem_size * sizeof(int); void (*kernel)(const Param, const T*, const T*, const RT*, const RT*, T*); + const bool is_fwd = (kDirection == DIRECTION_FORWARD); if (param.is_compute_deafult) { kernel = DepthwiseConv2dGPUKernelNCHW; } else { megdnn_assert_internal(0); } - kernel<<>>( - param, input, filter, rin, rout, output); + if (is_fwd) { + kernel<<>>( + param, input, filter, rin, rout, output); + } else { + kernel<<>>( + param, input, filter, rout, rin, output); + } after_kernel_launch(); } diff --git a/dnn/src/cuda/region_restricted_convolution/opr_impl.cpp b/dnn/src/cuda/region_restricted_convolution/opr_impl.cpp index d429e0bba64ec5c4197107fc1f08b14649217726..5f574b1b84a30bb0de5d4f4d5fde6ebcf9b07b2e 100644 --- a/dnn/src/cuda/region_restricted_convolution/opr_impl.cpp +++ b/dnn/src/cuda/region_restricted_convolution/opr_impl.cpp @@ -55,25 +55,65 @@ size_t RegionRestrictedConvolutionBackwardDataImpl::get_workspace_in_bytes( void RegionRestrictedConvolutionBackwardDataImpl::exec( _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { - megdnn_throw(ssprintf( - "unsupported RegionRestrictedConvolutionBackwardData(%s, %s, %s, %s) -> %s", - filter.layout.dtype.name(), diff.layout.dtype.name(), - rin.layout.dtype.name(), rout.layout.dtype.name(), - grad.layout.dtype.name())); + auto fm = check_exec( + filter.layout, diff.layout, rin.layout, rout.layout, grad.layout, + workspace.size); + // XXX: a naive impl to set deconv padding to param, needs optimization in future. + [&]() -> void { + size_t stride = fm.stride[0]; + size_t src_size = grad.layout.shape[2]; + size_t fwd_pad = fm.padding[0]; + size_t filter_size = fm.spatial[0]; + size_t deconv_pad = (stride * src_size - stride + stride * filter_size - + src_size - 2 * fwd_pad + filter_size - 1) / + (2 * stride); + fm.padding[0] = fm.padding[1] = deconv_pad; + return; + }(); + auto kparam = chanwise::Param::load( + diff.layout, grad.layout, fm, + param().compute_mode == Param::ComputeMode::DEFAULT); + megdnn_assert( + fm.group > 1 && diff.layout.dtype.category() == DTypeCategory::FLOAT && + param().compute_mode == Param::ComputeMode::DEFAULT && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip && + param().stride_h == 1 && param().stride_w == 1); + // NOTE: uint8 dtype region mask requires the spatial size of src&dst is 4*N + if (rin.layout.dtype == dtype::Uint8()) { + megdnn_assert( + (grad.layout.shape[3] & 3) == 0 && (diff.layout.shape[3] & 3) == 0); + } + auto stream = cuda_stream(handle()); + if (filter.layout.dtype == dtype::Float32() && rin.layout.dtype == dtype::Int32() && + rout.layout.dtype == dtype::Int32()) { + chanwise::run_bwd_depthwise_large_filter( + grad.ptr(), diff.ptr(), + filter.ptr(), rin.ptr(), rout.ptr(), + kparam, stream); + } else if ( + filter.layout.dtype == dtype::Float32() && + rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) { + chanwise::run_bwd_depthwise_large_filter( + grad.ptr(), diff.ptr(), + filter.ptr(), rin.ptr(), rout.ptr(), + kparam, stream); + } else { + megdnn_throw("undefined or unimplemented region restricted conv mode"); + } } size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& diff, const TensorLayout&, const TensorLayout&, const TensorLayout& grad) { - size_t workspace_size = 0; - return workspace_size; + return 0; } /* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */ void RegionRestrictedConvolutionBackwardFilterImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { - megdnn_assert_internal(0); + megdnn_throw("Region Restricted Conv BackwardFilter unimplemented"); } // vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/region_restricted_convolution.cpp b/dnn/test/cuda/region_restricted_convolution.cpp index 4e3ebb901cabf1c676e802f101123c3c71568f55..ab03cdf2922f04f3d64af902de027afb0803543e 100644 --- a/dnn/test/cuda/region_restricted_convolution.cpp +++ b/dnn/test/cuda/region_restricted_convolution.cpp @@ -117,7 +117,7 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { .set_dtype(1, dtype::Float32()) .set_dtype(2, dtype::Int32()) .set_dtype(3, dtype::Int32()); - rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng).set_rng(0, &r_rng); + rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); rr_bencher.set_times(nr_times); size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); @@ -169,6 +169,202 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); } +TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32) { + require_compute_capability(7, 5); + Benchmarker bencher(handle_cuda()); + bencher.set_display(false); + bencher.set_before_exec_callback( + AlgoChecker("DEPTHWISE_LARGE_FILTER")); + + Benchmarker rr_bencher(handle_cuda()); + rr_bencher.set_display(false); + + ConvolutionBackwardData::Param param; + param.format = ConvolutionBackwardData::Param::Format::NCHW; + param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP; + + RegionRestrictedConvolutionBackwardData::Param rr_param; + rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW; + rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; + + UniformIntRNG r_rng{1, 3}; + + auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, + size_t fw, size_t sh, size_t sw, size_t nr_times) { + param.pad_h = fh / 2; + param.pad_w = fw / 2; + param.stride_h = sh; + param.stride_w = sw; + + rr_param.pad_h = fh / 2; + rr_param.pad_w = fw / 2; + rr_param.stride_h = sh; + rr_param.stride_w = sw; + + bencher.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(4, dtype::Float32()); + bencher.set_times(nr_times); + + rr_bencher.set_param(rr_param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Int32()); + rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); + rr_bencher.set_times(nr_times); + + size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); + size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); + TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/, + rin{batch, hi, wi}, rout{batch, ho, wo}, + out{batch, g, ho, wo} /*output*/; + + float bandwith = static_cast( + inp.total_nr_elems() + kern.total_nr_elems() + + out.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + + float rr_bandwith = static_cast( + inp.total_nr_elems() + kern.total_nr_elems() + + rin.total_nr_elems() + rout.total_nr_elems() + + out.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + + auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times; + auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; + + auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times; + auto rr_ops = + 2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; + printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " + "grad=%s, " + "kern=%s, diff=%s\n" + "time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" + "bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", + inp.to_string().c_str(), kern.to_string().c_str(), + out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, + bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, + time_in_ms / rr_time_in_ms); + }; + + run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); + run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); + run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); + run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); + run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); + run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); + run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); + run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); + run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); + run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); + run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); + run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); + run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); + run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); + run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); +} + +TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32_UINT8) { + require_compute_capability(7, 5); + Benchmarker bencher(handle_cuda()); + bencher.set_display(false); + bencher.set_before_exec_callback( + AlgoChecker("DEPTHWISE_LARGE_FILTER")); + + Benchmarker rr_bencher(handle_cuda()); + rr_bencher.set_display(false); + + ConvolutionBackwardData::Param param; + param.format = ConvolutionBackwardData::Param::Format::NCHW; + param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP; + + RegionRestrictedConvolutionBackwardData::Param rr_param; + rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW; + rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; + + UniformIntRNG r_rng{1, 3}; + + auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, + size_t fw, size_t sh, size_t sw, size_t nr_times) { + param.pad_h = fh / 2; + param.pad_w = fw / 2; + param.stride_h = sh; + param.stride_w = sw; + + rr_param.pad_h = fh / 2; + rr_param.pad_w = fw / 2; + rr_param.stride_h = sh; + rr_param.stride_w = sw; + + bencher.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(4, dtype::Float32()); + bencher.set_times(nr_times); + + rr_bencher.set_param(rr_param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Uint8()) + .set_dtype(3, dtype::Uint8()); + rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); + rr_bencher.set_times(nr_times); + + size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); + size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); + TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/, + rin{batch, hi, wi}, rout{batch, ho, wo}, + out{batch, g, ho, wo} /*output*/; + + float bandwith = static_cast( + inp.total_nr_elems() + kern.total_nr_elems() + + out.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + + float rr_bandwith = static_cast( + inp.total_nr_elems() + kern.total_nr_elems() + + rin.total_nr_elems() + rout.total_nr_elems() + + out.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + + auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times; + auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; + + auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times; + auto rr_ops = + 2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; + printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " + "grad=%s, " + "kern=%s, diff=%s\n" + "time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" + "bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", + inp.to_string().c_str(), kern.to_string().c_str(), + out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, + bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, + time_in_ms / rr_time_in_ms); + }; + + run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); + run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); + run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); + run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); + run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); + run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); + run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); + run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); + run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); + run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); + run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); + run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); + run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); + run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); + run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); +} + TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { require_compute_capability(7, 5); Benchmarker bencher(handle_cuda()); @@ -271,6 +467,124 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { #endif +TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32) { + Checker checker(handle_cuda()); + + for (auto dt : std::vector{dtype::Int32(), dtype::Uint8()}) { + auto run = [&checker, &dt]( + size_t n, size_t g, size_t ih, size_t fh, size_t padding, + size_t stride) { + RegionRestrictedConvolutionBackwardData::Param cur_param; + cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode:: + CROSS_CORRELATION; + cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param:: + ComputeMode::DEFAULT; + cur_param.sparse = + RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dt) + .set_dtype(3, dt); + float scale = 64.f / sqrt(fh * fh); + UniformFloatRNG rng(scale, 2 * scale); + UniformIntRNG r_rng{1, 2}; + checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( + 3, &r_rng); + cur_param.pad_h = cur_param.pad_w = padding; + cur_param.stride_h = cur_param.stride_w = stride; + + size_t oh = (ih + 2 * padding - fh + 1) / stride; + checker.set_param(cur_param).execs({ + {g, 1, 1, fh, fh}, // filter + {n, g * 1, oh, oh}, // diff + {n, ih, ih}, // rin + {n, oh, oh}, // rout + {n, g * 1, ih, ih} // grad + }); + }; + if (dt == dtype::Int32()) { + run(4, 8, 32, 5, 5 / 2, 1); + run(1, 2, 2, 2, 0, 1); + run(1, 2, 3, 3, 0, 1); + run(1, 2, 4, 4, 0, 1); + run(1, 2, 5, 5, 0, 1); + run(1, 2, 6, 6, 0, 1); + run(1, 2, 7, 7, 0, 1); + } + run(4, 8, 32, 7, 7 / 2, 1); + run(4, 8, 32, 9, 9 / 2, 1); + run(4, 8, 32, 11, 11 / 2, 1); + run(4, 8, 32, 13, 13 / 2, 1); + run(4, 8, 32, 15, 15 / 2, 1); + run(4, 8, 32, 17, 17 / 2, 1); + run(4, 8, 32, 19, 19 / 2, 1); + run(4, 8, 32, 21, 21 / 2, 1); + run(4, 8, 32, 23, 23 / 2, 1); + run(4, 8, 32, 25, 25 / 2, 1); + run(4, 8, 32, 27, 27 / 2, 1); + run(4, 8, 32, 29, 29 / 2, 1); + run(4, 8, 32, 31, 31 / 2, 1); + } +} + +TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32_RIN_EQ_ROUT) { + Checker checker(handle_cuda()); + + for (auto dt : std::vector{dtype::Int32()}) { + auto run = [&checker, &dt]( + size_t n, size_t g, size_t ih, size_t fh, size_t padding, + size_t stride) { + RegionRestrictedConvolutionBackwardData::Param cur_param; + cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode:: + CROSS_CORRELATION; + cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param:: + ComputeMode::DEFAULT; + cur_param.sparse = + RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; + checker.set_dtype(2, dt).set_dtype(3, dt); + float scale = 64.f / sqrt(fh * fh); + UniformFloatRNG rng(scale, 2 * scale); + // value 0 mask may cause unexpected behaviour. + UniformIntRNG r_rng{1, 1}; + checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( + 3, &r_rng); + cur_param.pad_h = cur_param.pad_w = padding; + cur_param.stride_h = cur_param.stride_w = stride; + + size_t oh = (ih + 2 * padding - fh + 1) / stride; + checker.set_param(cur_param).execs( + {/*filter*/ {g, 1, 1, fh, fh}, + /*diff*/ {n, g * 1, oh, oh}, + /*rin*/ {n, ih, ih}, + /*rout*/ {n, oh, oh}, + /*grad*/ {n, g * 1, ih, ih}}); + }; + if (dt == dtype::Int32()) { + // NOTE: UINT8 assert the spatial size of src&dst is 4*N + run(4, 8, 32, 5, 5 / 2, 1); + run(1, 2, 2, 2, 0, 1); + run(1, 2, 3, 3, 0, 1); + run(1, 2, 4, 4, 0, 1); + run(1, 2, 5, 5, 0, 1); + run(1, 2, 6, 6, 0, 1); + run(1, 2, 7, 7, 0, 1); + } + run(4, 8, 32, 7, 7 / 2, 1); + run(4, 8, 32, 9, 9 / 2, 1); + run(4, 8, 32, 11, 11 / 2, 1); + run(4, 8, 32, 13, 13 / 2, 1); + run(4, 8, 32, 15, 15 / 2, 1); + run(4, 8, 32, 17, 17 / 2, 1); + run(4, 8, 32, 19, 19 / 2, 1); + run(4, 8, 32, 21, 21 / 2, 1); + run(4, 8, 32, 23, 23 / 2, 1); + run(4, 8, 32, 25, 25 / 2, 1); + run(4, 8, 32, 27, 27 / 2, 1); + run(4, 8, 32, 29, 29 / 2, 1); + run(4, 8, 32, 31, 31 / 2, 1); + } +} + } // namespace test } // namespace megdnn diff --git a/dnn/test/naive/region_restricted_convolution.cpp b/dnn/test/naive/region_restricted_convolution.cpp index 7bf9abda0cd9b2558c40f79622d151c7ab4e0160..4012f2a906c6665a33faa9981aac9dcf4a4604c4 100644 --- a/dnn/test/naive/region_restricted_convolution.cpp +++ b/dnn/test/naive/region_restricted_convolution.cpp @@ -131,4 +131,110 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { {}}); } +TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD_DENSE_BRUTE) { + Checker checker(handle()); + RegionRestrictedConvolutionForward::Param param; + checker.set_param(param).exect( + Testcase{ + TensorValue( // src + {1, 1, 4, 4}, dtype::Float32(), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + TensorValue( // filter + {1, 1, 2, 2}, dtype::Float32(), {1, 1, 1, 1}), + TensorValue( // rin + {1, 4, 4}, dtype::Int32(), + {1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}), + TensorValue( // rout + {1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}), + {}, // output + }, + Testcase{ + {}, + {}, + {}, + {}, + TensorValue( + {1, 1, 3, 3}, dtype::Float32(), + {4, 14, 18, 5, 9, 0, 13, 9, 50})}); +} + +TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_DENSE_BRUTE) { + Checker checker(handle()); + RegionRestrictedConvolutionBackwardData::Param param; + checker.set_param(param).exect( + Testcase{ + // filter + TensorValue( + {1, 1, 2, 2}, // shape + dtype::Float32(), // dtype + {1.f, 1.f, 1.f, 1.f}), + // diff + TensorValue( + {1, 1, 3, 3}, dtype::Float32(), + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}), + // rin + TensorValue( + {1, 4, 4}, dtype::Int32(), + {1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}), + // rout + TensorValue({1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}), + // grad + {}}, + Testcase{// filter + {}, + // diff + {}, + // rin + {}, + // rout + {}, + // grad + TensorValue( + {1, 1, 4, 4}, dtype::Float32(), + {0., 2., 5., 3., 1., 6., 5., 3., 0., 13., 9., 9., 0., 7., + 9., 9.})}); +} + +TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_GROUP_BRUTE) { + Checker checker(handle()); + + // params + RegionRestrictedConvolutionBackwardData::Param param; + param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; + param.mode = RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION; + param.compute_mode = + RegionRestrictedConvolutionBackwardData::Param::ComputeMode::DEFAULT; + param.pad_h = param.pad_w = + 0; // forward param, naive backward data doesn't matter with deconv padding + param.stride_h = param.stride_w = 1; + + // checker setting + checker.set_param(param).exect( + Testcase{// filter + TensorValue( + {2, 1, 1, 2, 2}, // shape + dtype::Float32(), // dtype + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}), + // diff + TensorValue({1, 2, 1, 1}, dtype::Float32(), {1, 2}), + // rin + TensorValue({1, 2, 2}, dtype::Int32(), {1, 1, 1, 1}), + // rout + TensorValue({1, 1, 1}, dtype::Int32(), {1}), + // grad + {}}, + Testcase{// filter + {}, + // diff + {}, + // rin + {}, + // rout + {}, + // grad + TensorValue( + {1, 2, 2, 2}, dtype::Float32(), + {1, 2, 3, 4, 10, 12, 14, 16})}); +} + // vim: syntax=cpp.doxygen