提交 1997b1a2 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add correlation kernel

GitOrigin-RevId: 25e58b61e66856e7926afaa7d4e6c8147e800ae9
上级 6d376623
......@@ -192,6 +192,87 @@ class ReduceForward: public OperatorBase {
};
using Reduce = ReduceForward;
class CorrelationBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(CorrelationBase, OperatorBase);
DEF_OPR_PARAM(Correlation);
protected:
void deduce_layout_fwd(const TensorLayout& data1, const TensorLayout& data2,
TensorLayout& dst);
void check_layout_fwd(const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& dst);
};
class CorrelationForward : public CorrelationBase {
DEF_OPR_IMPL(CorrelationForward, CorrelationBase, 2, 1);
public:
/**
* \param[in] data1 (n, c, ih, iw)
* \param[in] data2 (n, c, ih, iw)
* \param[out] dst (n, q, oh, ow), q is the number of neighborhood
* */
virtual void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& data1, const TensorLayout& data2,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& dst) = 0;
protected:
void check_exec(const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& dst, size_t workspace_in_bytes);
};
using Correlation = CorrelationForward;
class CorrelationBackwardData1 : public CorrelationBase {
DEF_OPR_IMPL(CorrelationBackwardData1, CorrelationBase, 3, 1);
public:
/**
* \param[in] diff the backpropagated gradient wrt. dst
* \param[in] data1 the `data1' parameter in CorrelationForward::exec
* \param[in] data2 the `data2' parameter in CorrelationForward::exec
* \param[out] grad1 the backpropagated gradient wrt. data1
*/
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1,
const TensorLayout& data2, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& grad1) = 0;
protected:
void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& grad1, size_t workspace_in_bytes);
};
class CorrelationBackwardData2 : public CorrelationBase {
DEF_OPR_IMPL(CorrelationBackwardData2, CorrelationBase, 3, 1);
public:
/**
* \param[in] diff the backpropagated gradient wrt. dst
* \param[in] data1 the `data1' parameter in CorrelationForward::exec
* \param[in] data2 the `data2' parameter in CorrelationForward::exec
* \param[out] grad2 the backpropagated gradient wrt. data2
*/
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1,
const TensorLayout& data2, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& grad2) = 0;
protected:
void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& grad2, size_t workspace_in_bytes);
};
class CumsumForward: public OperatorBase {
DEF_OPR_PARAM(Cumsum);
DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1);
......
......@@ -1053,6 +1053,16 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
'sample_width', '2')
)
(pdef('Correlation').
add_enum_alias('Format', 'ConvolutionV0').
add_fields('uint32', 'kernel_size', '1').
add_fields('uint32', 'max_displacement', '1').
add_fields('uint32', 'stride1', '1').
add_fields('uint32', 'stride2', '1').
add_fields('uint32', 'pad_size', '0').
add_fields('bool', 'is_multiply', 'true')
)
(pdef('DeformablePSROIPooling').
add_fields('bool', 'no_trans', 'true').
add_fields('float32', 'spatial_scale', 1,
......
/**
* \file dnn/src/common/correlation.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1,
const TensorLayout& data2,
TensorLayout& dst) {
megdnn_assert_contiguous(data1);
megdnn_assert_contiguous(data2);
megdnn_assert_contiguous(dst);
auto errmsg = [&]() {
return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) +
", " + megdnn_layout_msg(dst);
};
MEGDNN_MARK_USED_VAR(errmsg);
using Format = CorrelationBase::Param::Format;
megdnn_assert(param().format == Format::NCHW);
auto data1_dtype = data1.dtype, data2_dtype = data2.dtype;
megdnn_assert(data1_dtype == data2_dtype &&
data1_dtype.category() == DTypeCategory::FLOAT);
megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str());
megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str());
uint32_t pad_size = param().pad_size;
uint32_t kernel_size = param().kernel_size;
uint32_t stride1 = param().stride1;
uint32_t stride2 = param().stride2;
uint32_t max_displacement = param().max_displacement;
int paddedbottomheight = data1[2] + 2 * pad_size;
int paddedbottomwidth = data1[3] + 2 * pad_size;
uint32_t kernel_radius = (kernel_size - 1) / 2;
uint32_t border_size = max_displacement + kernel_radius;
uint32_t top_width =
ceil(static_cast<float>(paddedbottomwidth - border_size * 2) /
static_cast<float>(stride1));
uint32_t top_height =
ceil(static_cast<float>(paddedbottomheight - border_size * 2) /
static_cast<float>(stride1));
uint32_t neighborhood_grid_radius = max_displacement / stride2;
uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width;
megdnn_assert(top_width >= 1 && top_height >= 1);
dst = TensorLayout{{data1[0], top_channels, top_height, top_width},
data1.dtype};
}
void CorrelationBase::check_layout_fwd(const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& dst) {
TensorLayout dst_expected;
megdnn_assert_eq_dtype(data1, dst);
megdnn_assert_eq_shape(data1, data2);
deduce_layout_fwd(data1, data2, dst_expected);
megdnn_assert_eq_shape(dst_expected, dst);
}
void CorrelationForward::deduce_layout(const TensorLayout& data1,
const TensorLayout& data2,
TensorLayout& dst) {
deduce_layout_fwd(data1, data2, dst);
}
void CorrelationForward::check_exec(const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& dst,
size_t workspace_in_bytes) {
check_layout_fwd(data1, data2, dst);
auto required_workspace_in_bytes =
get_workspace_in_bytes(data1, data2, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CorrelationBackwardData1::check_exec(const TensorLayout& diff,
const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& grad1,
size_t workspace_in_bytes) {
check_layout_fwd(grad1, data2, diff);
megdnn_assert_eq_shape(data1, data2);
auto required_workspace_in_bytes =
get_workspace_in_bytes(diff, data1, data2, grad1);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CorrelationBackwardData2::check_exec(const TensorLayout& diff,
const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& grad2,
size_t workspace_in_bytes) {
check_layout_fwd(data1, grad2, diff);
megdnn_assert_eq_shape(data1, data2);
auto required_workspace_in_bytes =
get_workspace_in_bytes(diff, data1, data2, grad2);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CorrelationBackwardData2::deduce_layout(const TensorLayout& diff,
const TensorLayout& data1,
const TensorLayout& data2,
TensorLayout& grad) {
megdnn_assert_eq_shape(data1, data2);
check_layout_fwd(data1, data2, diff);
grad = data2;
}
void CorrelationBackwardData1::deduce_layout(const TensorLayout& diff,
const TensorLayout& data1,
const TensorLayout& data2,
TensorLayout& grad) {
megdnn_assert_eq_shape(data1, data2);
check_layout_fwd(data1, data2, diff);
grad = data1;
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -194,6 +194,9 @@ private:
cb(LocalShareBackwardFilter) \
cb(ROIAlignForward) \
cb(ROIAlignBackward) \
cb(CorrelationForward) \
cb(CorrelationBackwardData1) \
cb(CorrelationBackwardData2) \
cb(BatchConvBiasForward) \
cb(Remap) \
cb(RemapBackwardData) \
......
......@@ -54,6 +54,9 @@ DEF(BNForward, 8, true, true);
DEF(BNBackward, 8, true, false);
DEF(ROIPoolingForward, 4, true, false);
DEF(ROIPoolingBackward, 5, true, false);
DEF(CorrelationForward, 3, true, true);
DEF(CorrelationBackwardData1, 4, true, true);
DEF(CorrelationBackwardData2, 4, true, true);
DEF(WarpPerspectiveForward, 3, true, false);
DEF(WarpPerspectiveBackwardData, 3, true, false);
DEF(WarpPerspectiveBackwardMat, 4, true, false);
......
/**
* \file dnn/src/cuda/roi_align/roi_align.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/correlation/correlation_cuda.cuh"
#include <cfloat>
#include "megdnn/dtype.h"
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/utils.cuh"
#define ROUND_OFF 50000
using namespace megdnn;
namespace megdnn {
namespace cuda {
namespace correlation {
#define CUDA_KERNEL_LOOP(vtid, vthreads) \
for (int vtid = blockIdx.x * blockDim.x + threadIdx.x; vtid < vthreads; \
vtid += blockDim.x * gridDim.x)
template <typename T>
__global__ void forward_kernel(const int nthreads, const T* data1,
const T* data2, T* dst, const int bchannels,
const int bheight, const int bwidth,
const int tchannels, const int theight,
const int twidth, const int kernel_size,
const int max_displacement, const int stride1,
const int stride2, const int pad_size,
const bool is_multiply) {
CUDA_KERNEL_LOOP(idx, nthreads) {
int kernel_radius = (kernel_size - 1) / 2;
int neighborhood_grid_radius = max_displacement / stride2;
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
int x = idx % twidth;
int y = (idx / twidth) % theight;
int c = (idx / twidth / theight) % tchannels;
int n = idx / twidth / theight / tchannels;
// get src center position in image1
int x1 = x * stride1 + kernel_radius + max_displacement - pad_size;
int y1 = y * stride1 + kernel_radius + max_displacement - pad_size;
// get offset of center in image2
int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) *
stride2;
int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) *
stride2;
int x2 = x1 + s2o;
int y2 = y1 + s2p;
// compute kernel correlation
T sum = T(0.f);
for (int i = -kernel_radius; i <= kernel_radius; i++) {
for (int j = -kernel_radius; j <= kernel_radius; j++) {
int in_x1 = x1 + i;
int in_y1 = y1 + j;
int in_x2 = x2 + i;
int in_y2 = y2 + j;
for (int channel = 0; channel < bchannels; channel++) {
T tmp1 = T(0.f);
T tmp2 = T(0.f);
if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 &&
in_y1 < bheight) {
int idx1 =
((n * bchannels + channel) * bheight + in_y1) *
bwidth +
in_x1;
tmp1 = data1[idx1];
}
if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 &&
in_y2 < bheight) {
int idx2 =
((n * bchannels + channel) * bheight + in_y2) *
bwidth +
in_x2;
tmp2 = data2[idx2];
}
if (is_multiply) {
sum += tmp1 * tmp2;
} else {
sum += fabsf(tmp1 - tmp2);
}
}
}
}
const int sumelems =
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels;
dst[idx] = sum / sumelems;
}
}
template <typename T>
__global__ void backward_kernel_data1(
const int nthreads, const T* diff, const T* data1, const T* data2,
T* grad1, const int bchannels, const int bheight, const int bwidth,
const int tchannels, const int theight, const int twidth,
const int kernel_size, const int max_displacement, const int stride1,
const int stride2, const int pad_size, const bool is_multiply) {
CUDA_KERNEL_LOOP(idx, nthreads) {
int kernel_radius = (kernel_size - 1) / 2;
int neighborhood_grid_radius = max_displacement / stride2;
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
int x = idx % bwidth;
int y = (idx / bwidth) % bheight;
int c = (idx / bwidth / bheight) % bchannels;
int n = idx / bwidth / bheight / bchannels;
T tmp1 = data1[idx];
// Get X,Y ranges and clamp
// round_off is a trick to enable integer division with ceil, even for
// negative numbers We use a large offset, for the inner part not to
// become negative.
const int round_off = ROUND_OFF;
const int round_off_s1 = stride1 * round_off;
// we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y)
// for diff_x_min, diff_y_min, x,y at the position of right-down
// ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1
int xmin = (x + pad_size - 2 * kernel_radius - max_displacement +
round_off_s1 - 1) /
stride1 +
1 - round_off;
int ymin = (y + pad_size - 2 * kernel_radius - max_displacement +
round_off_s1 - 1) /
stride1 +
1 - round_off;
// floor (l - max_displacement + pad_size) / stride1
int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 -
round_off;
int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 -
round_off;
T sum = T(0.f);
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) &&
(ymin <= theight - 1)) {
xmin = max(0, xmin);
xmax = min(twidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(theight - 1, ymax);
for (int p = -neighborhood_grid_radius;
p <= neighborhood_grid_radius; p++) {
for (int o = -neighborhood_grid_radius;
o <= neighborhood_grid_radius; o++) {
// Get bottom1 data:
int s2o = stride2 * o;
int s2p = stride2 * p;
int x2 = x + s2o, y2 = y + s2p;
int idx2 =
((n * bchannels + c) * bheight + y2) * bwidth + x2;
T tmp2 = T(0.f);
if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) {
tmp2 = data2[idx2];
}
int op = (p + neighborhood_grid_radius) *
neighborhood_grid_width +
(o + neighborhood_grid_radius);
int diff_channels_offset = (n * tchannels + op);
for (int diff_y = ymin; diff_y <= ymax; diff_y++) {
for (int diff_x = xmin; diff_x <= xmax; diff_x++) {
int idxtopdiff =
(diff_channels_offset * theight + diff_y) *
twidth +
diff_x;
if (is_multiply) {
sum += diff[idxtopdiff] * tmp2;
} else {
T sign = (tmp1 >= tmp2) ? T(1.f) : T(-1.f);
sum += diff[idxtopdiff] * sign;
}
}
}
}
}
}
const int sumelems =
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels;
grad1[idx] = sum / sumelems;
}
}
template <typename T>
__global__ void backward_kernel_data2(
const int nthreads, const T* diff, const T* data1, const T* data2,
T* grad2, const int bchannels, const int bheight, const int bwidth,
const int tchannels, const int theight, const int twidth,
const int kernel_size, const int max_displacement, const int stride1,
const int stride2, const int pad_size, const bool is_multiply) {
CUDA_KERNEL_LOOP(idx, nthreads) {
int kernel_radius = (kernel_size - 1) / 2;
int neighborhood_grid_radius = max_displacement / stride2;
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
int x = idx % bwidth;
int y = (idx / bwidth) % bheight;
int c = (idx / bwidth / bheight) % bchannels;
int n = idx / bwidth / bheight / bchannels;
T tmp2 = data2[idx];
T sum = T(0.f);
for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius;
p++) {
for (int o = -neighborhood_grid_radius;
o <= neighborhood_grid_radius; o++) {
int s2o = o * stride2;
int s2p = p * stride2;
int x1 = x - s2o;
int y1 = y - s2p;
const int round_off = ROUND_OFF;
const int round_off_s1 = stride1 * round_off;
int xmin = (x1 + pad_size - 2 * kernel_radius -
max_displacement + round_off_s1 - 1) /
stride1 +
1 - round_off;
int ymin = (y1 + pad_size - 2 * kernel_radius -
max_displacement + round_off_s1 - 1) /
stride1 +
1 - round_off;
int xmax = (x1 + pad_size - max_displacement + round_off_s1) /
stride1 -
round_off;
int ymax = (y1 + pad_size - max_displacement + round_off_s1) /
stride1 -
round_off;
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) &&
(ymin <= theight - 1)) {
xmin = max(0, xmin);
xmax = min(twidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(theight - 1, ymax);
int idx1 =
((n * bchannels + c) * bheight + y1) * bwidth + x1;
T tmp1 = T(0.f);
if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) {
tmp1 = data1[idx1];
}
int op = (p + neighborhood_grid_radius) *
neighborhood_grid_width +
(o + neighborhood_grid_radius);
int diff_channels_offset = (n * tchannels + op);
for (int diff_y = ymin; diff_y <= ymax; diff_y++) {
for (int diff_x = xmin; diff_x <= xmax; diff_x++) {
int idxtopdiff =
(diff_channels_offset * theight + diff_y) *
twidth +
diff_x;
if (is_multiply) {
sum += diff[idxtopdiff] * tmp1;
} else {
T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f);
sum += diff[idxtopdiff] * sign;
}
}
}
}
}
}
const int sumelems =
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels;
grad2[idx] = sum / sumelems;
}
}
template <typename T>
void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst,
const int bchannels, const int bheight, const int bwidth,
const int tchannels, const int theight, const int twidth,
const int kernel_size, const int max_displacement,
const int stride1, const int stride2, const int pad_size,
const bool is_multiply, cudaStream_t stream) {
int threads_block = query_blocksize_for_kernel(forward_kernel<T>);
forward_kernel<T>
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>(
nthreads, data1, data2, dst, bchannels, bheight, bwidth,
tchannels, theight, twidth, kernel_size, max_displacement,
stride1, stride2, pad_size, is_multiply);
after_kernel_launch();
}
template <typename T>
void backward_proxy_data1(const int nthreads, const T* diff, const T* data1,
const T* data2, T* grad1, const int bchannels,
const int bheight, const int bwidth,
const int tchannels, const int theight,
const int twidth, const int kernel_size,
const int max_displacement, const int stride1,
const int stride2, const int pad_size,
const bool is_multiply, cudaStream_t stream) {
int threads_block = query_blocksize_for_kernel(backward_kernel_data1<T>);
backward_kernel_data1<T>
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>(
nthreads, diff, data1, data2, grad1, bchannels, bheight,
bwidth, tchannels, theight, twidth, kernel_size,
max_displacement, stride1, stride2, pad_size, is_multiply);
after_kernel_launch();
}
template <typename T>
void backward_proxy_data2(const int nthreads, const T* diff, const T* data1,
const T* data2, T* grad2, const int bchannels,
const int bheight, const int bwidth,
const int tchannels, const int theight,
const int twidth, const int kernel_size,
const int max_displacement, const int stride1,
const int stride2, const int pad_size,
const bool is_multiply, cudaStream_t stream) {
int threads_block = query_blocksize_for_kernel(backward_kernel_data2<T>);
backward_kernel_data2<T>
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>(
nthreads, diff, data1, data2, grad2, bchannels, bheight,
bwidth, tchannels, theight, twidth, kernel_size,
max_displacement, stride1, stride2, pad_size, is_multiply);
after_kernel_launch();
}
#define INST(T) \
template void forward_proxy<T>( \
const int, const T*, const T*, T* dst, const int, const int, \
const int, const int, const int, const int, const int, const int, \
const int, const int, const int, const bool, cudaStream_t); \
template void backward_proxy_data1<T>( \
const int, const T*, const T*, const T*, T*, const int, const int, \
const int, const int, const int, const int, const int, const int, \
const int, const int, const int, const bool, cudaStream_t); \
template void backward_proxy_data2<T>( \
const int, const T*, const T*, const T*, T*, const int, const int, \
const int, const int, const int, const int, const int, const int, \
const int, const int, const int, const bool, cudaStream_t);
INST(dt_float32)
INST(dt_float16)
INST(dt_bfloat16)
#undef INST
} // namespace roi_align
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/correlation/correlation.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <cuda_runtime_api.h>
namespace megdnn {
namespace cuda {
namespace correlation {
template <typename T>
void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst,
const int bchannels, const int bheight, const int bwidth,
const int tchannels, const int theight, const int twidth,
const int kernel_size, const int max_displacement,
const int stride1, const int stride2, const int pad_size,
const bool is_multiply, cudaStream_t stream);
template <typename T>
void backward_proxy_data1(const int nthreads, const T* diff, const T* data1,
const T* data2, T* grad1, const int bchannels,
const int bheight, const int bwidth,
const int tchannels, const int theight,
const int twidth, const int kernel_size,
const int max_displacement, const int stride1,
const int stride2, const int pad_size,
const bool is_multiply, cudaStream_t stream);
template <typename T>
void backward_proxy_data2(const int nthreads, const T* diff, const T* data1,
const T* data2, T* grad2, const int bchannels,
const int bheight, const int bwidth,
const int tchannels, const int theight,
const int twidth, const int kernel_size,
const int max_displacement, const int stride1,
const int stride2, const int pad_size,
const bool is_multiply, cudaStream_t stream);
} // namespace correlation
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/correlation/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/correlation/opr_impl.h"
#include "src/cuda/correlation/correlation_cuda.cuh"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void CorrelationForwardImpl::exec(_megdnn_tensor_in data1,
_megdnn_tensor_in data2,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(data1.layout, data2.layout, dst.layout, workspace.size);
auto p = param();
auto stream = cuda_stream(handle());
int nthreads = dst.layout.total_nr_elems();
int stride1 = p.stride1;
int stride2 = p.stride2;
int kernel_size = p.kernel_size;
int max_displacement = p.max_displacement;
int pad_size = p.pad_size;
bool is_multiply = p.is_multiply;
int tchannels = dst.layout[1];
int theight = dst.layout[2], twidth = dst.layout[3];
int bchannels = data1.layout[1];
int bheight = data1.layout[2], bwidth = data1.layout[3];
using namespace ::megdnn::cuda::correlation;
#define cb(DType) \
if (data1.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
forward_proxy<T>(nthreads, data1.ptr<T>(), data2.ptr<T>(), \
dst.ptr<T>(), bchannels, bheight, bwidth, tchannels, \
theight, twidth, kernel_size, max_displacement, \
stride1, stride2, pad_size, is_multiply, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
}
void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in data1,
_megdnn_tensor_in data2,
_megdnn_tensor_out grad1,
_megdnn_workspace workspace) {
check_exec(diff.layout, data1.layout, data2.layout, grad1.layout,
workspace.size);
auto stream = cuda_stream(handle());
int nthreads = grad1.layout.total_nr_elems();
int stride1 = param().stride1;
int stride2 = param().stride2;
int kernel_size = param().kernel_size;
int max_displacement = param().max_displacement;
int pad_size = param().pad_size;
bool is_multiply = param().is_multiply;
int tchannels = diff.layout[1];
int theight = diff.layout[2], twidth = diff.layout[3];
int bchannels = data1.layout[1];
int bheight = data1.layout[2], bwidth = data1.layout[3];
using namespace ::megdnn::cuda::correlation;
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
backward_proxy_data1<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \
data2.ptr<T>(), grad1.ptr<T>(), bchannels, \
bheight, bwidth, tchannels, theight, twidth, \
kernel_size, max_displacement, stride1, \
stride2, pad_size, is_multiply, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
}
void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in data1,
_megdnn_tensor_in data2,
_megdnn_tensor_out grad2,
_megdnn_workspace workspace) {
check_exec(diff.layout, data1.layout, data2.layout, grad2.layout,
workspace.size);
auto p = param();
auto stream = cuda_stream(handle());
int nthreads = grad2.layout.total_nr_elems();
int stride1 = p.stride1;
int stride2 = p.stride2;
int kernel_size = p.kernel_size;
int max_displacement = p.max_displacement;
int pad_size = p.pad_size;
bool is_multiply = p.is_multiply;
int tchannels = diff.layout[1];
int theight = diff.layout[2], twidth = diff.layout[3];
int bchannels = data1.layout[1];
int bheight = data1.layout[2], bwidth = data1.layout[3];
using namespace ::megdnn::cuda::correlation;
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
backward_proxy_data2<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \
data2.ptr<T>(), grad2.ptr<T>(), bchannels, \
bheight, bwidth, tchannels, theight, twidth, \
kernel_size, max_displacement, stride1, \
stride2, pad_size, is_multiply, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/correlation/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_wrapper.h"
namespace megdnn {
namespace cuda {
class CorrelationForwardImpl final : public CorrelationForward {
public:
using CorrelationForward::CorrelationForward;
void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& data1,
const TensorLayout& data2,
const TensorLayout& dst) override {
return 0;
}
};
class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 {
public:
using CorrelationBackwardData1::CorrelationBackwardData1;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1,
_megdnn_tensor_in data2, _megdnn_tensor_out grad1,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 {
public:
using CorrelationBackwardData2::CorrelationBackwardData2;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1,
_megdnn_tensor_in data2, _megdnn_tensor_out grad2,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -24,6 +24,7 @@
#include "src/cuda/convolution/opr_impl.h"
#include "src/cuda/convolution3d/opr_impl.h"
#include "src/cuda/convpooling/opr_impl.h"
#include "src/cuda/correlation/opr_impl.h"
#include "src/cuda/cumsum/opr_impl.h"
#include "src/cuda/cvt_color/opr_impl.h"
#include "src/cuda/dct/opr_impl.h"
......
/**
* \file dnn/src/naive/correlation/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/correlation/opr_impl.h"
#include <algorithm>
#include "src/common/utils.h"
#include "src/naive/handle.h"
#define ROUND_OFF 50000
using namespace megdnn;
using namespace naive;
using namespace std;
namespace {
using Param = megdnn::Correlation::Param;
template <typename T>
void forward(_megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, const Param& param) {
// data1 treat as no-padding tensor
int total_nr_elems = dst.layout.total_nr_elems();
int stride1 = param.stride1, stride2 = param.stride2;
int kernel_size = param.kernel_size;
int kernel_radius = (kernel_size - 1) / 2;
int max_displacement = param.max_displacement;
int pad_size = param.pad_size;
int tchannels = dst.layout[1];
int theight = dst.layout[2], twidth = dst.layout[3];
int bchannels = data1.layout[1];
int bheight = data1.layout[2], bwidth = data1.layout[3];
int neighborhood_grid_radius = max_displacement / stride2;
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
for (int idx = 0; idx < total_nr_elems; ++idx) {
int x = idx % twidth;
int y = (idx / twidth) % theight;
int c = (idx / twidth / theight) % tchannels;
int n = idx / twidth / theight / tchannels;
// get src center position in image1
int x1 = x * stride1 + kernel_radius + max_displacement - pad_size;
int y1 = y * stride1 + kernel_radius + max_displacement - pad_size;
// get offset of center in image2
int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) *
stride2;
int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) *
stride2;
int x2 = x1 + s2o;
int y2 = y1 + s2p;
// compute kernel correlation
float sum = 0.;
for (int i = -kernel_radius; i <= kernel_radius; i++) {
for (int j = -kernel_radius; j <= kernel_radius; j++) {
int in_x1 = x1 + i;
int in_y1 = y1 + j;
int in_x2 = x2 + i;
int in_y2 = y2 + j;
for (int channel = 0; channel < bchannels; channel++) {
float tmp1 = 0.;
float tmp2 = 0.;
if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 &&
in_y1 < bheight) {
int idx1 =
((n * bchannels + channel) * bheight + in_y1) *
bwidth +
in_x1;
tmp1 = data1.ptr<T>()[idx1];
}
if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 &&
in_y2 < bheight) {
int idx2 =
((n * bchannels + channel) * bheight + in_y2) *
bwidth +
in_x2;
tmp2 = data2.ptr<T>()[idx2];
}
if (param.is_multiply) {
sum += tmp1 * tmp2;
} else {
sum += fabsf(tmp1 - tmp2);
}
}
}
}
const int sumelems =
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels;
dst.ptr<T>()[idx] = sum / sumelems;
}
}
template <typename T>
void backward_data1(_megdnn_tensor_in diff, _megdnn_tensor_in data1,
_megdnn_tensor_in data2, _megdnn_tensor_out grad1,
const Param& param) {
// data1 treat as no-padding tensor
// int total_nr_elems = diff.layout.total_nr_elems();
int total_nr_elems = grad1.layout.total_nr_elems();
int stride1 = param.stride1, stride2 = param.stride2;
int kernel_size = param.kernel_size;
int kernel_radius = (kernel_size - 1) / 2;
int max_displacement = param.max_displacement;
int pad_size = param.pad_size;
int tchannels = diff.layout[1];
int theight = diff.layout[2], twidth = diff.layout[3];
int bchannels = grad1.layout[1];
int bheight = grad1.layout[2], bwidth = grad1.layout[3];
int neighborhood_grid_radius = max_displacement / stride2;
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
for (int idx = 0; idx < total_nr_elems; ++idx) {
// idx for grad1
int x = idx % bwidth;
int y = (idx / bwidth) % bheight;
int c = (idx / bwidth / bheight) % bchannels;
int n = idx / bwidth / bheight / bchannels;
float tmp1 = data1.ptr<T>()[idx];
// Get X,Y ranges and clamp
// round_off is a trick to enable integer division with ceil, even for
// negative numbers We use a large offset, for the inner part not to
// become negative.
const int round_off = ROUND_OFF;
const int round_off_s1 = stride1 * round_off;
// we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y)
// for diff_x_min, diff_y_min, x,y at the position of right-down
// ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1
int xmin = (x + pad_size - 2 * kernel_radius - max_displacement +
round_off_s1 - 1) /
stride1 +
1 - round_off;
int ymin = (y + pad_size - 2 * kernel_radius - max_displacement +
round_off_s1 - 1) /
stride1 +
1 - round_off;
// floor (l - max_displacement + pad_size) / stride1
int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 -
round_off;
int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 -
round_off;
float sum = 0.;
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) &&
(ymin <= theight - 1)) {
xmin = max(0, xmin);
xmax = min(twidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(theight - 1, ymax);
for (int p = -neighborhood_grid_radius;
p <= neighborhood_grid_radius; p++) {
for (int o = -neighborhood_grid_radius;
o <= neighborhood_grid_radius; o++) {
// Get bottom1 data:
int s2o = stride2 * o;
int s2p = stride2 * p;
int x2 = x + s2p, y2 = y + s2o;
int idx2 =
((n * bchannels + c) * bheight + y2) * bwidth + x2;
float tmp2 = 0.;
if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) {
tmp2 = data2.ptr<T>()[idx2];
}
int op = (p + neighborhood_grid_radius) *
neighborhood_grid_width +
(o + neighborhood_grid_radius);
int diff_channels_offset = (n * tchannels + op);
for (int diff_y = ymin; diff_y <= ymax; diff_y++) {
for (int diff_x = xmin; diff_x <= xmax; diff_x++) {
int idxtopdiff =
(diff_channels_offset * theight + diff_y) *
twidth +
diff_x;
if (param.is_multiply) {
sum += diff.ptr<T>()[idxtopdiff] * tmp2;
} else {
T sign = (tmp1 > tmp2) ? T(1.) : T(-1.);
sum += diff.ptr<T>()[idxtopdiff] * sign;
}
}
}
}
}
}
const int sumelems =
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels;
grad1.ptr<T>()[idx] = sum / sumelems;
}
}
template <typename T>
void backward_data2(_megdnn_tensor_in diff, _megdnn_tensor_in data1,
_megdnn_tensor_in data2, _megdnn_tensor_out grad2,
const Param& param) {
// data1 treat as no-padding tensor
int total_nr_elems = grad2.layout.total_nr_elems();
int stride1 = param.stride1, stride2 = param.stride2;
int kernel_size = param.kernel_size;
int kernel_radius = (kernel_size - 1) / 2;
int max_displacement = param.max_displacement;
int pad_size = param.pad_size;
int tchannels = diff.layout[1];
int theight = diff.layout[2], twidth = diff.layout[3];
int bchannels = grad2.layout[1];
int bheight = grad2.layout[2], bwidth = grad2.layout[3];
int neighborhood_grid_radius = max_displacement / stride2;
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
for (int idx = 0; idx < total_nr_elems; ++idx) {
int x = idx % bwidth;
int y = (idx / bwidth) % bheight;
int c = (idx / bwidth / bheight) % bchannels;
int n = idx / bwidth / bheight / bchannels;
T tmp2 = data2.ptr<T>()[idx];
T sum = T(0.f);
for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius;
p++) {
for (int o = -neighborhood_grid_radius;
o <= neighborhood_grid_radius; o++) {
int s2o = o * stride2;
int s2p = p * stride2;
int x1 = x - s2o;
int y1 = y - s2p;
const int round_off = ROUND_OFF;
const int round_off_s1 = stride1 * round_off;
int xmin = (x1 + pad_size - 2 * kernel_radius -
max_displacement + round_off_s1 - 1) /
stride1 +
1 - round_off;
int ymin = (y1 + pad_size - 2 * kernel_radius -
max_displacement + round_off_s1 - 1) /
stride1 +
1 - round_off;
int xmax = (x1 + pad_size - max_displacement + round_off_s1) /
stride1 -
round_off;
int ymax = (y1 + pad_size - max_displacement + round_off_s1) /
stride1 -
round_off;
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) &&
(ymin <= theight - 1)) {
xmin = max(0, xmin);
xmax = min(twidth - 1, xmax);
ymin = max(0, ymin);
ymax = min(theight - 1, ymax);
int idx1 =
((n * bchannels + c) * bheight + y1) * bwidth + x1;
T tmp1 = T(0.f);
if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) {
tmp1 = data1.ptr<T>()[idx1];
}
int op = (p + neighborhood_grid_radius) *
neighborhood_grid_width +
(o + neighborhood_grid_radius);
int diff_channels_offset = (n * tchannels + op);
for (int diff_y = ymin; diff_y <= ymax; diff_y++) {
for (int diff_x = xmin; diff_x <= xmax; diff_x++) {
int idxtopdiff =
(diff_channels_offset * theight + diff_y) *
twidth +
diff_x;
if (param.is_multiply) {
sum += diff.ptr<T>()[idxtopdiff] * tmp1;
} else {
T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f);
sum += diff.ptr<T>()[idxtopdiff] * sign;
}
}
}
}
}
}
const int sumelems =
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels;
grad2.ptr<T>()[idx] = sum / sumelems;
}
}
} // namespace
namespace megdnn {
namespace naive {
void CorrelationForwardImpl::exec(_megdnn_tensor_in data1,
_megdnn_tensor_in data2,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(data1.layout, data2.layout, dst.layout, workspace.size);
#define cb(DType) \
if (data1.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
forward<typename DTypeTrait<DType>::ctype>(data1, data2, dst, \
param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in data1,
_megdnn_tensor_in data2,
_megdnn_tensor_out grad1,
_megdnn_workspace workspace) {
check_exec(diff.layout, data1.layout, data2.layout, grad1.layout,
workspace.size);
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
backward_data1<typename DTypeTrait<DType>::ctype>( \
diff, data1, data2, grad1, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in data1,
_megdnn_tensor_in data2,
_megdnn_tensor_out grad2,
_megdnn_workspace workspace) {
check_exec(diff.layout, data1.layout, data2.layout, grad2.layout,
workspace.size);
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
backward_data2<typename DTypeTrait<DType>::ctype>( \
diff, data1, data2, grad2, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/correlation/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class CorrelationForwardImpl final : public CorrelationForward {
public:
using CorrelationForward::CorrelationForward;
void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 {
public:
using CorrelationBackwardData1::CorrelationBackwardData1;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1,
_megdnn_tensor_in data2, _megdnn_tensor_out grad1,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 {
public:
using CorrelationBackwardData2::CorrelationBackwardData2;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1,
_megdnn_tensor_in data2, _megdnn_tensor_out grad2,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -30,6 +30,7 @@
#include "src/naive/convpooling/opr_impl.h"
#include "src/naive/cumsum/opr_impl.h"
#include "src/naive/cvt_color/opr_impl.h"
#include "src/naive/correlation/opr_impl.h"
#include "src/naive/dct/opr_impl.h"
#include "src/naive/deformable_conv/opr_impl.h"
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h"
......
/**
* \file dnn/test/common/correlation.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"
namespace megdnn {
namespace test {
namespace correlation {
struct TestArg {
param::Correlation param;
TensorShape data1, data2;
TestArg(param::Correlation param, TensorShape data1, TensorShape data2)
: param(param), data1(data1), data2(data2) {}
};
inline static std::vector<TestArg> get_args() {
std::vector<TestArg> args;
param::Correlation cur_param;
for (size_t batch_size : {2}) {
for (size_t channel : {2}) {
for (size_t height : {160}) {
for (size_t width : {160}) {
cur_param.is_multiply = true;
cur_param.kernel_size = 3;
cur_param.max_displacement = 3;
cur_param.pad_size = 0;
cur_param.stride1 = 1;
cur_param.stride2 = 1;
cur_param.format = megdnn::param::Correlation::Format::NCHW;
args.emplace_back(
cur_param,
TensorShape{batch_size, channel, height, width},
TensorShape{batch_size, channel, height, width});
// cur_param.is_multiply = false;
// cur_param.kernel_size = 1;
// cur_param.max_displacement = 2;
// cur_param.pad_size = 1;
// cur_param.stride1 = 1;
// cur_param.stride2 = 1;
// cur_param.format =
// megdnn::param::Correlation::Format::NCHW;
// args.emplace_back(
// cur_param,
// TensorShape{batch_size, channel, height, width},
// TensorShape{batch_size, channel, height, width});
}
}
}
}
return args;
}
} // namespace correlation
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cuda/correlation.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/cuda/fixture.h"
#include "test/common/checker.h"
#include "test/common/correlation.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, CORRELATION_FORWARD) {
using namespace correlation;
std::vector<TestArg> args = get_args();
Checker<Correlation> checker(handle_cuda());
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.execs({arg.data1, arg.data2, {}});
}
}
TEST_F(CUDA, CORRELATION_BACKWARDDATA1) {
ConstValue const_0{0};
using Param = CorrelationBackwardData1::Param;
Param param;
param.is_multiply = true;
param.format = Param::Format::NCHW;
param.stride1 = 2;
param.stride2 = 2;
param.kernel_size = 3;
param.pad_size = 4;
Checker<CorrelationBackwardData1> checker(handle_cuda());
checker.set_epsilon(1e-2);
uint32_t pad_size = param.pad_size;
uint32_t kernel_size = param.kernel_size;
uint32_t stride1 = param.stride1;
uint32_t stride2 = param.stride2;
uint32_t max_displacement = param.max_displacement;
auto run = [&](DType dtype) {
for (size_t N : {1, 3})
for (size_t C : {1, 3})
for (size_t OH : {10, 100})
for (size_t OW : {10, 100}) {
int paddedbottomheight = OH + 2 * pad_size;
int paddedbottomwidth = OW + 2 * pad_size;
uint32_t kernel_radius = (kernel_size - 1) / 2;
uint32_t border_size = max_displacement + kernel_radius;
uint32_t top_width =
ceil(static_cast<float>(paddedbottomwidth -
border_size * 2) /
static_cast<float>(stride1));
uint32_t top_height =
ceil(static_cast<float>(paddedbottomheight -
border_size * 2) /
static_cast<float>(stride1));
uint32_t neighborhood_grid_radius =
max_displacement / stride2;
uint32_t neighborhood_grid_width =
neighborhood_grid_radius * 2 + 1;
uint32_t top_channels = neighborhood_grid_width *
neighborhood_grid_width;
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.execs({{N, top_channels, top_height,
top_width},
{N, C, OH, OW},
{N, C, OH, OW},
{N, C, OH, OW}});
}
};
run(dtype::Float32());
run(dtype::Float16());
checker.set_epsilon(5e-2);
run(dtype::BFloat16());
}
TEST_F(CUDA, CORRELATION_BACKWARDDATA2) {
ConstValue const_0{0};
using Param = CorrelationBackwardData2::Param;
Param param;
param.is_multiply = true;
param.format = Param::Format::NCHW;
param.stride1 = 2;
param.stride2 = 2;
param.kernel_size = 3;
param.pad_size = 4;
Checker<CorrelationBackwardData2> checker(handle_cuda());
checker.set_epsilon(1e-2);
uint32_t pad_size = param.pad_size;
uint32_t kernel_size = param.kernel_size;
uint32_t stride1 = param.stride1;
uint32_t stride2 = param.stride2;
uint32_t max_displacement = param.max_displacement;
auto run = [&](DType dtype) {
for (size_t N : {1, 3})
for (size_t C : {1, 3})
for (size_t OH : {10, 100})
for (size_t OW : {10, 100}) {
int paddedbottomheight = OH + 2 * pad_size;
int paddedbottomwidth = OW + 2 * pad_size;
uint32_t kernel_radius = (kernel_size - 1) / 2;
uint32_t border_size = max_displacement + kernel_radius;
uint32_t top_width =
ceil(static_cast<float>(paddedbottomwidth -
border_size * 2) /
static_cast<float>(stride1));
uint32_t top_height =
ceil(static_cast<float>(paddedbottomheight -
border_size * 2) /
static_cast<float>(stride1));
uint32_t neighborhood_grid_radius =
max_displacement / stride2;
uint32_t neighborhood_grid_width =
neighborhood_grid_radius * 2 + 1;
uint32_t top_channels = neighborhood_grid_width *
neighborhood_grid_width;
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.execs({{N, top_channels, top_height,
top_width},
{N, C, OH, OW},
{N, C, OH, OW},
{N, C, OH, OW}});
}
};
run(dtype::Float32());
run(dtype::Float16());
checker.set_epsilon(5e-2);
run(dtype::BFloat16());
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册