diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index d8e954e335706a2ef8b66956b43c168def285e42..4b1c28871df99c1d0f713d1ecef1a1e1a8b87d1b 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -192,6 +192,87 @@ class ReduceForward: public OperatorBase { }; using Reduce = ReduceForward; +class CorrelationBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(CorrelationBase, OperatorBase); + DEF_OPR_PARAM(Correlation); + +protected: + void deduce_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, + TensorLayout& dst); + void check_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst); +}; + +class CorrelationForward : public CorrelationBase { + DEF_OPR_IMPL(CorrelationForward, CorrelationBase, 2, 1); + +public: + /** + * \param[in] data1 (n, c, ih, iw) + * \param[in] data2 (n, c, ih, iw) + * \param[out] dst (n, q, oh, ow), q is the number of neighborhood + * */ + virtual void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& data1, const TensorLayout& data2, + TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& dst) = 0; +protected: + void check_exec(const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst, size_t workspace_in_bytes); +}; +using Correlation = CorrelationForward; + +class CorrelationBackwardData1 : public CorrelationBase { + DEF_OPR_IMPL(CorrelationBackwardData1, CorrelationBase, 3, 1); + +public: + /** + * \param[in] diff the backpropagated gradient wrt. dst + * \param[in] data1 the `data1' parameter in CorrelationForward::exec + * \param[in] data2 the `data2' parameter in CorrelationForward::exec + * \param[out] grad1 the backpropagated gradient wrt. data1 + */ + virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, + const TensorLayout& data2, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& diff, + const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& grad1) = 0; + +protected: + void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& grad1, size_t workspace_in_bytes); +}; + +class CorrelationBackwardData2 : public CorrelationBase { + DEF_OPR_IMPL(CorrelationBackwardData2, CorrelationBase, 3, 1); + +public: + /** + * \param[in] diff the backpropagated gradient wrt. dst + * \param[in] data1 the `data1' parameter in CorrelationForward::exec + * \param[in] data2 the `data2' parameter in CorrelationForward::exec + * \param[out] grad2 the backpropagated gradient wrt. data2 + */ + virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, + const TensorLayout& data2, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& diff, + const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& grad2) = 0; + +protected: + void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& grad2, size_t workspace_in_bytes); +}; + class CumsumForward: public OperatorBase { DEF_OPR_PARAM(Cumsum); DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1); diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 2167fe3b20a28c86e9f9165c2285d219e5fd7f84..264b005f6cf1a5aa84526d40ed1890991ea8a67b 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1053,6 +1053,16 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o 'sample_width', '2') ) +(pdef('Correlation'). + add_enum_alias('Format', 'ConvolutionV0'). + add_fields('uint32', 'kernel_size', '1'). + add_fields('uint32', 'max_displacement', '1'). + add_fields('uint32', 'stride1', '1'). + add_fields('uint32', 'stride2', '1'). + add_fields('uint32', 'pad_size', '0'). + add_fields('bool', 'is_multiply', 'true') + ) + (pdef('DeformablePSROIPooling'). add_fields('bool', 'no_trans', 'true'). add_fields('float32', 'spatial_scale', 1, diff --git a/dnn/src/common/correlation.cpp b/dnn/src/common/correlation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1d211b7c0668c230ebb192dd807cd92c9b99343f --- /dev/null +++ b/dnn/src/common/correlation.cpp @@ -0,0 +1,132 @@ +/** + * \file dnn/src/common/correlation.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs.h" + +#include "src/common/utils.h" + +namespace megdnn { + +void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1, + const TensorLayout& data2, + TensorLayout& dst) { + megdnn_assert_contiguous(data1); + megdnn_assert_contiguous(data2); + megdnn_assert_contiguous(dst); + auto errmsg = [&]() { + return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + + ", " + megdnn_layout_msg(dst); + }; + MEGDNN_MARK_USED_VAR(errmsg); + using Format = CorrelationBase::Param::Format; + megdnn_assert(param().format == Format::NCHW); + auto data1_dtype = data1.dtype, data2_dtype = data2.dtype; + megdnn_assert(data1_dtype == data2_dtype && + data1_dtype.category() == DTypeCategory::FLOAT); + megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str()); + megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str()); + + uint32_t pad_size = param().pad_size; + uint32_t kernel_size = param().kernel_size; + uint32_t stride1 = param().stride1; + uint32_t stride2 = param().stride2; + uint32_t max_displacement = param().max_displacement; + + int paddedbottomheight = data1[2] + 2 * pad_size; + int paddedbottomwidth = data1[3] + 2 * pad_size; + uint32_t kernel_radius = (kernel_size - 1) / 2; + uint32_t border_size = max_displacement + kernel_radius; + uint32_t top_width = + ceil(static_cast(paddedbottomwidth - border_size * 2) / + static_cast(stride1)); + uint32_t top_height = + ceil(static_cast(paddedbottomheight - border_size * 2) / + static_cast(stride1)); + uint32_t neighborhood_grid_radius = max_displacement / stride2; + uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width; + megdnn_assert(top_width >= 1 && top_height >= 1); + + dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, + data1.dtype}; +} + +void CorrelationBase::check_layout_fwd(const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& dst) { + TensorLayout dst_expected; + megdnn_assert_eq_dtype(data1, dst); + megdnn_assert_eq_shape(data1, data2); + deduce_layout_fwd(data1, data2, dst_expected); + megdnn_assert_eq_shape(dst_expected, dst); +} + +void CorrelationForward::deduce_layout(const TensorLayout& data1, + const TensorLayout& data2, + TensorLayout& dst) { + deduce_layout_fwd(data1, data2, dst); +} + +void CorrelationForward::check_exec(const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& dst, + size_t workspace_in_bytes) { + check_layout_fwd(data1, data2, dst); + auto required_workspace_in_bytes = + get_workspace_in_bytes(data1, data2, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void CorrelationBackwardData1::check_exec(const TensorLayout& diff, + const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& grad1, + size_t workspace_in_bytes) { + check_layout_fwd(grad1, data2, diff); + megdnn_assert_eq_shape(data1, data2); + auto required_workspace_in_bytes = + get_workspace_in_bytes(diff, data1, data2, grad1); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void CorrelationBackwardData2::check_exec(const TensorLayout& diff, + const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& grad2, + size_t workspace_in_bytes) { + check_layout_fwd(data1, grad2, diff); + megdnn_assert_eq_shape(data1, data2); + auto required_workspace_in_bytes = + get_workspace_in_bytes(diff, data1, data2, grad2); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void CorrelationBackwardData2::deduce_layout(const TensorLayout& diff, + const TensorLayout& data1, + const TensorLayout& data2, + TensorLayout& grad) { + megdnn_assert_eq_shape(data1, data2); + check_layout_fwd(data1, data2, diff); + grad = data2; +} + +void CorrelationBackwardData1::deduce_layout(const TensorLayout& diff, + const TensorLayout& data1, + const TensorLayout& data2, + TensorLayout& grad) { + megdnn_assert_eq_shape(data1, data2); + check_layout_fwd(data1, data2, diff); + grad = data1; +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index c318931189f04e0ac8de65c691b01f640f72c369..e05449abfa8a96b4f5501213f77055870fc20ba5 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -194,6 +194,9 @@ private: cb(LocalShareBackwardFilter) \ cb(ROIAlignForward) \ cb(ROIAlignBackward) \ + cb(CorrelationForward) \ + cb(CorrelationBackwardData1) \ + cb(CorrelationBackwardData2) \ cb(BatchConvBiasForward) \ cb(Remap) \ cb(RemapBackwardData) \ diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 378a92796ff57892e6c60dfa7883a76c4ea5cba7..0642345e126b75bcb5a6d941c6ac7e9443b00b54 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -54,6 +54,9 @@ DEF(BNForward, 8, true, true); DEF(BNBackward, 8, true, false); DEF(ROIPoolingForward, 4, true, false); DEF(ROIPoolingBackward, 5, true, false); +DEF(CorrelationForward, 3, true, true); +DEF(CorrelationBackwardData1, 4, true, true); +DEF(CorrelationBackwardData2, 4, true, true); DEF(WarpPerspectiveForward, 3, true, false); DEF(WarpPerspectiveBackwardData, 3, true, false); DEF(WarpPerspectiveBackwardMat, 4, true, false); diff --git a/dnn/src/cuda/correlation/correlation_cuda.cu b/dnn/src/cuda/correlation/correlation_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..beb6db45a0a6faa3b738b6082a6f6582c883b8d2 --- /dev/null +++ b/dnn/src/cuda/correlation/correlation_cuda.cu @@ -0,0 +1,371 @@ +/** + * \file dnn/src/cuda/roi_align/roi_align.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/cuda/correlation/correlation_cuda.cuh" + +#include +#include "megdnn/dtype.h" +#include "src/cuda/query_blocksize.cuh" +#include "src/cuda/utils.cuh" +#define ROUND_OFF 50000 + +using namespace megdnn; +namespace megdnn { +namespace cuda { +namespace correlation { + +#define CUDA_KERNEL_LOOP(vtid, vthreads) \ + for (int vtid = blockIdx.x * blockDim.x + threadIdx.x; vtid < vthreads; \ + vtid += blockDim.x * gridDim.x) + +template +__global__ void forward_kernel(const int nthreads, const T* data1, + const T* data2, T* dst, const int bchannels, + const int bheight, const int bwidth, + const int tchannels, const int theight, + const int twidth, const int kernel_size, + const int max_displacement, const int stride1, + const int stride2, const int pad_size, + const bool is_multiply) { + CUDA_KERNEL_LOOP(idx, nthreads) { + int kernel_radius = (kernel_size - 1) / 2; + int neighborhood_grid_radius = max_displacement / stride2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + int x = idx % twidth; + int y = (idx / twidth) % theight; + int c = (idx / twidth / theight) % tchannels; + int n = idx / twidth / theight / tchannels; + + // get src center position in image1 + int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; + int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; + + // get offset of center in image2 + int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * + stride2; + int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * + stride2; + + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + // compute kernel correlation + T sum = T(0.f); + for (int i = -kernel_radius; i <= kernel_radius; i++) { + for (int j = -kernel_radius; j <= kernel_radius; j++) { + int in_x1 = x1 + i; + int in_y1 = y1 + j; + int in_x2 = x2 + i; + int in_y2 = y2 + j; + + for (int channel = 0; channel < bchannels; channel++) { + T tmp1 = T(0.f); + T tmp2 = T(0.f); + if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && + in_y1 < bheight) { + int idx1 = + ((n * bchannels + channel) * bheight + in_y1) * + bwidth + + in_x1; + tmp1 = data1[idx1]; + } + + if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && + in_y2 < bheight) { + int idx2 = + ((n * bchannels + channel) * bheight + in_y2) * + bwidth + + in_x2; + tmp2 = data2[idx2]; + } + if (is_multiply) { + sum += tmp1 * tmp2; + } else { + sum += fabsf(tmp1 - tmp2); + } + } + } + } + + const int sumelems = + (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; + dst[idx] = sum / sumelems; + } +} + +template +__global__ void backward_kernel_data1( + const int nthreads, const T* diff, const T* data1, const T* data2, + T* grad1, const int bchannels, const int bheight, const int bwidth, + const int tchannels, const int theight, const int twidth, + const int kernel_size, const int max_displacement, const int stride1, + const int stride2, const int pad_size, const bool is_multiply) { + CUDA_KERNEL_LOOP(idx, nthreads) { + int kernel_radius = (kernel_size - 1) / 2; + int neighborhood_grid_radius = max_displacement / stride2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + int x = idx % bwidth; + int y = (idx / bwidth) % bheight; + int c = (idx / bwidth / bheight) % bchannels; + int n = idx / bwidth / bheight / bchannels; + + T tmp1 = data1[idx]; + // Get X,Y ranges and clamp + // round_off is a trick to enable integer division with ceil, even for + // negative numbers We use a large offset, for the inner part not to + // become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + // we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) + // for diff_x_min, diff_y_min, x,y at the position of right-down + // ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 + int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + + round_off_s1 - 1) / + stride1 + + 1 - round_off; + int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + + round_off_s1 - 1) / + stride1 + + 1 - round_off; + + // floor (l - max_displacement + pad_size) / stride1 + int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - + round_off; + int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - + round_off; + + T sum = T(0.f); + + if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && + (ymin <= theight - 1)) { + xmin = max(0, xmin); + xmax = min(twidth - 1, xmax); + + ymin = max(0, ymin); + ymax = min(theight - 1, ymax); + + for (int p = -neighborhood_grid_radius; + p <= neighborhood_grid_radius; p++) { + for (int o = -neighborhood_grid_radius; + o <= neighborhood_grid_radius; o++) { + // Get bottom1 data: + int s2o = stride2 * o; + int s2p = stride2 * p; + int x2 = x + s2o, y2 = y + s2p; + + int idx2 = + ((n * bchannels + c) * bheight + y2) * bwidth + x2; + T tmp2 = T(0.f); + + if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { + tmp2 = data2[idx2]; + } + + int op = (p + neighborhood_grid_radius) * + neighborhood_grid_width + + (o + neighborhood_grid_radius); + int diff_channels_offset = (n * tchannels + op); + for (int diff_y = ymin; diff_y <= ymax; diff_y++) { + for (int diff_x = xmin; diff_x <= xmax; diff_x++) { + int idxtopdiff = + (diff_channels_offset * theight + diff_y) * + twidth + + diff_x; + + if (is_multiply) { + sum += diff[idxtopdiff] * tmp2; + } else { + T sign = (tmp1 >= tmp2) ? T(1.f) : T(-1.f); + sum += diff[idxtopdiff] * sign; + } + } + } + } + } + } + + const int sumelems = + (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; + grad1[idx] = sum / sumelems; + } +} + +template +__global__ void backward_kernel_data2( + const int nthreads, const T* diff, const T* data1, const T* data2, + T* grad2, const int bchannels, const int bheight, const int bwidth, + const int tchannels, const int theight, const int twidth, + const int kernel_size, const int max_displacement, const int stride1, + const int stride2, const int pad_size, const bool is_multiply) { + CUDA_KERNEL_LOOP(idx, nthreads) { + int kernel_radius = (kernel_size - 1) / 2; + int neighborhood_grid_radius = max_displacement / stride2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + int x = idx % bwidth; + int y = (idx / bwidth) % bheight; + int c = (idx / bwidth / bheight) % bchannels; + int n = idx / bwidth / bheight / bchannels; + + T tmp2 = data2[idx]; + + T sum = T(0.f); + + for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; + p++) { + for (int o = -neighborhood_grid_radius; + o <= neighborhood_grid_radius; o++) { + int s2o = o * stride2; + int s2p = p * stride2; + + int x1 = x - s2o; + int y1 = y - s2p; + + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + int xmin = (x1 + pad_size - 2 * kernel_radius - + max_displacement + round_off_s1 - 1) / + stride1 + + 1 - round_off; + int ymin = (y1 + pad_size - 2 * kernel_radius - + max_displacement + round_off_s1 - 1) / + stride1 + + 1 - round_off; + int xmax = (x1 + pad_size - max_displacement + round_off_s1) / + stride1 - + round_off; + int ymax = (y1 + pad_size - max_displacement + round_off_s1) / + stride1 - + round_off; + + if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && + (ymin <= theight - 1)) { + xmin = max(0, xmin); + xmax = min(twidth - 1, xmax); + + ymin = max(0, ymin); + ymax = min(theight - 1, ymax); + + int idx1 = + ((n * bchannels + c) * bheight + y1) * bwidth + x1; + T tmp1 = T(0.f); + if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { + tmp1 = data1[idx1]; + } + + int op = (p + neighborhood_grid_radius) * + neighborhood_grid_width + + (o + neighborhood_grid_radius); + int diff_channels_offset = (n * tchannels + op); + for (int diff_y = ymin; diff_y <= ymax; diff_y++) { + for (int diff_x = xmin; diff_x <= xmax; diff_x++) { + int idxtopdiff = + (diff_channels_offset * theight + diff_y) * + twidth + + diff_x; + + if (is_multiply) { + sum += diff[idxtopdiff] * tmp1; + } else { + T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); + sum += diff[idxtopdiff] * sign; + } + } + } + } + } + } + + const int sumelems = + (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; + grad2[idx] = sum / sumelems; + } +} + +template +void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, + const int bchannels, const int bheight, const int bwidth, + const int tchannels, const int theight, const int twidth, + const int kernel_size, const int max_displacement, + const int stride1, const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream) { + int threads_block = query_blocksize_for_kernel(forward_kernel); + forward_kernel + <<>>( + nthreads, data1, data2, dst, bchannels, bheight, bwidth, + tchannels, theight, twidth, kernel_size, max_displacement, + stride1, stride2, pad_size, is_multiply); + after_kernel_launch(); +} + +template +void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, + const T* data2, T* grad1, const int bchannels, + const int bheight, const int bwidth, + const int tchannels, const int theight, + const int twidth, const int kernel_size, + const int max_displacement, const int stride1, + const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream) { + int threads_block = query_blocksize_for_kernel(backward_kernel_data1); + backward_kernel_data1 + <<>>( + nthreads, diff, data1, data2, grad1, bchannels, bheight, + bwidth, tchannels, theight, twidth, kernel_size, + max_displacement, stride1, stride2, pad_size, is_multiply); + after_kernel_launch(); +} + +template +void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, + const T* data2, T* grad2, const int bchannels, + const int bheight, const int bwidth, + const int tchannels, const int theight, + const int twidth, const int kernel_size, + const int max_displacement, const int stride1, + const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream) { + int threads_block = query_blocksize_for_kernel(backward_kernel_data2); + backward_kernel_data2 + <<>>( + nthreads, diff, data1, data2, grad2, bchannels, bheight, + bwidth, tchannels, theight, twidth, kernel_size, + max_displacement, stride1, stride2, pad_size, is_multiply); + after_kernel_launch(); +} + +#define INST(T) \ + template void forward_proxy( \ + const int, const T*, const T*, T* dst, const int, const int, \ + const int, const int, const int, const int, const int, const int, \ + const int, const int, const int, const bool, cudaStream_t); \ + template void backward_proxy_data1( \ + const int, const T*, const T*, const T*, T*, const int, const int, \ + const int, const int, const int, const int, const int, const int, \ + const int, const int, const int, const bool, cudaStream_t); \ + template void backward_proxy_data2( \ + const int, const T*, const T*, const T*, T*, const int, const int, \ + const int, const int, const int, const int, const int, const int, \ + const int, const int, const int, const bool, cudaStream_t); +INST(dt_float32) +INST(dt_float16) +INST(dt_bfloat16) +#undef INST + +} // namespace roi_align +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/correlation/correlation_cuda.cuh b/dnn/src/cuda/correlation/correlation_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3562abd0b8d121e474c877962dcc524f6822270f --- /dev/null +++ b/dnn/src/cuda/correlation/correlation_cuda.cuh @@ -0,0 +1,51 @@ +/** + * \file dnn/src/cuda/correlation/correlation.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include + +namespace megdnn { +namespace cuda { +namespace correlation { + +template +void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, + const int bchannels, const int bheight, const int bwidth, + const int tchannels, const int theight, const int twidth, + const int kernel_size, const int max_displacement, + const int stride1, const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream); + +template +void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, + const T* data2, T* grad1, const int bchannels, + const int bheight, const int bwidth, + const int tchannels, const int theight, + const int twidth, const int kernel_size, + const int max_displacement, const int stride1, + const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream); + +template +void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, + const T* data2, T* grad2, const int bchannels, + const int bheight, const int bwidth, + const int tchannels, const int theight, + const int twidth, const int kernel_size, + const int max_displacement, const int stride1, + const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream); + +} // namespace correlation +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/correlation/opr_impl.cpp b/dnn/src/cuda/correlation/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99a23fb177ecdc446c9fc00d50dc6b6de0cdc216 --- /dev/null +++ b/dnn/src/cuda/correlation/opr_impl.cpp @@ -0,0 +1,129 @@ +/** + * \file dnn/src/naive/correlation/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/correlation/opr_impl.h" +#include "src/cuda/correlation/correlation_cuda.cuh" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, + _megdnn_tensor_in data2, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(data1.layout, data2.layout, dst.layout, workspace.size); + auto p = param(); + auto stream = cuda_stream(handle()); + int nthreads = dst.layout.total_nr_elems(); + int stride1 = p.stride1; + int stride2 = p.stride2; + int kernel_size = p.kernel_size; + int max_displacement = p.max_displacement; + int pad_size = p.pad_size; + bool is_multiply = p.is_multiply; + + int tchannels = dst.layout[1]; + int theight = dst.layout[2], twidth = dst.layout[3]; + int bchannels = data1.layout[1]; + int bheight = data1.layout[2], bwidth = data1.layout[3]; + using namespace ::megdnn::cuda::correlation; + +#define cb(DType) \ + if (data1.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + forward_proxy(nthreads, data1.ptr(), data2.ptr(), \ + dst.ptr(), bchannels, bheight, bwidth, tchannels, \ + theight, twidth, kernel_size, max_displacement, \ + stride1, stride2, pad_size, is_multiply, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb +} + +void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, + _megdnn_tensor_in data1, + _megdnn_tensor_in data2, + _megdnn_tensor_out grad1, + _megdnn_workspace workspace) { + check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, + workspace.size); + + auto stream = cuda_stream(handle()); + int nthreads = grad1.layout.total_nr_elems(); + int stride1 = param().stride1; + int stride2 = param().stride2; + int kernel_size = param().kernel_size; + int max_displacement = param().max_displacement; + int pad_size = param().pad_size; + bool is_multiply = param().is_multiply; + + int tchannels = diff.layout[1]; + int theight = diff.layout[2], twidth = diff.layout[3]; + int bchannels = data1.layout[1]; + int bheight = data1.layout[2], bwidth = data1.layout[3]; + + using namespace ::megdnn::cuda::correlation; + +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + backward_proxy_data1(nthreads, diff.ptr(), data1.ptr(), \ + data2.ptr(), grad1.ptr(), bchannels, \ + bheight, bwidth, tchannels, theight, twidth, \ + kernel_size, max_displacement, stride1, \ + stride2, pad_size, is_multiply, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb +} + +void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, + _megdnn_tensor_in data1, + _megdnn_tensor_in data2, + _megdnn_tensor_out grad2, + _megdnn_workspace workspace) { + check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, + workspace.size); + auto p = param(); + auto stream = cuda_stream(handle()); + int nthreads = grad2.layout.total_nr_elems(); + int stride1 = p.stride1; + int stride2 = p.stride2; + int kernel_size = p.kernel_size; + int max_displacement = p.max_displacement; + int pad_size = p.pad_size; + bool is_multiply = p.is_multiply; + + int tchannels = diff.layout[1]; + int theight = diff.layout[2], twidth = diff.layout[3]; + int bchannels = data1.layout[1]; + int bheight = data1.layout[2], bwidth = data1.layout[3]; + + using namespace ::megdnn::cuda::correlation; + +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + backward_proxy_data2(nthreads, diff.ptr(), data1.ptr(), \ + data2.ptr(), grad2.ptr(), bchannels, \ + bheight, bwidth, tchannels, theight, twidth, \ + kernel_size, max_displacement, stride1, \ + stride2, pad_size, is_multiply, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb +} + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/correlation/opr_impl.h b/dnn/src/cuda/correlation/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..0fc31c481add0fc6ab5eebfce3bf8d34aa9cca5e --- /dev/null +++ b/dnn/src/cuda/correlation/opr_impl.h @@ -0,0 +1,61 @@ +/** + * \file dnn/src/naive/correlation/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +#include "src/cuda/cudnn_wrapper.h" + +namespace megdnn { +namespace cuda { + +class CorrelationForwardImpl final : public CorrelationForward { +public: + using CorrelationForward::CorrelationForward; + void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& data1, + const TensorLayout& data2, + const TensorLayout& dst) override { + return 0; + } +}; + +class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { +public: + using CorrelationBackwardData1::CorrelationBackwardData1; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, + _megdnn_tensor_in data2, _megdnn_tensor_out grad1, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { +public: + using CorrelationBackwardData2::CorrelationBackwardData2; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, + _megdnn_tensor_in data2, _megdnn_tensor_out grad2, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 3b9270390a71dc55b98fd6130c721630a8e58890..80ac59638c7fdc0be3288033341165ebacd6a0f8 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -24,6 +24,7 @@ #include "src/cuda/convolution/opr_impl.h" #include "src/cuda/convolution3d/opr_impl.h" #include "src/cuda/convpooling/opr_impl.h" +#include "src/cuda/correlation/opr_impl.h" #include "src/cuda/cumsum/opr_impl.h" #include "src/cuda/cvt_color/opr_impl.h" #include "src/cuda/dct/opr_impl.h" diff --git a/dnn/src/naive/correlation/opr_impl.cpp b/dnn/src/naive/correlation/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99d04eb6330d410871fa17a61021612072736d9d --- /dev/null +++ b/dnn/src/naive/correlation/opr_impl.cpp @@ -0,0 +1,384 @@ +/** + * \file dnn/src/naive/correlation/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/naive/correlation/opr_impl.h" +#include +#include "src/common/utils.h" +#include "src/naive/handle.h" +#define ROUND_OFF 50000 +using namespace megdnn; +using namespace naive; +using namespace std; +namespace { + +using Param = megdnn::Correlation::Param; + +template +void forward(_megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, const Param& param) { + // data1 treat as no-padding tensor + int total_nr_elems = dst.layout.total_nr_elems(); + + int stride1 = param.stride1, stride2 = param.stride2; + int kernel_size = param.kernel_size; + int kernel_radius = (kernel_size - 1) / 2; + int max_displacement = param.max_displacement; + int pad_size = param.pad_size; + + int tchannels = dst.layout[1]; + int theight = dst.layout[2], twidth = dst.layout[3]; + int bchannels = data1.layout[1]; + int bheight = data1.layout[2], bwidth = data1.layout[3]; + + int neighborhood_grid_radius = max_displacement / stride2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + for (int idx = 0; idx < total_nr_elems; ++idx) { + int x = idx % twidth; + int y = (idx / twidth) % theight; + int c = (idx / twidth / theight) % tchannels; + int n = idx / twidth / theight / tchannels; + + // get src center position in image1 + int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; + int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; + + // get offset of center in image2 + int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * + stride2; + int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * + stride2; + + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + // compute kernel correlation + float sum = 0.; + for (int i = -kernel_radius; i <= kernel_radius; i++) { + for (int j = -kernel_radius; j <= kernel_radius; j++) { + int in_x1 = x1 + i; + int in_y1 = y1 + j; + int in_x2 = x2 + i; + int in_y2 = y2 + j; + + for (int channel = 0; channel < bchannels; channel++) { + float tmp1 = 0.; + float tmp2 = 0.; + if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && + in_y1 < bheight) { + int idx1 = + ((n * bchannels + channel) * bheight + in_y1) * + bwidth + + in_x1; + tmp1 = data1.ptr()[idx1]; + } + + if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && + in_y2 < bheight) { + int idx2 = + ((n * bchannels + channel) * bheight + in_y2) * + bwidth + + in_x2; + tmp2 = data2.ptr()[idx2]; + } + + if (param.is_multiply) { + sum += tmp1 * tmp2; + } else { + sum += fabsf(tmp1 - tmp2); + } + } + } + } + + const int sumelems = + (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; + dst.ptr()[idx] = sum / sumelems; + } +} + +template +void backward_data1(_megdnn_tensor_in diff, _megdnn_tensor_in data1, + _megdnn_tensor_in data2, _megdnn_tensor_out grad1, + const Param& param) { + // data1 treat as no-padding tensor + // int total_nr_elems = diff.layout.total_nr_elems(); + int total_nr_elems = grad1.layout.total_nr_elems(); + + int stride1 = param.stride1, stride2 = param.stride2; + int kernel_size = param.kernel_size; + int kernel_radius = (kernel_size - 1) / 2; + int max_displacement = param.max_displacement; + int pad_size = param.pad_size; + + int tchannels = diff.layout[1]; + int theight = diff.layout[2], twidth = diff.layout[3]; + int bchannels = grad1.layout[1]; + int bheight = grad1.layout[2], bwidth = grad1.layout[3]; + + int neighborhood_grid_radius = max_displacement / stride2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + for (int idx = 0; idx < total_nr_elems; ++idx) { + // idx for grad1 + + int x = idx % bwidth; + int y = (idx / bwidth) % bheight; + int c = (idx / bwidth / bheight) % bchannels; + int n = idx / bwidth / bheight / bchannels; + + float tmp1 = data1.ptr()[idx]; + // Get X,Y ranges and clamp + // round_off is a trick to enable integer division with ceil, even for + // negative numbers We use a large offset, for the inner part not to + // become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + // we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) + // for diff_x_min, diff_y_min, x,y at the position of right-down + // ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 + int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + + round_off_s1 - 1) / + stride1 + + 1 - round_off; + int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + + round_off_s1 - 1) / + stride1 + + 1 - round_off; + // floor (l - max_displacement + pad_size) / stride1 + int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - + round_off; + int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - + round_off; + + float sum = 0.; + if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && + (ymin <= theight - 1)) { + xmin = max(0, xmin); + xmax = min(twidth - 1, xmax); + + ymin = max(0, ymin); + ymax = min(theight - 1, ymax); + + for (int p = -neighborhood_grid_radius; + p <= neighborhood_grid_radius; p++) { + for (int o = -neighborhood_grid_radius; + o <= neighborhood_grid_radius; o++) { + // Get bottom1 data: + int s2o = stride2 * o; + int s2p = stride2 * p; + int x2 = x + s2p, y2 = y + s2o; + + int idx2 = + ((n * bchannels + c) * bheight + y2) * bwidth + x2; + float tmp2 = 0.; + + if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { + tmp2 = data2.ptr()[idx2]; + } + + int op = (p + neighborhood_grid_radius) * + neighborhood_grid_width + + (o + neighborhood_grid_radius); + int diff_channels_offset = (n * tchannels + op); + + for (int diff_y = ymin; diff_y <= ymax; diff_y++) { + for (int diff_x = xmin; diff_x <= xmax; diff_x++) { + int idxtopdiff = + (diff_channels_offset * theight + diff_y) * + twidth + + diff_x; + + if (param.is_multiply) { + sum += diff.ptr()[idxtopdiff] * tmp2; + } else { + T sign = (tmp1 > tmp2) ? T(1.) : T(-1.); + sum += diff.ptr()[idxtopdiff] * sign; + } + } + } + } + } + } + + const int sumelems = + (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; + grad1.ptr()[idx] = sum / sumelems; + } +} + +template +void backward_data2(_megdnn_tensor_in diff, _megdnn_tensor_in data1, + _megdnn_tensor_in data2, _megdnn_tensor_out grad2, + const Param& param) { + // data1 treat as no-padding tensor + int total_nr_elems = grad2.layout.total_nr_elems(); + + int stride1 = param.stride1, stride2 = param.stride2; + int kernel_size = param.kernel_size; + int kernel_radius = (kernel_size - 1) / 2; + int max_displacement = param.max_displacement; + int pad_size = param.pad_size; + + int tchannels = diff.layout[1]; + int theight = diff.layout[2], twidth = diff.layout[3]; + int bchannels = grad2.layout[1]; + int bheight = grad2.layout[2], bwidth = grad2.layout[3]; + + int neighborhood_grid_radius = max_displacement / stride2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + for (int idx = 0; idx < total_nr_elems; ++idx) { + int x = idx % bwidth; + int y = (idx / bwidth) % bheight; + int c = (idx / bwidth / bheight) % bchannels; + int n = idx / bwidth / bheight / bchannels; + + T tmp2 = data2.ptr()[idx]; + + T sum = T(0.f); + + for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; + p++) { + for (int o = -neighborhood_grid_radius; + o <= neighborhood_grid_radius; o++) { + int s2o = o * stride2; + int s2p = p * stride2; + + int x1 = x - s2o; + int y1 = y - s2p; + + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + int xmin = (x1 + pad_size - 2 * kernel_radius - + max_displacement + round_off_s1 - 1) / + stride1 + + 1 - round_off; + int ymin = (y1 + pad_size - 2 * kernel_radius - + max_displacement + round_off_s1 - 1) / + stride1 + + 1 - round_off; + int xmax = (x1 + pad_size - max_displacement + round_off_s1) / + stride1 - + round_off; + int ymax = (y1 + pad_size - max_displacement + round_off_s1) / + stride1 - + round_off; + + if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && + (ymin <= theight - 1)) { + xmin = max(0, xmin); + xmax = min(twidth - 1, xmax); + + ymin = max(0, ymin); + ymax = min(theight - 1, ymax); + + int idx1 = + ((n * bchannels + c) * bheight + y1) * bwidth + x1; + T tmp1 = T(0.f); + if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { + tmp1 = data1.ptr()[idx1]; + } + + int op = (p + neighborhood_grid_radius) * + neighborhood_grid_width + + (o + neighborhood_grid_radius); + int diff_channels_offset = (n * tchannels + op); + for (int diff_y = ymin; diff_y <= ymax; diff_y++) { + for (int diff_x = xmin; diff_x <= xmax; diff_x++) { + int idxtopdiff = + (diff_channels_offset * theight + diff_y) * + twidth + + diff_x; + + if (param.is_multiply) { + sum += diff.ptr()[idxtopdiff] * tmp1; + } else { + T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); + sum += diff.ptr()[idxtopdiff] * sign; + } + } + } + } + } + } + + const int sumelems = + (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; + grad2.ptr()[idx] = sum / sumelems; + } +} + +} // namespace + +namespace megdnn { +namespace naive { + +void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, + _megdnn_tensor_in data2, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(data1.layout, data2.layout, dst.layout, workspace.size); +#define cb(DType) \ + if (data1.layout.dtype == DType()) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + forward::ctype>(data1, data2, dst, \ + param())); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, + _megdnn_tensor_in data1, + _megdnn_tensor_in data2, + _megdnn_tensor_out grad1, + _megdnn_workspace workspace) { + check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, + workspace.size); +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + backward_data1::ctype>( \ + diff, data1, data2, grad1, param())); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, + _megdnn_tensor_in data1, + _megdnn_tensor_in data2, + _megdnn_tensor_out grad2, + _megdnn_workspace workspace) { + check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, + workspace.size); +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + backward_data2::ctype>( \ + diff, data1, data2, grad2, param())); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/correlation/opr_impl.h b/dnn/src/naive/correlation/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..157928958e4933f7392ca4691472a29ae8c12639 --- /dev/null +++ b/dnn/src/naive/correlation/opr_impl.h @@ -0,0 +1,58 @@ +/** + * \file dnn/src/naive/correlation/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class CorrelationForwardImpl final : public CorrelationForward { +public: + using CorrelationForward::CorrelationForward; + void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { +public: + using CorrelationBackwardData1::CorrelationBackwardData1; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, + _megdnn_tensor_in data2, _megdnn_tensor_out grad1, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { +public: + using CorrelationBackwardData2::CorrelationBackwardData2; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, + _megdnn_tensor_in data2, _megdnn_tensor_out grad2, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 0c386de8e9d4d33a6b6477fe85e77391a05b21fe..e641737b33d1f6fb995209f2f7561ca9f01b8e65 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -30,6 +30,7 @@ #include "src/naive/convpooling/opr_impl.h" #include "src/naive/cumsum/opr_impl.h" #include "src/naive/cvt_color/opr_impl.h" +#include "src/naive/correlation/opr_impl.h" #include "src/naive/dct/opr_impl.h" #include "src/naive/deformable_conv/opr_impl.h" #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" diff --git a/dnn/test/common/correlation.h b/dnn/test/common/correlation.h new file mode 100644 index 0000000000000000000000000000000000000000..37375124496b778421ef5e6404cc35452a4e08ae --- /dev/null +++ b/dnn/test/common/correlation.h @@ -0,0 +1,73 @@ +/** + * \file dnn/test/common/correlation.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace test { +namespace correlation { + +struct TestArg { + param::Correlation param; + TensorShape data1, data2; + TestArg(param::Correlation param, TensorShape data1, TensorShape data2) + : param(param), data1(data1), data2(data2) {} +}; + +inline static std::vector get_args() { + std::vector args; + + param::Correlation cur_param; + for (size_t batch_size : {2}) { + for (size_t channel : {2}) { + for (size_t height : {160}) { + for (size_t width : {160}) { + cur_param.is_multiply = true; + cur_param.kernel_size = 3; + cur_param.max_displacement = 3; + cur_param.pad_size = 0; + cur_param.stride1 = 1; + cur_param.stride2 = 1; + cur_param.format = megdnn::param::Correlation::Format::NCHW; + + args.emplace_back( + cur_param, + TensorShape{batch_size, channel, height, width}, + TensorShape{batch_size, channel, height, width}); + + // cur_param.is_multiply = false; + // cur_param.kernel_size = 1; + // cur_param.max_displacement = 2; + // cur_param.pad_size = 1; + // cur_param.stride1 = 1; + // cur_param.stride2 = 1; + // cur_param.format = + // megdnn::param::Correlation::Format::NCHW; + + // args.emplace_back( + // cur_param, + // TensorShape{batch_size, channel, height, width}, + // TensorShape{batch_size, channel, height, width}); + } + } + } + } + + return args; +} + +} // namespace correlation +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/correlation.cpp b/dnn/test/cuda/correlation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..64ff64b83256d4f10ecc08a6c47acdb9645d8280 --- /dev/null +++ b/dnn/test/cuda/correlation.cpp @@ -0,0 +1,160 @@ +/** + * \file dnn/test/cuda/correlation.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/cuda/fixture.h" + +#include "test/common/checker.h" +#include "test/common/correlation.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, CORRELATION_FORWARD) { + using namespace correlation; + std::vector args = get_args(); + Checker checker(handle_cuda()); + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .execs({arg.data1, arg.data2, {}}); + } +} + +TEST_F(CUDA, CORRELATION_BACKWARDDATA1) { + ConstValue const_0{0}; + using Param = CorrelationBackwardData1::Param; + Param param; + param.is_multiply = true; + param.format = Param::Format::NCHW; + param.stride1 = 2; + param.stride2 = 2; + param.kernel_size = 3; + param.pad_size = 4; + Checker checker(handle_cuda()); + checker.set_epsilon(1e-2); + + uint32_t pad_size = param.pad_size; + uint32_t kernel_size = param.kernel_size; + uint32_t stride1 = param.stride1; + uint32_t stride2 = param.stride2; + uint32_t max_displacement = param.max_displacement; + + auto run = [&](DType dtype) { + for (size_t N : {1, 3}) + for (size_t C : {1, 3}) + for (size_t OH : {10, 100}) + for (size_t OW : {10, 100}) { + int paddedbottomheight = OH + 2 * pad_size; + int paddedbottomwidth = OW + 2 * pad_size; + uint32_t kernel_radius = (kernel_size - 1) / 2; + uint32_t border_size = max_displacement + kernel_radius; + uint32_t top_width = + ceil(static_cast(paddedbottomwidth - + border_size * 2) / + static_cast(stride1)); + uint32_t top_height = + ceil(static_cast(paddedbottomheight - + border_size * 2) / + static_cast(stride1)); + uint32_t neighborhood_grid_radius = + max_displacement / stride2; + uint32_t neighborhood_grid_width = + neighborhood_grid_radius * 2 + 1; + uint32_t top_channels = neighborhood_grid_width * + neighborhood_grid_width; + + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .execs({{N, top_channels, top_height, + top_width}, + {N, C, OH, OW}, + {N, C, OH, OW}, + {N, C, OH, OW}}); + } + }; + + run(dtype::Float32()); + run(dtype::Float16()); + checker.set_epsilon(5e-2); + run(dtype::BFloat16()); +} + +TEST_F(CUDA, CORRELATION_BACKWARDDATA2) { + ConstValue const_0{0}; + using Param = CorrelationBackwardData2::Param; + Param param; + param.is_multiply = true; + param.format = Param::Format::NCHW; + param.stride1 = 2; + param.stride2 = 2; + param.kernel_size = 3; + param.pad_size = 4; + Checker checker(handle_cuda()); + checker.set_epsilon(1e-2); + + uint32_t pad_size = param.pad_size; + uint32_t kernel_size = param.kernel_size; + uint32_t stride1 = param.stride1; + uint32_t stride2 = param.stride2; + uint32_t max_displacement = param.max_displacement; + + auto run = [&](DType dtype) { + for (size_t N : {1, 3}) + for (size_t C : {1, 3}) + for (size_t OH : {10, 100}) + for (size_t OW : {10, 100}) { + int paddedbottomheight = OH + 2 * pad_size; + int paddedbottomwidth = OW + 2 * pad_size; + uint32_t kernel_radius = (kernel_size - 1) / 2; + uint32_t border_size = max_displacement + kernel_radius; + uint32_t top_width = + ceil(static_cast(paddedbottomwidth - + border_size * 2) / + static_cast(stride1)); + uint32_t top_height = + ceil(static_cast(paddedbottomheight - + border_size * 2) / + static_cast(stride1)); + uint32_t neighborhood_grid_radius = + max_displacement / stride2; + uint32_t neighborhood_grid_width = + neighborhood_grid_radius * 2 + 1; + uint32_t top_channels = neighborhood_grid_width * + neighborhood_grid_width; + + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .execs({{N, top_channels, top_height, + top_width}, + {N, C, OH, OW}, + {N, C, OH, OW}, + {N, C, OH, OW}}); + } + }; + + run(dtype::Float32()); + run(dtype::Float16()); + checker.set_epsilon(5e-2); + run(dtype::BFloat16()); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen